r-mlr3在特定时间点的生存分布估计



我正在寻找一种更快的方法来提取mlr3mlr3proba的预测生存分布
预测过程非常耗时,尤其是使用具有数百个观测值且没有时间变量关联的数据集
是否有任何选项可以不在每次评估整个单独分布,而仅在预定义的分布中进行评估
如果不可能,是否有类似于[randomForestSRC::rfsrc][1]ntimes的选项?

这里是一个使用survivalmodels::akritas的例子,其中在1个时间点的估计持续大约10分钟:

pacman::p_load("survival","mltools","paradox","mlr3misc","mlr3tuning",
"devtools","mlr3extralearners","mlr3proba","mlr3learners",
"survivalmodels","mlr3pipelines", "tictoc", "casebase","distr6")
dat <- survival::rotterdam[,-c(1,2,12,13)]
length(unique(dat$dtime)) # 2215 unique times
set.seed(220311) 
sample.train <- sample(nrow(dat), nrow(dat)*.2)
dat_train <- dat[sample.train, ]
length(unique(dat_train$dtime)) # 558 unique times
sample.test <- c(1:nrow(dat))[which(!c(1:nrow(dat)) %in% sample.train)]
dat_test <- dat[sample.test, ]
length(unique(dat_test$dtime)) # 1875 unique times

task = mlr3proba::TaskSurv$new(id = "dat_train", backend = dat_train,
time = "dtime", event = "death")
search_space <- ps(
lambda = p_dbl(lower = 0, upper = 0.25))
learner.dh <- lrn("surv.akritas", reverse=F)
learner.dh$encapsulate = c(train = "evaluate")
at <- AutoTuner$new(
learner = learner.dh,
search_space = search_space,
resampling = rsmp("cv", folds = 5),
measure = msr("surv.cindex"),
terminator = trm("evals", n_evals = 10), #nevals very low, just for example 
tuner = tnr("random_search")
)
tic()
at$train(task)
toc() #807.46  sec elapsed
tic()
pred.S_t2638 <- 1 - as.numeric(at$predict_newdata(dat_test)$distr$cdf(2638))
toc() #559.5 sec elapsed

嗨,很抱歉花了这么长时间,但现在修复了CRAN:install.packages(c("distr6", "mlr3proba"))。不幸的是,Akritas仍然很慢,我的cpp包括四个循环,这并不好,我会在未来尝试并想出更好的解决方案。但对于实际的预测部分,这现在是<1s。请参阅下面的rfsrc测试,其中还包括自我修复我的包中瓶颈的注意事项。

library(paradox)
library(mlr3extralearners)
library(mlr3tuning)
library(tictoc)
dat <- survival::rotterdam[, -c(1, 2, 12, 13)]
set.seed(220311)
sample.train <- sample(nrow(dat), nrow(dat) * .2)
dat_train <- dat[sample.train, ]
sample.test <- c(1:nrow(dat))[which(!c(1:nrow(dat)) %in% sample.train)]
dat_test <- dat[sample.test, ]

task = mlr3proba::TaskSurv$new(
id = "dat_train", backend = dat_train,
time = "dtime", event = "death"
)
learner = lrn("surv.rfsrc", ntree = to_tune(50, 200))
at <- AutoTuner$new(
learner = learner,
resampling = rsmp("cv", folds = 5),
measure = msr("surv.cindex"),
terminator = trm("evals", n_evals = 10),
tuner = tnr("random_search")
)
tic()
at$train(task)
toc()
#> 15.531 sec elapsed
tic()
pred <- at$predict_newdata(dat_test) # bottleneck is predict.rfsrc
toc()
#> 0.309 sec elapsed
tic()
distr <- pred$distr # bottleneck is support checks
toc()
#> 0.721 sec elapsed
tic()
pred.S_t2638 <- distr$survival(2638) # bottleneck is param6 transform
toc() # 559.5 sec elapsed
#> 0.274 sec elapsed

创建于2022-03-25由reprex包(v2.0.1(

最新更新