带有ranger包的r-fit_resamples失败



尝试使用交叉折叠重采样,并从ranger包中拟合一个随机林。不重新采样的拟合有效,但一旦我尝试重新采样拟合,它就会失败,并出现以下错误。

考虑以下df

df<-structure(list(a = c(1379405931, 732812609, 18614430, 1961678341, 
2362202769, 55687714, 72044715, 236503454, 61988734, 2524712675, 
98081131, 1366513385, 48203585, 697397991, 28132854), b = structure(c(1L, 
6L, 2L, 5L, 7L, 8L, 8L, 1L, 3L, 4L, 3L, 5L, 7L, 2L, 2L), .Label = c("CA", 
"IA", "IL", "LA", "MA", "MN", "TX", "WI"), class = "factor"), 
c = structure(c(2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 2L, 1L, 
2L, 2L, 2L, 1L), .Label = c("R", "U"), class = "factor"), 
d = structure(c(3L, 3L, 1L, 3L, 3L, 1L, 1L, 3L, 1L, 3L, 1L, 
3L, 2L, 3L, 1L), .Label = c("CAH", "LTCH", "STH"), class = "factor"), 
e = structure(c(3L, 2L, 3L, 3L, 1L, 3L, 3L, 3L, 2L, 4L, 2L, 
2L, 3L, 3L, 3L), .Label = c("cancer", "general long term", 
"psychiatric", "rehabilitation"), class = "factor")), row.names = c(NA, 
-15L), class = c("tbl_df", "tbl", "data.frame"))

以下简单的配合工作没有问题

library(tidymodels)
library(ranger)
rf_spec <- rand_forest(mode = 'regression') %>% 
set_engine('ranger')

rf_spec %>% 
fit(a ~. , data = df)

但只要我想通过运行交叉验证

rf_folds <- vfold_cv(df, strata = c)
fit_resamples(a ~ . ,
rf_spec,
rf_folds)

跟随错误

