r语言 - makeclalsif,带有MLR- ID列从任务中排除



我的数据中有一个ID列。我从Traintask中删除此列,因为它不是一个功能。但是,我想将预测概率与数据中的实际ID编号联系起来。

我要匹配的列是init_actt,它是data.frame

中的ID号

我的代码如下:

# Make classif tasks
trainTask <- makeClassifTask(
  data = train.df %>% dplyr::select(-Init_Acct) # Init_Acct is the ID I want to match
  , id
  , target = "READMIT_FLAG"
  , positive = "Y"
)
testTask <- makeClassifTask(
  data = test.df %>% dplyr::select(-Init_Acct)
  , target = "READMIT_FLAG"
  , positive = "Y"
)
# Check trainTask and testTask
trainTask <- smote(trainTask, rate = 6)
testTask <- smote(testTask, rate = 6)
# GBM ####
getParamSet('classif.gbm')
gbm.learner <- makeLearner(
  'classif.gbm'
  , predict.type = 'prob'
)
plotLearnerPrediction(gbm.learner, trainTask)
# Tune model
gbm.tune.ctl <- makeTuneControlRandom(maxit = 50L)
# Cross validation
gbm.cv <- makeResampleDesc("CV", iters = 3L)
# Grid search - Hyper-parameter space
gbm.par <- makeParamSet(
  makeDiscreteParam('distribution', values = 'bernoulli')
  , makeIntegerParam('n.trees', lower = 10, upper = 1000)
  , makeIntegerParam('interaction.depth', lower = 2, upper = 10)
  , makeIntegerParam('n.minobsinnode', lower = 10, upper = 80)
  , makeNumericParam('shrinkage', lower = 0.01, upper = 1)
)
# Tune Hyper-parameters
parallelMap::parallelStartSocket(
  4
  , level = "mlr.tuneParams"
)
gbm.tune <- tuneParams(
  learner = gbm.learner
  , task = trainTask
  , resampling = gbm.cv
  , measures = acc
  , par.set = gbm.par
  , control = gbm.tune.ctl
)
parallelMap::parallelStop()
# Check CV acc
gbm.tune$y
gbm.tune$x
# Set hyper-parameters
gbm.ps <- setHyperPars(
  learner = gbm.learner
  , par.vals = gbm.tune$x
)
# Train gbm
gbm.train <- train(gbm.ps, testTask)
plotLearningCurve(
  generateLearningCurveData(
    gbm.learner
    , testTask
  )
)
# Predict
gbm.pred <- predict(gbm.train, testTask)
plotResiduals(gbm.pred)
# Create submission file
gbm.submit <- data.frame(
  gbm.pred$data
)
head(gbm.submit, 5)
table(gbm.submit$truth, gbm.submit$response)
# Confusion Matrix
calculateConfusionMatrix(gbm.pred)
calculateROCMeasures(gbm.pred)
conf_mat_f1_func(gbm.pred)
perf_plots_func(Model = gbm.pred)

数据看起来像这样:

