我正在使用 mlr 框架做一些事情,导致FeatureImp
为每个功能返回 1,但我无法将手指放在上面。下面是一个示例:
library(caret)
#> Carregando pacotes exigidos: lattice
#> Carregando pacotes exigidos: ggplot2
library(mlr)
#> Carregando pacotes exigidos: ParamHelpers
#>
#> Attaching package: 'mlr'
#> The following object is masked from 'package:caret':
#>
#> train
library(iml)
data("iris")
iris = iris[iris$Species != 'setosa',]
iris$Species = ifelse(iris$Species == 'virginica', 1, 0)
iris$Species = as.factor(iris$Species)
ind=createDataPartition(iris$Species, times=1, p=0.8, list=FALSE)
train=iris[ind,]
test=iris[-ind,]
remove(ind)
train.task=makeClassifTask(data=train, target = 'Species', positive = 1)
test.task=makeClassifTask(data=test, target = 'Species', positive = 1)
learner=list(
xgboost = makeLearner("classif.xgboost",predict.type = "prob"),
ksvm = makeLearner("classif.ksvm",predict.type = "prob"),
nnet = makeLearner("classif.nnet",predict.type = "prob"),
randomForest = makeLearner("classif.randomForest",predict.type = "prob")
)
model = lapply(learner, function(x) train(x, train.task))
#> # weights: 19
#> initial value 57.506055
#> iter 10 value 52.109027
#> iter 20 value 7.798098
#> iter 30 value 5.401193
#> iter 40 value 4.707935
#> iter 50 value 4.702049
#> final value 4.701710
#> converged
prediction = lapply(model, function(x) predict(x, test.task))
ensemble = makeStackedLearner(learner, super.learner = 'classif.randomForest', predict.type = 'prob',
method = "stack.cv", use.feat = FALSE)
model$ensemble = train(ensemble, train.task)
#> # weights: 19
#> initial value 43.712841
#> iter 10 value 5.444287
#> iter 20 value 4.536990
#> iter 30 value 4.527489
#> iter 40 value 4.481401
#> iter 50 value 4.481221
#> iter 50 value 4.481221
#> iter 50 value 4.481221
#> final value 4.481221
#> converged
#> # weights: 19
#> initial value 52.864011
#> iter 10 value 33.347827
#> iter 20 value 2.926847
#> iter 30 value 0.011104
#> final value 0.000055
#> converged
#> # weights: 19
#> initial value 44.627604
#> iter 10 value 31.360597
#> iter 20 value 5.798769
#> iter 30 value 4.290623
#> iter 40 value 3.751202
#> iter 50 value 3.547856
#> iter 60 value 3.469366
#> iter 70 value 3.373487
#> iter 80 value 3.317680
#> iter 90 value 3.310354
#> iter 100 value 3.301115
#> final value 3.301115
#> stopped after 100 iterations
#> # weights: 19
#> initial value 46.410266
#> iter 10 value 29.975896
#> iter 20 value 1.266423
#> iter 30 value 0.004667
#> final value 0.000052
#> converged
#> # weights: 19
#> initial value 52.665930
#> final value 44.361399
#> converged
#> # weights: 19
#> initial value 60.471973
#> iter 10 value 50.475349
#> iter 20 value 7.580138
#> iter 30 value 4.828646
#> iter 40 value 4.543112
#> iter 50 value 2.995374
#> iter 60 value 2.636710
#> iter 70 value 2.539857
#> iter 80 value 2.497281
#> iter 90 value 2.427158
#> iter 100 value 2.370383
#> final value 2.370383
#> stopped after 100 iterations
prediction$ensemble = predict(model$ensemble, test.task)
predictor = Predictor$new(model$ensemble,
data = train.task$env$data[which(names(train.task$env$data) != "Species")],
y = as.numeric(train.task$env$data$Species)-1)
imp = FeatureImp$new(predictor, loss = "ce")
imp$results
#> feature importance.05 importance importance.95 permutation.error
#> 1 Sepal.Length 1 1 1 1
#> 2 Sepal.Width 1 1 1 1
#> 3 Petal.Length 1 1 1 1
#> 4 Petal.Width 1 1 1 1
创建于 2020-01-23 由 reprex 软件包 (v0.3.0(
似乎这已通过 {iml} 的开发版本修复。
我可以使用当前的 CRAN 版本重现您的问题。
library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
library(mlr)
#> Loading required package: ParamHelpers
#> 'mlr' is in maintenance mode since July 2019. Future development
#> efforts will go into its successor 'mlr3' (<https://mlr3.mlr-org.com>).
#>
#> Attaching package: 'mlr'
#> The following object is masked from 'package:caret':
#>
#> train
library(iml)
data("iris")
iris = iris[iris$Species != "setosa", ]
iris$Species = ifelse(iris$Species == "virginica", 1, 0)
iris$Species = as.factor(iris$Species)
ind = createDataPartition(iris$Species, times = 1, p = 0.8, list = FALSE)
train = iris[ind, ]
test = iris[-ind, ]
remove(ind)
train.task = makeClassifTask(data = train, target = "Species", positive = 1)
test.task = makeClassifTask(data = test, target = "Species", positive = 1)
learner = list(
xgboost = makeLearner("classif.xgboost", predict.type = "prob"),
ksvm = makeLearner("classif.ksvm", predict.type = "prob"),
nnet = makeLearner("classif.nnet", predict.type = "prob"),
randomForest = makeLearner("classif.randomForest", predict.type = "prob")
)
model = lapply(learner, function(x) train(x, train.task))
#> # weights: 19
#> initial value 59.040647
#> iter 10 value 54.908003
#> iter 20 value 8.784817
#> iter 30 value 2.906017
#> iter 40 value 0.187334
#> iter 50 value 0.000610
#> final value 0.000059
#> converged
prediction = lapply(model, function(x) predict(x, test.task))
ensemble = makeStackedLearner(learner,
super.learner = "classif.randomForest", predict.type = "prob",
method = "stack.cv", use.feat = FALSE)
model$ensemble = train(ensemble, train.task)
#> # weights: 19
#> initial value 44.537254
#> iter 10 value 6.716784
#> iter 20 value 4.750452
#> iter 30 value 4.487501
#> iter 40 value 4.481250
#> final value 4.481222
#> converged
#> # weights: 19
#> initial value 54.135701
#> iter 10 value 13.081961
#> iter 20 value 1.676063
#> iter 30 value 0.002261
#> final value 0.000044
#> converged
#> # weights: 19
#> initial value 42.621635
#> iter 10 value 5.201573
#> iter 20 value 2.878946
#> iter 30 value 1.133911
#> iter 40 value 0.002784
#> iter 50 value 0.000726
#> final value 0.000037
#> converged
#> # weights: 19
#> initial value 43.795663
#> iter 10 value 4.478310
#> iter 20 value 1.811306
#> iter 30 value 0.027775
#> iter 40 value 0.004873
#> iter 50 value 0.001480
#> iter 60 value 0.000230
#> iter 70 value 0.000221
#> final value 0.000089
#> converged
#> # weights: 19
#> initial value 44.433321
#> iter 10 value 7.252874
#> iter 20 value 1.200457
#> iter 30 value 0.001668
#> final value 0.000063
#> converged
#> # weights: 19
#> initial value 67.012204
#> final value 55.451774
#> converged
prediction$ensemble = predict(model$ensemble, test.task)
predictor = Predictor$new(model$ensemble,
data = train.task$env$data[which(names(train.task$env$data) != "Species")],
y = as.numeric(train.task$env$data$Species) - 1)
imp = FeatureImp$new(predictor, loss = "ce")
imp$results
#> feature importance.05 importance importance.95 permutation.error
#> 1 Petal.Width 11.1 12.0 14.2 0.3000
#> 2 Petal.Length 10.3 11.5 13.1 0.2875
#> 3 Sepal.Length 3.3 4.5 6.3 0.1125
#> 4 Sepal.Width 2.1 3.5 4.0 0.0875
创建于 2020-01-23 由 reprex 软件包 (v0.3.0(
会话信息
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 3.6.2 Patched (2019-12-12 r77564)
#> os macOS Mojave 10.14.6
#> system x86_64, darwin15.6.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/Berlin
#> date 2020-01-23
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib
#> assertthat 0.2.1 2019-03-21 [1]
#> backports 1.1.5 2019-10-02 [1]
#> BBmisc 1.11 2017-03-10 [1]
#> callr 3.4.0 2019-12-09 [1]
#> caret * 6.0-85 2020-01-07 [1]
#> checkmate 1.9.4 2019-07-04 [1]
#> class 7.3-15 2019-01-01 [2]
#> cli 2.0.1.9000 2020-01-12 [1]
#> codetools 0.2-16 2018-12-24 [2]
#> colorspace 1.4-1 2019-03-18 [1]
#> crayon 1.3.4 2017-09-16 [1]
#> data.table 1.12.8 2019-12-09 [1]
#> desc 1.2.0 2018-05-01 [1]
#> devtools 2.2.1 2019-09-24 [1]
#> digest 0.6.23 2019-11-23 [1]
#> dplyr 0.8.3 2019-07-04 [1]
#> ellipsis 0.3.0 2019-09-20 [1]
#> evaluate 0.14 2019-05-28 [1]
#> fansi 0.4.1 2020-01-08 [1]
#> fastmatch 1.1-0 2017-01-28 [1]
#> foreach 1.4.7 2019-07-27 [1]
#> fs 1.3.1 2019-05-06 [1]
#> generics 0.0.2 2018-11-29 [1]
#> ggplot2 * 3.2.1 2019-08-10 [1]
#> glue 1.3.1 2019-03-12 [1]
#> gower 0.2.1 2019-05-14 [1]
#> gridExtra 2.3 2017-09-09 [1]
#> gtable 0.3.0 2019-03-25 [1]
#> highr 0.8 2019-03-20 [1]
#> htmltools 0.4.0 2019-10-04 [1]
#> iml * 0.9.0 2020-01-23 [1]
#> ipred 0.9-9 2019-04-28 [1]
#> iterators 1.0.12 2019-07-26 [1]
#> kernlab 0.9-29 2019-11-12 [1]
#> knitr 1.27 2020-01-16 [1]
#> lattice * 0.20-38 2018-11-04 [2]
#> lava 1.6.6 2019-08-01 [1]
#> lazyeval 0.2.2 2019-03-15 [1]
#> lifecycle 0.1.0 2019-08-01 [1]
#> lubridate 1.7.4 2018-04-11 [1]
#> magrittr 1.5 2014-11-22 [1]
#> MASS 7.3-51.4 2019-03-31 [1]
#> Matrix 1.2-18 2019-11-27 [2]
#> memoise 1.1.0 2017-04-21 [1]
#> Metrics 0.1.4 2018-07-09 [1]
#> mlr * 2.17.0.9000 2020-01-13 [1]
#> ModelMetrics 1.2.2.1 2020-01-13 [1]
#> munsell 0.5.0 2018-06-12 [1]
#> nlme 3.1-143 2019-12-10 [2]
#> nnet 7.3-12 2016-02-02 [2]
#> parallelMap 1.4 2019-05-17 [1]
#> ParamHelpers * 1.13.0.9000 2019-12-11 [1]
#> pillar 1.4.3 2019-12-20 [1]
#> pkgbuild 1.0.6 2019-10-09 [1]
#> pkgconfig 2.0.3 2019-09-22 [1]
#> pkgload 1.0.2 2018-10-29 [1]
#> plyr 1.8.5 2019-12-10 [1]
#> prediction 0.3.14 2019-06-17 [1]
#> prettyunits 1.1.0 2020-01-09 [1]
#> pROC 1.16.1 2020-01-14 [1]
#> processx 3.4.1 2019-07-18 [1]
#> prodlim 2019.11.13 2019-11-17 [1]
#> ps 1.3.0 2018-12-21 [1]
#> purrr 0.3.3 2019-10-18 [1]
#> R6 2.4.1 2019-11-12 [1]
#> randomForest 4.6-14 2018-03-25 [1]
#> Rcpp 1.0.3 2019-11-08 [1]
#> recipes 0.1.9 2020-01-07 [1]
#> remotes 2.1.0 2019-06-24 [1]
#> reshape2 1.4.3 2017-12-11 [1]
#> rlang 0.4.3 2020-01-22 [1]
#> rmarkdown 2.1 2020-01-20 [1]
#> rpart 4.1-15 2019-04-12 [1]
#> rprojroot 1.3-2 2018-01-03 [1]
#> scales 1.1.0 2019-11-18 [1]
#> sessioninfo 1.1.1 2018-11-05 [1]
#> stringi 1.4.5 2020-01-11 [1]
#> stringr 1.4.0 2019-02-10 [1]
#> survival 3.1-8 2019-12-03 [2]
#> testthat 2.3.1 2019-12-01 [1]
#> tibble 2.1.3 2019-06-06 [1]
#> tidyselect 0.2.5 2018-10-11 [1]
#> timeDate 3043.102 2018-02-21 [1]
#> usethis 1.5.1.9000 2020-01-17 [1]
#> withr 2.1.2 2018-03-15 [1]
#> xfun 0.12 2020-01-13 [1]
#> xgboost 0.90.0.2 2019-08-01 [1]
#> XML 3.99-0.3 2020-01-20 [1]
#> yaml 2.2.0 2018-07-25 [1]
#> source
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> Github (r-lib/cli@f786d87)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> Github (christophM/iml@54b2ce2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> local
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.2)
#> Github (berndbischl/ParamHelpers@c2d989c)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> Github (r-lib/rlang@624c5c3)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> Github (pat-s/usethis@0251102)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.0)
#>
#> [1] /Users/pjs/Library/R/3.6/library
#> [2] /Library/Frameworks/R.framework/Versions/3.6/Resources/library