model:解析中出错。formula(formula,data,env=parent.frame(((:错误:公式接口中的列名非法。修复列名称或在ranger中使用替代接口。

上面的注释是正确的,这里的问题来源是因子列中的空格。重新采样的函数和普通旧拟合的函数目前处理这一问题的方式不同,我们正在积极研究如何为用户解决这一问题。谢谢你的耐心!

同时,我建议设置一个简单的workflow()和一个recipe(),它们将一起为您处理所有必要的伪变量munging。

library(tidymodels)
rf_spec <- rand_forest(mode = "regression") %>% 
set_engine("ranger")
rf_wf <- workflow() %>%
add_model(rf_spec) %>%
add_recipe(recipe(a ~ ., data = df))

fit(rf_wf, data = df)
#> ══ Workflow [trained] ═══════════════════════════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: rand_forest()
#> 
#> ── Preprocessor ─────────────────────────────────────────────────────────────────────────────────────────────────
#> 0 Recipe Steps
#> 
#> ── Model ────────────────────────────────────────────────────────────────────────────────────────────────────────
#> Ranger result
#> 
#> Call:
#>  ranger::ranger(formula = formula, data = data, num.threads = 1,      verbose = FALSE, seed = sample.int(10^5, 1)) 
#> 
#> Type:                             Regression 
#> Number of trees:                  500 
#> Sample size:                      15 
#> Number of independent variables:  4 
#> Mtry:                             2 
#> Target node size:                 5 
#> Variable importance mode:         none 
#> Splitrule:                        variance 
#> OOB prediction error (MSE):       4.7042e+17 
#> R squared (OOB):                  0.4341146
rf_folds <- vfold_cv(df, strata = c)
fit_resamples(rf_wf,
rf_folds)
#> #  10-fold cross-validation using stratification 
#> # A tibble: 9 x 4
#>   splits         id    .metrics         .notes          
#>   <list>         <chr> <list>           <list>          
#> 1 <split [13/2]> Fold1 <tibble [2 × 3]> <tibble [0 × 1]>
#> 2 <split [13/2]> Fold2 <tibble [2 × 3]> <tibble [0 × 1]>
#> 3 <split [13/2]> Fold3 <tibble [2 × 3]> <tibble [0 × 1]>
#> 4 <split [13/2]> Fold4 <tibble [2 × 3]> <tibble [0 × 1]>
#> 5 <split [13/2]> Fold5 <tibble [2 × 3]> <tibble [0 × 1]>
#> 6 <split [13/2]> Fold6 <tibble [2 × 3]> <tibble [0 × 1]>
#> 7 <split [14/1]> Fold7 <tibble [2 × 3]> <tibble [0 × 1]>
#> 8 <split [14/1]> Fold8 <tibble [2 × 3]> <tibble [0 × 1]>
#> 9 <split [14/1]> Fold9 <tibble [2 × 3]> <tibble [0 × 1]>

由reprex包(v0.3.0(于2020-03-20创建

Julia和我打赌,这样她就得到了因果报应。我得到了同样的答案(我做她做的事,但速度较慢(:

这是一个类型的错误,我们一直在努力使其不出错。这很复杂。让我解释一下。

ranger是少数几个公式方法不创建伪变量的R包之一(因为它不需要它们(。

tune中的基础结构使用workflows包来处理公式,然后将结果数据移交给ranger默认情况下,工作流确实会创建伪变量,并且由于某些因子级别不是有效的R列名(例如"general long term"(,ranger()会引发错误。

(我知道你没有使用工作流,但这就是幕后发生的事情(。

我们正在做最好的事情,因为大多数用户不知道许多基于树的模型包不会产生伪变量。更复杂的是,parsnip(还(不使用工作流,也没有给你一个错误。

目前的解决方案

使用简单的配方而不是配方:

library(tidymodels)
#> ── Attaching packages ─────────────────────────────────────────────────────────────────────────────────── tidymodels 0.1.0 ──
#> ✓ broom     0.5.4      ✓ recipes   0.1.10
#> ✓ dials     0.0.4      ✓ rsample   0.0.5 
#> ✓ dplyr     0.8.5      ✓ tibble    2.1.3 
#> ✓ ggplot2   3.3.0      ✓ tune      0.0.1 
#> ✓ infer     0.5.1      ✓ workflows 0.1.0 
#> ✓ parsnip   0.0.5      ✓ yardstick 0.0.5 
#> ✓ purrr     0.3.3
#> ── Conflicts ────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard()  masks scales::discard()
#> x dplyr::filter()   masks stats::filter()
#> x dplyr::lag()      masks stats::lag()
#> x ggplot2::margin() masks dials::margin()
#> x recipes::step()   masks stats::step()
df<-structure(list(a = c(1379405931, 732812609, 18614430, 1961678341, 
2362202769, 55687714, 72044715, 236503454, 61988734, 2524712675, 
98081131, 1366513385, 48203585, 697397991, 28132854), b = structure(c(1L, 
6L, 2L, 5L, 7L, 8L, 8L, 1L, 3L, 4L, 3L, 5L, 7L, 2L, 2L), .Label = c("CA", 
"IA", "IL", "LA", "MA", "MN", "TX", "WI"), class = "factor"), 
c = structure(c(2L, 2L, 1L, 2L, 2L, 1L, 1L, 2L, 1L, 2L, 1L, 
2L, 2L, 2L, 1L), .Label = c("R", "U"), class = "factor"), 
d = structure(c(3L, 3L, 1L, 3L, 3L, 1L, 1L, 3L, 1L, 3L, 1L, 
3L, 2L, 3L, 1L), .Label = c("CAH", "LTCH", "STH"), class = "factor"), 
e = structure(c(3L, 2L, 3L, 3L, 1L, 3L, 3L, 3L, 2L, 4L, 2L, 
2L, 3L, 3L, 3L), .Label = c("cancer", "general long term", 
"psychiatric", "rehabilitation"), class = "factor")), row.names = c(NA, 
-15L), class = c("tbl_df", "tbl", "data.frame"))

library(tidymodels)
library(ranger)
rf_spec <- rand_forest(mode = 'regression') %>% 
set_engine('ranger')
rf_folds <- vfold_cv(df, strata = c)
fit_resamples(recipe(a ~ ., data = df),  rf_spec, rf_folds)
#> #  10-fold cross-validation using stratification 
#> # A tibble: 9 x 4
#>   splits         id    .metrics         .notes          
#> * <list>         <chr> <list>           <list>          
#> 1 <split [13/2]> Fold1 <tibble [2 × 3]> <tibble [0 × 1]>
#> 2 <split [13/2]> Fold2 <tibble [2 × 3]> <tibble [0 × 1]>
#> 3 <split [13/2]> Fold3 <tibble [2 × 3]> <tibble [0 × 1]>
#> 4 <split [13/2]> Fold4 <tibble [2 × 3]> <tibble [0 × 1]>
#> 5 <split [13/2]> Fold5 <tibble [2 × 3]> <tibble [0 × 1]>
#> 6 <split [13/2]> Fold6 <tibble [2 × 3]> <tibble [0 × 1]>
#> 7 <split [14/1]> Fold7 <tibble [2 × 3]> <tibble [0 × 1]>
#> 8 <split [14/1]> Fold8 <tibble [2 × 3]> <tibble [0 × 1]>
#> 9 <split [14/1]> Fold9 <tibble [2 × 3]> <tibble [0 × 1]>
# FYI `tune` 0.0.2 will require a different argument order: 
# rf_spec %>% fit_resamples(recipe(a ~ ., data = df), rf_folds)

由reprex包(v0.3.0(于2020-03-20创建

相关内容

  • 没有找到相关文章

最新更新