glimpse(train.df)
Observations: 33,031
Variables: 17
$ Init_Acct         <chr> "12345678", "87654321", "81734650", "11223344", "1422...
$ Init_LOS          <dbl> 2, 2, 5, 1, 12, 3, 16, 9, 3, 14, 1, 1, 4, 7, 4, 1, 3,...
$ Init_LACE         <dbl> 2, 7, 7, 9, 8, 8, 11, 10, 8, 10, 5, 4, 8, 8, 4, 5, 3,...
$ READMIT_FLAG      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, Y,...
$ Init_Hosp_Pvt     <fct> PRIVATE, HOSPITALIST, HOSPITALIST, HOSPITALIST, PRIVA...
$ Age_at_Init_Admit <dbl> 37, 26, 56, 67, 51, 53, 48, 57, 92, 67, 72, 22, 60, 6...
$ Age_Bucket        <fct> 3, 2, 5, 6, 5, 5, 4, 5, 9, 6, 7, 2, 6, 6, 7, 6, 9, 5,...
$ Gender            <fct> F, M, F, M, M, F, M, F, M, M, M, F, M, F, F, F, F, M,...
$ Init_ROM          <dbl> 1, 1, 3, 4, 2, 3, 1, 3, 4, 1, 1, 1, 1, 1, 2, 1, 2, 4,...
$ Init_SOI          <dbl> 1, 1, 3, 4, 3, 3, 3, 3, 3, 2, 1, 1, 2, 2, 2, 2, 2, 4,...
$ Has_Diabetes      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,...
$ reduced_dispo     <fct> AHR, AHR, AHR, ATH, ATW, ATW, ATW, AHR, AHR, ATW, AHR...
$ reduced_hsvc      <fct> SUR, MED, MED, Other, MED, MED, MED, MED, MED, MED, M...
$ reduced_abucket   <fct> 3, 2, 5, 6, 5, 5, 4, 5, Other, 6, 7, 2, 6, 6, 7, 6, O...
$ reduced_spclty    <fct> Other, HOSIM, HOSIM, HOSIM, Other, HOSIM, HOSIM, HOSI...
$ reduced_lihn      <fct> Other, Medical, Pneumonia, Medical, Medical, Medical,...
$ discharge_month   <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...

输出:

glimpse(gbm.submit)
Observations: 23,896
Variables: 5
$ id       <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,...
$ truth    <fct> Y, N, N, N, N, N, N, Y, N, N, N, N, N, N, Y, N, Y, N, N, N, N,...
$ prob.N   <dbl> 0.9150623, 0.7914781, 0.9661108, 0.9198683, 0.8502536, 0.94376...
$ prob.Y   <dbl> 0.08493774, 0.20852192, 0.03388919, 0.08013167, 0.14974644, 0....
$ response <fct> N, N, N, N, N, N, N, Y, N, N, N, N, N, N, N, N, Y, N, N, N, N,...

MLR的predict()保留行名并在其输出中产生额外的id列,以索引原始数据。您可以使用任何一个将预测与其原始示例ID相关联。

设置

library(tidyverse)
library(mlr)
## Add a custom sample ID column
iris2 <- iris %>% mutate(Init_Acct = paste0("Acct",1:n()))
lrn <- makeLearner( "classif.gbm", predict.type="prob" )

选项1:使用ID列索引原始数据

## Drop the custom column as in your original post
task <- makeClassifTask( data=select(iris2, -Init_Acct), target="Species" )
mdl <- train( lrn, task )
pred <- predict( mdl, task )
## Join against the original data by the "id" column
iris2 %>% mutate(id=1:n()) %>% select(Init_Acct, id) %>% 
    inner_join( pred$data ) %>% select(-id)
#   Init_Acct  truth prob.setosa prob.versicolor prob.virginica response
# 1     Acct1 setosa   0.9998775    1.225043e-04   2.836942e-08   setosa
# 2     Acct2 setosa   0.9999652    3.468690e-05   1.118015e-07   setosa
# 3     Acct3 setosa   0.9999538    4.611200e-05   8.389636e-08   setosa

选项2:使用Rownames

## Store the sample names into rownames
task <- makeClassifTask( data=column_to_rownames(iris2, "Init_Acct"),
                         target="Species" )
mdl <- train( lrn, task )
pred <- predict( mdl, task )
## Pull the rownames back out into their own column
pred$data %>% rownames_to_column( "Init_Acct" ) %>% select(-id)
#     Init_Acct      truth  prob.setosa prob.versicolor prob.virginica   response
# 1       Acct1     setosa 9.999266e-01    7.331226e-05   6.889259e-08     setosa
# 2       Acct2     setosa 9.999751e-01    2.462816e-05   3.154618e-07     setosa
# 3       Acct3     setosa 9.999656e-01    3.421543e-05   1.449155e-07     setosa

相关内容

  • 没有找到相关文章

最新更新