我定义了一个自定义度量,允许在评估标准度量(如rmse
(之前使用外部函数转换prediction$data
。如果我尝试在没有并行化的情况下调整参数,一切都会顺利进行,但是如果我启动并行化会话,它似乎不再找到外部函数,尽管它是在全局环境中声明的。
library(compiler)
library(mlr)
library(parallelMap)
library(parallel)
# define function
inverse_fun = function(x){x^2}
inverse_fun = Vectorize(inverse_fun)
inverse_fun = cmpfun(inverse_fun, options=list(suppressUndefined=T))
assign('inverse_fun', inverse_fun, envir = .GlobalEnv)
tuning_criterion = 'rmse'
# define a new measure that applies inverse_fun to prediction and evaluates rmse
original_measure = eval(parse(text=tuning_criterion))
transf_measure_fun = function(task, model, pred, feats, extra.args){
# transform back to original value
pred$data$truth = inverse_fun(pred$data$truth)
pred$data$response = inverse_fun(pred$data$response)
return(original_measure$fun(task, model, pred, feats, extra.args))
}
transf_measure = makeMeasure(
id = 'ii', name = 'ccc',
properties = original_measure$properties,
minimize = original_measure$minimize, best = original_measure$best, worst = original_measure$worst,
fun = transf_measure_fun)
transf_measure = setAggregation(transf_measure, original_measure$aggr)
aggregated_measure = list(transf_measure, setAggregation(transf_measure, test.sd), setAggregation(transf_measure, train.mean), setAggregation(transf_measure, train.sd))
# train and predict
lrn.lm = makeLearner("regr.ksvm")
mod.lm = train(lrn.lm, bh.task)
task.pred.lm = predict(mod.lm, task = bh.task)
# inverse function on prediction
inv_pred = task.pred.lm
inv_pred$data$truth = inverse_fun(inv_pred$data$truth)
inv_pred$data$response = inverse_fun(inv_pred$data$response)
# check for performance match
performance(task.pred.lm, transf_measure)
performance(inv_pred, rmse)
# tuning
discrete_ps = makeParamSet(
makeDiscreteParam("C", values = c(0.5, 1.0, 1.5, 2.0)),
makeDiscreteParam("sigma", values = c(0.5, 1.0, 1.5, 2.0))
)
ctrl = makeTuneControlGrid()
rdesc = makeResampleDesc("CV", iters = 3L)
# this works
res = tuneParams(lrn.lm, task = bh.task, resampling = rdesc,
par.set = discrete_ps, control = ctrl, measures = transf_measure)
# try with parallelization - doesn't work
current_os = Sys.info()[['sysname']] # detect OS
if (current_os == "Windows"){
set.seed(1, "L'Ecuyer-CMRG")
parallelStart(mode = "socket", cpus = detectCores(), show.info = F)
parallel::clusterSetRNGStream(iseed = 1)
} else if (current_os == "Linux"){
set.seed(1, "L'Ecuyer-CMRG")
parallelStart(mode = "multicore", cpus = detectCores(), show.info = F)
} else {
cat('nn#### OS not recognized, check parallelization initnn')
}
res = tuneParams(lrn.lm, task = bh.task, resampling = rdesc,
par.set = discrete_ps, control = ctrl, measures = transf_measure)
parallelStop()
收到以下错误:
Error in stopWithJobErrorMessages(inds, vcapply(result.list[inds], as.character)) :
Errors occurred in 16 slave jobs, displaying at most 10 of them:
00001: Error in inverse_fun(pred$data$truth) :
cannot find "inverse_fun"
我试图用extra.args
传递函数,但出现错误
original_measure = eval(parse(text=tuning_criterion))
transf_measure_fun = function(task, model, pred, feats, extra.args){
# transform back to original value
pred$data$truth = extra.args$inv_fun(pred$data$truth)
pred$data$response = extra.args$inv_fun(pred$data$response)
return(original_measure$fun(task, model, pred, feats, extra.args))
}
transf_measure = makeMeasure(
id = 'ii', name = 'ccc',
properties = original_measure$properties,
minimize = original_measure$minimize, best = original_measure$best, worst = original_measure$worst,
fun = transf_measure_fun(extra.args = list(inv_fun = inverse_fun))
)
我得到Error in FUN(X[[i]], ...) : argument "pred" is missing, with no default
提前致谢
您需要使用parallelMap::parallelExport()
导出自定义对象。
library(mlr)
#> Loading required package: ParamHelpers
library(parallelMap)
library(compiler)
# define function
inverse_fun = function(x){x^2}
inverse_fun = Vectorize(inverse_fun)
inverse_fun = cmpfun(inverse_fun, options=list(suppressUndefined=T))
assign('inverse_fun', inverse_fun, envir = .GlobalEnv)
tuning_criterion = 'rmse'
# define a new measure that applies inverse_fun to prediction and evaluates rmse
original_measure = eval(parse(text=tuning_criterion))
transf_measure_fun = function(task, model, pred, feats, extra.args){
# transform back to original value
pred$data$truth = inverse_fun(pred$data$truth)
pred$data$response = inverse_fun(pred$data$response)
return(original_measure$fun(task, model, pred, feats, extra.args))
}
transf_measure = makeMeasure(
id = 'ii', name = 'ccc',
properties = original_measure$properties,
minimize = original_measure$minimize, best = original_measure$best, worst = original_measure$worst,
fun = transf_measure_fun)
transf_measure = setAggregation(transf_measure, original_measure$aggr)
# tuning
discrete_ps = makeParamSet(
makeDiscreteParam("C", values = c(0.5, 1.0, 1.5, 2.0)),
makeDiscreteParam("sigma", values = c(0.5, 1.0, 1.5, 2.0))
)
ctrl = makeTuneControlGrid()
rdesc = makeResampleDesc("CV", iters = 3L)
lrn.lm = makeLearner("regr.ksvm")
set.seed(1, "L'Ecuyer-CMRG")
parallelStart(mode = "socket", cpus = 2, show.info = F)
parallelExport("inverse_fun", "original_measure")
res = tuneParams(lrn.lm, task = bh.task, resampling = rdesc,
par.set = discrete_ps, control = ctrl, measures = transf_measure)
#> [Tune] Started tuning learner regr.ksvm for parameter set:
#> Type len Def Constr Req Tunable Trafo
#> C discrete - - 0.5,1,1.5,2 - TRUE -
#> sigma discrete - - 0.5,1,1.5,2 - TRUE -
#> With control class: TuneControlGrid
#> Imputation value: Inf
#> [Tune] Result: C=2; sigma=0.5 : ii.test.rmse=270.8008465
parallelStop()
创建于 2019-10-08 由 reprex 软件包 (v0.3.0(
会话信息
devtools::session_info()
#> ─ Session info ──────────────────────────────────────────────────────────
#> setting value
#> version R version 3.6.1 (2019-07-05)
#> os Arch Linux
#> system x86_64, linux-gnu
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/Berlin
#> date 2019-10-08
#>
#> ─ Packages ──────────────────────────────────────────────────────────────
#> ! package * version date lib
#> assertthat 0.2.1 2019-03-21 [1]
#> backports 1.1.5 2019-10-02 [1]
#> BBmisc 1.11 2017-03-10 [1]
#> callr 3.3.2 2019-09-22 [1]
#> checkmate 1.9.4 2019-07-04 [1]
#> cli 1.1.0 2019-03-19 [1]
#> colorspace 1.4-1 2019-03-18 [1]
#> crayon 1.3.4 2017-09-16 [1]
#> data.table 1.12.4 2019-10-03 [1]
#> desc 1.2.0 2018-05-01 [1]
#> devtools 2.2.1 2019-09-24 [1]
#> digest 0.6.21 2019-09-20 [1]
#> dplyr 0.8.3 2019-07-04 [1]
#> ellipsis 0.3.0 2019-09-20 [1]
#> evaluate 0.14 2019-05-28 [1]
#> fastmatch 1.1-0 2017-01-28 [1]
#> fs 1.3.1 2019-05-06 [1]
#> ggplot2 3.2.1 2019-08-10 [1]
#> glue 1.3.1 2019-03-12 [1]
#> gtable 0.3.0 2019-03-25 [1]
#> highr 0.8 2019-03-20 [1]
#> htmltools 0.4.0 2019-10-04 [1]
#> kernlab 0.9-27 2018-08-10 [1]
#> knitr 1.25 2019-09-18 [1]
#> lattice 0.20-38 2018-11-04 [1]
#> lazyeval 0.2.2 2019-03-15 [1]
#> magrittr 1.5 2014-11-22 [1]
#> Matrix 1.2-17 2019-03-22 [1]
#> memoise 1.1.0 2017-04-21 [1]
#> mlr * 2.15.0.9000 2019-10-08 [1]
#> munsell 0.5.0 2018-06-12 [1]
#> parallelMap * 1.4 2019-05-17 [1]
#> ParamHelpers * 1.12 2019-01-18 [1]
#> pillar 1.4.2 2019-06-29 [1]
#> pkgbuild 1.0.5 2019-08-26 [1]
#> pkgconfig 2.0.3 2019-09-22 [1]
#> pkgload 1.0.2 2018-10-29 [1]
#> prettyunits 1.0.2 2015-07-13 [1]
#> processx 3.4.1 2019-07-18 [1]
#> ps 1.3.0 2018-12-21 [1]
#> purrr 0.3.2 2019-03-15 [1]
#> R6 2.4.0 2019-02-14 [1]
#> Rcpp 1.0.2 2019-07-25 [1]
#> remotes 2.1.0 2019-06-24 [1]
#> rlang 0.4.0 2019-06-25 [1]
#> rmarkdown 1.16 2019-10-01 [1]
#> rprojroot 1.3-2 2018-01-03 [1]
#> scales 1.0.0 2018-08-09 [1]
#> sessioninfo 1.1.1 2018-11-05 [1]
#> stringi 1.4.3 2019-03-12 [1]
#> stringr 1.4.0 2019-02-10 [1]
#> R survival 2.44-1.1 <NA> [2]
#> testthat 2.2.1 2019-07-25 [1]
#> tibble 2.1.3 2019-06-06 [1]
#> tidyselect 0.2.5 2018-10-11 [1]
#> usethis 1.5.1.9000 2019-10-07 [1]
#> withr 2.1.2 2018-03-15 [1]
#> xfun 0.10 2019-10-01 [1]
#> XML 3.98-1.20 2019-06-06 [1]
#> yaml 2.2.0 2018-07-25 [1]
#> source
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> local
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> <NA>
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#> Github (r-lib/usethis@3015465)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.0)
#>
#> [1] /home/pjs/R/x86_64-pc-linux-gnu-library/3.6
#> [2] /usr/lib/R/library
#>
#> R ── Package was removed from disk.