r语言 - 如何从防风草拟合对象中提取预测因子



我有以下预测模型:

library(tidymodels)
data(ames)
set.seed(4595)
data_split <- initial_split(ames, strata = "Sale_Price", prop = 0.75)
ames_train <- training(data_split)
ames_test  <- testing(data_split)
rec <- recipe(Sale_Price ~ ., data = ames_train)
norm_trans <- rec %>%
step_zv(all_predictors()) %>%
step_nzv(all_predictors())  %>% 
step_corr(all_numeric_predictors(), threshold = 0.1)
# Preprocessing 
norm_obj <- prep(norm_trans, training = ames_train)
rf_ames_train <- bake(norm_obj, ames_train) %>%
dplyr::select(Sale_Price, everything()) %>%
as.data.frame()
dim(rf_ames_train )

rf_xy_fit <- rand_forest(mode = "regression") %>%
set_engine("ranger") %>%
fit_xy(
x = rf_ames_train,
y = log10(rf_ames_train$Sale_Price)
)

注意,经过预处理步骤后,特征的数量从74个减少到33个。

dim(rf_ames_train )
# 33

目前,我必须显式地传递函数中的预测器:

preds <- colnames(rf_ames_train) 
my_pred_function <- function (fit = NULL, test_data = NULL, predictors = NULL) {

test_results <- test_data %>%
select(Sale_Price) %>%
mutate(Sale_Price = log10(Sale_Price)) %>%
bind_cols(
predict(fit, new_data = ames_test[, predictors])
)
test_results

}
my_pred_function(fit = rf_xy_fit, test_data = ames_test, predictors = preds)

在上面的函数调用中显示为predictors = preds

在实践中,我必须将rf_xy_fitpreds保存为两个RDS文件,然后再次读取它们。这很容易出错,很麻烦。

我想绕过这个显式传递。有没有办法直接从rf_xy_fit中提取出来?

在这种情况下,您将受益于使用工作流包。这允许您将预处理代码与模型拟合代码组合在一起

library(tidymodels)
data(ames)
set.seed(4595)
# Notice how I did log transformation before doing the splitting to assure that it is not on both testing and training data sets.
ames <- ames %>%
mutate(Sale_Price = log10(Sale_Price))

data_split <- initial_split(ames, strata = "Sale_Price", prop = 0.75)
ames_train <- training(data_split)
ames_test  <- testing(data_split)
rec <- recipe(Sale_Price ~ ., data = ames_train)
norm_trans <- rec %>%
step_zv(all_predictors()) %>%
step_nzv(all_predictors())  %>% 
step_corr(all_numeric_predictors(), threshold = 0.1)
rf_spec <- rand_forest(mode = "regression") %>%
set_engine("ranger")
rf_wf <- workflow() %>%
add_recipe(norm_trans) %>%
add_model(rf_spec)
rf_fit <- fit(rf_wf, ames_train)
predict(rf_fit, new_data = ames_train)
#> # A tibble: 2,197 × 1
#>    .pred
#>    <dbl>
#>  1  5.09
#>  2  5.12
#>  3  5.01
#>  4  4.99
#>  5  5.12
#>  6  5.07
#>  7  4.90
#>  8  5.09
#>  9  5.13
#> 10  5.08
#> # … with 2,187 more rows

创建于2022-11-21与reprex v2.0.2

根据你的评论补充email的回答…

请记住,大多数R建模函数将期望原始的特性集,即使其中一些根本没有使用。这是R的公式/model.matrix()机器的副产品。

对于菜谱,这取决于您使用的步骤。

你可以在没有它们的情况下改装最终模型,但你可能不会得到完全相同的模型。在许多情况下,获得特征子集的过程取决于最初传递了多少。

我正在为此编写一个tidymodels api,但是插入符号有一个可以获得模型实际使用的预测器列表。参见示例:

library(caret)
#> Loading required package: ggplot2
#> Loading required package: lattice
library(tidymodels)

tidymodels_prefer()
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)

data(ames)
set.seed(4595)
ames <- ames %>%
mutate(Sale_Price = log10(Sale_Price))
data_split <- initial_split(ames, strata = "Sale_Price", prop = 0.75)
ames_train <- training(data_split)
ames_test  <- testing(data_split)
rec <- recipe(Sale_Price ~ ., data = ames_train)
norm_trans <- rec %>%
step_zv(all_predictors()) %>%
step_nzv(all_predictors())  %>% 
step_corr(all_numeric_predictors(), threshold = 0.1)
rf_spec <- rand_forest(mode = "regression") %>%
set_engine("ranger")
rf_wf <- workflow() %>%
add_recipe(norm_trans) %>%
add_model(rf_spec)
rf_fit <- fit(rf_wf, ames_train)
# get predictor set:
rf_features <- 
rf_fit %>% 
extract_fit_engine() %>% 
predictors()  #<- the caret funciton
head(rf_features)
#> [1] "MS_SubClass"  "MS_Zoning"    "Lot_Frontage" "Lot_Shape"    "Lot_Config"  
#> [6] "Neighborhood"
# You get an error here: 
ames_test %>% 
select(all_of(rf_features)) %>% 
predict(rf_fit, new_data = .)
#> Error in `validate_column_names()`:
#> ! The following required columns are missing: 'Lot_Area', 
#> 'Street', 'Alley', 'Land_Contour', 'Utilities', 'Land_Slope',
#> 'Condition_2', 'Year_Built', 'Year_Remod_Add', 'Roof_Matl', 
#> 'Mas_Vnr_Area', 'Bsmt_Cond', 'BsmtFin_SF_1', 'BsmtFin_Type_2', 
#> 'BsmtFin_SF_2', 'Bsmt_Unf_SF', 'Total_Bsmt_SF', 'Heating', 
#> 'First_Flr_SF', 'Second_Flr_SF', 'Gr_Liv_Area', 'Bsmt_Full_Bath', 
#> 'Full_Bath', 'Half_Bath', 'Bedroom_AbvGr', 'Kitchen_AbvGr', 
#> 'TotRms_AbvGrd', 'Functional', 'Fireplaces', 'Garage_Cars',
#> 'Garage_Area', 'Wood_Deck_SF', 'Open_Porch_SF', 'Enclosed_Porch',
#> 'Three_season_porch', 'Screen_Porch', 'Pool_Area', 'Pool_QC',
#> 'Misc_Feature', 'Misc_Val', 'Mo_Sold', 'Latitude'.

由reprex包(v2.0.1)创建于2022-11-21

这个错误来自工作流包,但是底层的建模包也会出错。

相关内容

  • 没有找到相关文章

最新更新