我有以下预测模型:
library(tidymodels)
data(ames)
set.seed(4595)
data_split <- initial_split(ames, strata = "Sale_Price", prop = 0.75)
ames_train <- training(data_split)
ames_test <- testing(data_split)
rec <- recipe(Sale_Price ~ ., data = ames_train)
norm_trans <- rec %>%
step_zv(all_predictors()) %>%
step_nzv(all_predictors()) %>%
step_corr(all_numeric_predictors(), threshold = 0.1)
# Preprocessing
norm_obj <- prep(norm_trans, training = ames_train)
rf_ames_train <- bake(norm_obj, ames_train) %>%
dplyr::select(Sale_Price, everything()) %>%
as.data.frame()
dim(rf_ames_train )
rf_xy_fit <- rand_forest(mode = "regression") %>%
set_engine("ranger") %>%
fit_xy(
x = rf_ames_train,
y = log10(rf_ames_train$Sale_Price)
)
注意,经过预处理步骤后,特征的数量从74个减少到33个。
dim(rf_ames_train )
# 33
目前,我必须显式地传递函数中的预测器:
preds <- colnames(rf_ames_train)
my_pred_function <- function (fit = NULL, test_data = NULL, predictors = NULL) {
test_results <- test_data %>%
select(Sale_Price) %>%
mutate(Sale_Price = log10(Sale_Price)) %>%
bind_cols(
predict(fit, new_data = ames_test[, predictors])
)
test_results
}
my_pred_function(fit = rf_xy_fit, test_data = ames_test, predictors = preds)
在上面的函数调用中显示为predictors = preds
。
在实践中,我必须将rf_xy_fit
和preds
保存为两个RDS文件,然后再次读取它们。这很容易出错,很麻烦。
我想绕过这个显式传递。有没有办法直接从rf_xy_fit
中提取出来?
在这种情况下,您将受益于使用工作流包。这允许您将预处理代码与模型拟合代码组合在一起
library(tidymodels)
data(ames)
set.seed(4595)
# Notice how I did log transformation before doing the splitting to assure that it is not on both testing and training data sets.
ames <- ames %>%
mutate(Sale_Price = log10(Sale_Price))
data_split <- initial_split(ames, strata = "Sale_Price", prop = 0.75)
ames_train <- training(data_split)
ames_test <- testing(data_split)
rec <- recipe(Sale_Price ~ ., data = ames_train)
norm_trans <- rec %>%
step_zv(all_predictors()) %>%
step_nzv(all_predictors()) %>%
step_corr(all_numeric_predictors(), threshold = 0.1)
rf_spec <- rand_forest(mode = "regression") %>%
set_engine("ranger")
rf_wf <- workflow() %>%
add_recipe(norm_trans) %>%
add_model(rf_spec)
rf_fit <- fit(rf_wf, ames_train)
predict(rf_fit, new_data = ames_train)
#> # A tibble: 2,197 × 1
#> .pred
#> <dbl>
#> 1 5.09
#> 2 5.12
#> 3 5.01
#> 4 4.99
#> 5 5.12
#> 6 5.07
#> 7 4.90
#> 8 5.09
#> 9 5.13
#> 10 5.08
#> # … with 2,187 more rows
创建于2022-11-21与reprex v2.0.2
根据你的评论补充email的回答…
请记住,大多数R建模函数将期望原始的特性集,即使其中一些根本没有使用。这是R的公式/model.matrix()
机器的副产品。
对于菜谱,这取决于您使用的步骤。
你可以在没有它们的情况下改装最终模型,但你可能不会得到完全相同的模型。在许多情况下,获得特征子集的过程取决于最初传递了多少。
我正在为此编写一个tidymodels api,但是插入符号有一个可以获得模型实际使用的预测器列表。参见示例:
library(caret)
#> Loading required package: ggplot2
#> Loading required package: lattice
library(tidymodels)
tidymodels_prefer()
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
data(ames)
set.seed(4595)
ames <- ames %>%
mutate(Sale_Price = log10(Sale_Price))
data_split <- initial_split(ames, strata = "Sale_Price", prop = 0.75)
ames_train <- training(data_split)
ames_test <- testing(data_split)
rec <- recipe(Sale_Price ~ ., data = ames_train)
norm_trans <- rec %>%
step_zv(all_predictors()) %>%
step_nzv(all_predictors()) %>%
step_corr(all_numeric_predictors(), threshold = 0.1)
rf_spec <- rand_forest(mode = "regression") %>%
set_engine("ranger")
rf_wf <- workflow() %>%
add_recipe(norm_trans) %>%
add_model(rf_spec)
rf_fit <- fit(rf_wf, ames_train)
# get predictor set:
rf_features <-
rf_fit %>%
extract_fit_engine() %>%
predictors() #<- the caret funciton
head(rf_features)
#> [1] "MS_SubClass" "MS_Zoning" "Lot_Frontage" "Lot_Shape" "Lot_Config"
#> [6] "Neighborhood"
# You get an error here:
ames_test %>%
select(all_of(rf_features)) %>%
predict(rf_fit, new_data = .)
#> Error in `validate_column_names()`:
#> ! The following required columns are missing: 'Lot_Area',
#> 'Street', 'Alley', 'Land_Contour', 'Utilities', 'Land_Slope',
#> 'Condition_2', 'Year_Built', 'Year_Remod_Add', 'Roof_Matl',
#> 'Mas_Vnr_Area', 'Bsmt_Cond', 'BsmtFin_SF_1', 'BsmtFin_Type_2',
#> 'BsmtFin_SF_2', 'Bsmt_Unf_SF', 'Total_Bsmt_SF', 'Heating',
#> 'First_Flr_SF', 'Second_Flr_SF', 'Gr_Liv_Area', 'Bsmt_Full_Bath',
#> 'Full_Bath', 'Half_Bath', 'Bedroom_AbvGr', 'Kitchen_AbvGr',
#> 'TotRms_AbvGrd', 'Functional', 'Fireplaces', 'Garage_Cars',
#> 'Garage_Area', 'Wood_Deck_SF', 'Open_Porch_SF', 'Enclosed_Porch',
#> 'Three_season_porch', 'Screen_Porch', 'Pool_Area', 'Pool_QC',
#> 'Misc_Feature', 'Misc_Val', 'Mo_Sold', 'Latitude'.
由reprex包(v2.0.1)创建于2022-11-21
这个错误来自工作流包,但是底层的建模包也会出错。