r语言 - 对多列进行分层以进行交叉验证



我见过很多方法可以通过单个变量对样本进行分层以用于交叉验证。caret包通过createFolds()功能很好地做到这一点。默认情况下,似乎caret将进行分区,以便每个折叠具有大致相同的目标事件率。

我想做的是按目标率和时间分层。我找到了一个可以部分做到这一点的函数,它是splitstackshape包并使用stratified()函数。该函数的问题在于它返回单个样本,在给定条件下它不会将数据拆分为 k 组。

这里有一些要重现的虚拟数据。

set.seed(123)
time = rep(seq(1:10),100)
target = rbinom(n=100, size=1, prob=0.3)
data = as.data.frame(cbind(time,target))
table(data$time,data$target)
0  1
1  60 40
2  80 20
3  80 20
4  60 40
5  80 20
6  80 20
7  60 40
8  60 40
9  70 30
10 80 20

如您所见,目标事件率在一段时间内并不相同。时间 1 为 40%,时间 2 为 20%,依此类推。我想在创建用于交叉验证的折叠时保留这一点。如果我理解正确,插入符号将按整体事件率进行分区。

table(data$target)
0   1 
710 290 

总体上将保留 ~30% 的比率,但随着时间的推移,目标事件率不会保留。

我们可以得到一个这样的样本:

library(splitstackshape)
train.index <- stratified(data,c("target","time"),size=.2)

我需要再重复 4 次以进行 5 次交叉验证,并且需要这样做,以便一旦分配了一行就无法再次分配。我觉得应该已经为此设计了一个功能。有什么想法吗?

我知道这篇文章很旧,但我只是遇到了同样的问题,我找不到其他解决方案。如果其他人需要答案,这是我正在实施的解决方案。

library(data.table)
mystratified <- function(indt, group, NUM_FOLDS) {
indt <- setDT(copy(indt))
if (is.numeric(group)) 
group <- names(indt)[group]
temp_grp <- temp_ind <- NULL
indt[, `:=`(temp_ind, .I)]
indt[, `:=`(temp_grp, do.call(paste, .SD)), .SDcols = group]
samp_sizes <- indt[, .N, by = group]
samp_sizes[, `:=`(temp_grp, do.call(paste, .SD)), .SDcols = group]
inds <- split(indt$temp_ind, indt$temp_grp)[samp_sizes$temp_grp]
z = unlist(inds,use.names=F)
model_folds <- suppressWarnings(split(z, 1:NUM_FOLDS))
}

这基本上是对splitstackshape::stratified的重写.它的工作原理如下,为每个折叠提供验证 indeces 列表作为输出。

myfolds = mystratified(indt = data, group = colnames(data), NUM_FOLDS = 5)
str(myfolds)
List of 5
$ 1: int [1:200] 1 91 181 261 351 441 501 591 681 761 ...
$ 2: int [1:200] 41 101 191 281 361 451 541 601 691 781 ...
$ 3: int [1:200] 51 141 201 291 381 461 551 641 701 791 ...
$ 4: int [1:200] 61 151 241 301 391 481 561 651 741 801 ...
$ 5: int [1:200] 81 161 251 341 401 491 581 661 751 841 ...

因此,例如,每个折叠的训练和验证数据是:

# first fold
train = data[-myfolds[[1]],]
valid = data[myfolds[[1]],]
# second fold
train = data[-myfolds[[2]],]
valid = data[myfolds[[2]],]
# etc...

最新更新