r语言 - 如何保存和再次检索基于防风草/水瓜的H2O对象



我使用tidymodels的agua包编写了以下脚本:

library(tidymodels)
library(agua)
library(ggplot2)
theme_set(theme_bw())
h2o_start()
data(concrete)
set.seed(4595)
concrete_split <- initial_split(concrete, strata = compressive_strength)
concrete_train <- training(concrete_split)
concrete_test <- testing(concrete_split)
# run for a maximum of 120 seconds
auto_spec <-
auto_ml() %>%
set_engine("h2o", max_runtime_secs = 120, seed = 1) %>%
set_mode("regression")
normalized_rec <-
recipe(compressive_strength ~ ., data = concrete_train) %>%
step_normalize(all_predictors())
auto_wflow <-
workflow() %>%
add_model(auto_spec) %>%
add_recipe(normalized_rec)
auto_fit <- fit(auto_wflow, data = concrete_train)
saveRDS(auto_fit, file = "test.h2o.auto_fit.rds") #save the object
h2o_end()

我试图将auto_fit对象保存到一个文件中。但是当我试图检索它并使用它来预测测试数据时:

h2o_start()
auto_fit <- readRDS("test.h2o.auto_fit.rds")
predict(auto_fit, new_data = concrete_test)

我得到一个错误:

Error in `h2o_get_model()`:
! Model id does not exist on the h2o server.

该怎么做呢?

预期结果是:

predict(auto_fit, new_data = concrete_test)
#> # A tibble: 260 × 1
#>    .pred
#>    <dbl>
#>  1  40.0
#>  2  43.0
#>  3  38.2
#>  4  55.7
#>  5  41.4
#>  6  28.1
#>  7  53.2
#>  8  34.5
#>  9  51.1
#> 10  37.9
#> # … with 250 more rows

更新

听从Simon Couch的建议后

auto_fit <- fit(auto_wflow, data = concrete_train)
auto_fit_bundle <- bundle(auto_fit)
saveRDS(auto_fit_bundle, file = "test.h2o.auto_fit.rds") #save the object
h2o_end()
# to reload
h2o_start()
auto_fit_bundle <- readRDS("test.h2o.auto_fit.rds")
auto_fit <- unbundle(auto_fit_bundle)
predict(auto_fit, new_data = concrete_test)
rank_results(auto_fit)

我得到了这个错误信息:

Error in UseMethod("rank_results") : 
no applicable method for 'rank_results' applied to an object of class "c('H2ORegressionModel', 'H2OModel', 'Keyed')"

R中的一些模型对象需要本地序列化方法从file-h2o对象(以及包装它们的tidymodels对象)中保存和重新加载,这就是一个例子。

Posit的tidymodels和香根草团队最近合作开发了一个包bundle,它为本机序列化方法提供了一致的接口。关于h2o的文档在这里。

library(bundle)

简而言之,您将希望bundle()您准备保存的对象,用通常的saveRDS()保存它,然后,在您的新会话中,loadRDS()unbundle()加载的对象。unbundle()的输出是现成的模型对象。:)

# to save:
auto_fit <- fit(auto_wflow, data = concrete_train)
auto_fit_bundle <- bundle(auto_fit)
saveRDS(auto_fit_bundle, file = "test.h2o.auto_fit.rds") #save the object
h2o_end()
# to reload
h2o_start()
auto_fit_bundle <- readRDS("test.h2o.auto_fit.rds")
auto_fit <- unbundle(auto_fit_bundle)
predict(auto_fit, new_data = concrete_test)

相关内容

  • 没有找到相关文章