mlrMBO rBayesian通过插入符号优化R keras模型的错误



我正在尝试通过Keras包(和张量流(实现多层感知器来运行快速MLP。我想使用贝叶斯优化来训练算法的超参数。不过,我收到一条错误消息,说"ValueError:rate 既不是标量也不是标量张量",然后它会从 keras 打印 dropout 参数的随机值。然后,我还从插入符号中得到一个错误,即"重新采样的性能度量中缺少值"。我可以让该过程适用于非插入符号/keras 算法。

这是我应用于鸢尾花数据集的代码,应该重现错误:

library(caret)
library(rBayesianOptimization) # to create cv folds and for bayesian optimisation
library(mlrMBO)  # for bayesian optimisation
library(tensorflow)
library(keras)
iris$Speciesset=as.factor(iris$Species=="setosa")
levels(iris$Speciesset) = c("nonset","set")
rounds=5
#tunning via bays search
ctrl = trainControl(method = "cv", number = rounds, 
summaryFunction = twoClassSummary,
classProbs = TRUE, search= "grid",
verboseIter=FALSE,savePredictions = "all")

# objective function: we want to maximise the log likelihood by tuning MLP parameters
obj.fun  <- smoof::makeSingleObjectiveFunction(
name = "mlp_cv_bayes",
fn =   function(x){    
train_model = caret::train(Speciesset~., 
data=iris, trControl=ctrl, 
metric="ROC", method="mlpKerasDropout", 
tuneGrid= expand.grid(
size = x["size"],
dropout = x["dropout"],
batch_size = x["batch_size"],
lr = x["lr"],
rho = x["rho"],
activation = x["activation"],
decay = x["decay"]
))
train_model$results$ROC
},
par.set = makeParamSet(
makeIntegerParam("size", lower= 10,      upper = 500),
makeNumericParam("dropout", lower= 0.1,      upper = .9),
makeIntegerParam("batch_size", lower= 2000,      upper = 15000),
makeNumericParam("lr", lower= 0.01,      upper = .9),
makeNumericParam("rho", lower= 0.01,      upper = .9),
makeNumericParam("decay", lower= 0.00001,      upper = .9),
makeDiscreteParam("activation", values = c("relu", "tanh", "sigmoid"))

),
minimize = FALSE
)
control = makeMBOControl()
control = setMBOControlTermination(control, iters = 20)
des = generateDesign(par.set = getParamSet(obj.funnb), 
fun = lhs::randomLHS)
run = mbo(fun = obj.fun, 
control = control, 
show.info = TRUE, design = des)

这只是一个语法问题

以下内容对我有用:

library(caret)
library(rBayesianOptimization) #I don't think this is needed for the example
library(mlrMBO)  
library(tensorflow)
library(keras)
iris$Speciesset=as.factor(iris$Species=="setosa")
levels(iris$Speciesset) = c("nonset","set")

ctrl <- trainControl(method = "cv",
number = 5, 
summaryFunction = twoClassSummary,
classProbs = TRUE,
search = "grid",
verboseIter = FALSE,
savePredictions = "all")
obj.fun  <- smoof::makeSingleObjectiveFunction(
name = "mlp_cv_bayes",
has.simple.signature = FALSE,
fn =   function(x){    
train_model = caret::train(Speciesset~., 
data = iris,
trControl = ctrl, 
metric =" ROC",
method = "mlpKerasDropout", 
tuneGrid = expand.grid(
size = x$size,
batch_size = x$batch_size,
dropout = x$dropout,
lr = x$lr,
rho = x$rho,
activation = x$activation,
decay = x$decay
),
preProcess = c("center", "scale"),
epochs = 10)
train_model$results$ROC
},
par.set = makeParamSet(
makeIntegerParam("size", lower = 8, upper = 32),
makeNumericParam("dropout", lower = 0.5, upper = .9),
makeIntegerParam("batch_size", lower= 32, upper = 64),
makeNumericParam("lr", lower = 0.01, upper = .9),
makeNumericParam("rho", lower = 0.01, upper = .9),
makeNumericParam("decay", lower = 0.00001, upper = .9),
makeDiscreteParam("activation", values = c("relu", "tanh", "sigmoid"))
),
minimize = FALSE
)
control <- makeMBOControl()
control <- setMBOControlTermination(control, iters = 20)
des <- generateDesign(par.set = getParamSet(obj.fun), 
fun = lhs::randomLHS)
run <- mbo(fun = obj.fun, 
control = control, 
show.info = TRUE,
design = des)

即使在 GPU 上也需要一些时间

run
#output
Recommended parameters:
size=14; dropout=0.675; batch_size=58; lr=0.295; rho=0.776; decay=0.613; activation=tanh
Objective: y = 1.000
Optimization path
10 + 20 entries in total, displaying last 10 (or less):
size   dropout batch_size        lr       rho      decay activation y dob eol error.message exec.time cb error.model train.time prop.type propose.time se mean
21   17 0.8112161         55 0.2455501 0.1980725 0.84595374       relu 1  11  NA          <NA>     22.84 -1        <NA>       0.02 infill_cb         0.89  0    1
22   29 0.8613471         55 0.8087169 0.1619325 0.58929373    sigmoid 1  12  NA          <NA>     22.60 -1        <NA>       0.00 infill_cb         0.72  0    1
23   10 0.6228074         44 0.1214947 0.7515075 0.34196674       relu 1  13  NA          <NA>     22.73 -1        <NA>       0.00 infill_cb         0.92  0    1
24   23 0.5021470         51 0.8890780 0.3033280 0.75097924    sigmoid 1  14  NA          <NA>     22.63 -1        <NA>       0.00 infill_cb         0.73  0    1
25   26 0.5572763         52 0.2083211 0.6842752 0.12736857    sigmoid 1  15  NA          <NA>     22.54 -1        <NA>       0.01 infill_cb         0.92  0    1
26   20 0.6904176         46 0.4408440 0.8439430 0.53462843       tanh 1  16  NA          <NA>     22.70 -1        <NA>       0.01 infill_cb         0.75  0    1
27   32 0.8357865         62 0.8108571 0.4191330 0.02935206       tanh 1  17  NA          <NA>     22.45 -1        <NA>       0.02 infill_cb         0.95  0    1
28   23 0.8332311         45 0.3894060 0.7166899 0.24697168       relu 1  18  NA          <NA>     22.77 -1        <NA>       0.00 infill_cb         0.81  0    1
29   30 0.6880777         58 0.3077176 0.8634141 0.41809902    sigmoid 1  19  NA          <NA>     22.78 -1        <NA>       0.01 infill_cb         1.00  0    1
30   27 0.6603150         46 0.3338476 0.1976979 0.17276289       tanh 1  20  NA          <NA>     22.88 -1        <NA>       0.00 infill_cb         0.86  0    1
lambda
21      2
22      2
23      2
24      2
25      2
26      2
27      2
28      2
29      2
30      2

这是我的信息

sessionInfo()
R version 4.0.0 (2020-04-24)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 18362)
Matrix products: default
Random number generation:
RNG:     Mersenne-Twister 
Normal:  Inversion 
Sample:  Rounding 
locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252    LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    
attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     
other attached packages:
[1] keras_2.3.0.0.9000          tensorflow_2.2.0            mlrMBO_1.1.4                smoof_1.6.0.2               checkmate_2.0.0            
[6] mlr_2.17.1                  ParamHelpers_1.14           rBayesianOptimization_1.1.0 caret_6.0-86                ggplot2_3.3.1              
[11] lattice_0.20-41            
loaded via a namespace (and not attached):
[1] nlme_3.1-148         biomartr_0.9.2       lubridate_1.7.8      bit64_0.9-7          RColorBrewer_1.1-2   progress_1.2.2       httr_1.4.1          
[8] tools_4.0.0          backports_1.1.7      R6_2.4.1             rpart_4.1-15         lazyeval_0.2.2       DBI_1.1.0            BiocGenerics_0.34.0 
[15] colorspace_1.4-1     nnet_7.3-13          withr_2.2.0          tidyselect_1.1.0     prettyunits_1.1.1    mco_1.0-15.1         bit_1.1-15.2        
[22] curl_4.3             compiler_4.0.0       parallelMap_1.5.0    Biobase_2.48.0       plotly_4.9.2.1       scales_1.1.1         plot3D_1.3          
[29] askpass_1.1          tfruns_1.4           rappdirs_0.3.1       stringr_1.4.0        digest_0.6.25        XVector_0.28.0       base64enc_0.1-3     
[36] htmltools_0.4.0      pkgconfig_2.0.3      lhs_1.0.2            dbplyr_1.4.4         htmlwidgets_1.5.1    rlang_0.4.6          readxl_1.3.1        
[43] rstudioapi_0.11      RSQLite_2.2.0        BBmisc_1.11          generics_0.0.2       jsonlite_1.6.1       dplyr_1.0.0.9000     ModelMetrics_1.2.2.2
[50] zip_2.0.4            magrittr_1.5         Matrix_1.2-18        Rcpp_1.0.4.6         munsell_0.5.0        S4Vectors_0.26.1     reticulate_1.16     
[57] lifecycle_0.2.0      whisker_0.4          stringi_1.4.6        pROC_1.16.2          yaml_2.2.1           RJSONIO_1.3-1.4      MASS_7.3-51.6       
[64] zlibbioc_1.34.0      plyr_1.8.6           recipes_0.1.12       BiocFileCache_1.12.0 misc3d_0.8-4         grid_4.0.0           blob_1.2.1          
[71] parallel_4.0.0       crayon_1.3.4         Biostrings_2.56.0    splines_4.0.0        hms_0.5.3            zeallot_0.1.0        pillar_1.4.4        
[78] reshape2_1.4.4       codetools_0.2-16     biomaRt_2.44.0       stats4_4.0.0         fastmatch_1.1-0      XML_3.99-0.3         glue_1.4.1          
[85] data.table_1.12.8    vctrs_0.3.0          foreach_1.5.0        cellranger_1.1.0     tidyr_1.1.0          gtable_0.3.0         openssl_1.4.1       
[92] purrr_0.3.4          assertthat_0.2.1     xfun_0.14            gower_0.2.1          openxlsx_4.1.5       prodlim_2019.11.13   viridisLite_0.3.0   
[99] class_7.3-17         survival_3.1-12      timeDate_3043.102    tibble_3.0.1         iterators_1.0.12     tinytex_0.23         AnnotationDbi_1.50.0
[106] memoise_1.1.0        IRanges_2.22.2       lava_1.6.7           ellipsis_0.3.1       ipred_0.9-9 

相关内容

  • 没有找到相关文章

最新更新