r语言 - 使用多个列值作为函数的输入来计算data.table中的新列



我使用数据。表处理包含多个列的数据集。我需要使用其中的几个列值来为每行计算一个特定的新列。我知道我可以在一些简单的函数中使用SDcols功能。但是,当我想使用我自己的函数时,它会更棘手,因为它对列值的处理方式不同。下面是我的例子:

这就是数据的方式。表看起来像:

Training Age  Client1 Stim1.0 Stim1.1  Client2 Stim2.0 Stim2.1 Choice Val.00 Val.01
1:        0   1  absence       0       0  absence       0       0      2      0      0
2:        0   2  absence       0       0  absence       0       0      2      0      0
3:        0   3 Object 1       1       1 Object 2       2       2      1      0      0
4:        0   4 Object 2       2       1 Object 2       2       1      1      0      0
5:        0   5  absence       0       0  absence       0       0      2      0      0
6:        0   6  absence       0       0  absence       0       0      2      0      0
Val.02 alpha.0 Val.10 Val.11 Val.12 alpha.1 V25
1: 0.0000   0.005      0 0.0000 0.0000   0.005  NA
2: 0.0000   0.005      0 0.0000 0.0000   0.005  NA
3: 0.0025   0.005      0 0.0000 0.0025   0.005  NA
4: 0.0050   0.005      0 0.0025 0.0025   0.005  NA
5: 0.0050   0.005      0 0.0025 0.0025   0.005  NA
6: 0.0050   0.005      0 0.0025 0.0025   0.005  NA

函数使用以Stim开头的列的值来选择在计算新值时必须包含哪些以Val开头的列。

StimVal列的数量较低时,例如分别为2和3,我可以使用fcase

来解决它
rawData[,`:=`(Val.Client1=fcase(Stim1.0==0,Val.00,
Stim1.0==1,Val.01,Stim1.0==2,Val.02)+
fcase(Stim1.1==0,Val.10,
Stim1.1==1,Val.11,Stim1.1==2,Val.12),
Val.Clien2=fcase(Stim2.0==0,Val.00,
Stim2.0==1,Val.01,Stim2.0==2,Val.02)+
fcase(Stim2.1==0,Val.10,
Stim2.1==1,Val.11,Stim2.1==2,Val.12))]

但是,我使用的不同数据集的列数不同。因此,我想独立于列数对其进行编码。

我已经设法使用的组合使其工作。SDcols以这种方式应用

numSti<-2,numFeat<-2 # parameters to know the number of columns to expect
rawData[,Val.Client1:=apply(.SD,MARGIN = 1,FUN = function(x){
# I use apply to get a vector with alll the relevant values
x<-as.numeric(x) # for some reason I must force it to be numeric 
Stim1.tmp<-x[1:numSti]+1 # Choose the relevant values for the Stim columns
vals<-x[(numSti*2+1): (numSti*2+numSti*(1+numFeat))] # choose the relevant values for the Val columns
locVal<-Stim1.tmp+(numFeat+1)*(0:(numSti-1)) # map the Stim to the Val columns
return(sum(vals[locVal])) # sum over the chosen values. 
}),.SDcols=patterns("Stim.|Val.")]

这段代码给出了正确的计算。但是它太慢了!你能帮我找到一个更快的解决办法吗?

根据@jblood94的请求:dput(rawData)

输出
as.data.table(structure(list(Age = 1:6, Client1 = c(2L, 2L, 0L, 1L, 2L, 2L), 
Stim1.0 = c(0L, 0L, 1L, 2L, 0L, 0L), Stim1.1 = c(0L, 0L, 
1L, 1L, 0L, 0L), Client2 = c(2L, 2L, 1L, 1L, 2L, 2L), Stim2.0 = c(0L, 
0L, 2L, 2L, 0L, 0L), Stim2.1 = c(0L, 0L, 2L, 1L, 0L, 0L), 
Choice = c(2L, 2L, 1L, 1L, 2L, 2L), Val.00 = c(0, 0, 0, 0, 
0, 0), Val.01 = c(0, 0, 0, 0, 0, 0), Val.02 = c(0, 0, 0.0025, 
0.005, 0.005, 0.005), alpha.0 = c(0.005, 0.005, 0.005, 0.005, 
0.005, 0.005), Val.10 = c(0, 0, 0, 0, 0, 0), Val.11 = c(0, 
0, 0, 0.0025, 0.0025, 0.0025), Val.12 = c(0, 0, 0.0025, 0.0025, 
0.0025, 0.0025), alpha.1 = c(0.005, 0.005, 0.005, 0.005, 
0.005, 0.005), V25 = c(NA, NA, NA, NA, NA, NA)), row.names = c(NA, 
-6L), class = c("data.table", "data.frame")))

也许这个用户函数会有所帮助:

fun <- function(data, vals) {
stimvals <- Map(function(V, levels) {
match(paste0(sub("Stim[0-9]+\.([0-9]+)", "Val.\1", V), levels),
names(data))
}, setNames(nm = vals), lapply(vals, function(z) data[[z]]))
Reduce(`+`, lapply(stimvals, function(z) as.data.frame(data)[cbind(seq_along(z), z)]))
}
stims <- grep("Stim.*", names(rawData), value = TRUE)
stims <- split(stims, sub("\..*", "", stims))
names(stims) <- sub(".*([0-9]+)$", "Val.Client\1", names(stims))
stims
# $Val.Client1
# [1] "Stim1.0" "Stim1.1"
# $Val.Client2
# [1] "Stim2.0" "Stim2.1"
rawData[, names(stims) := lapply(stims, fun, data = .SD)]
rawData
#      Age Client1 Stim1.0 Stim1.1 Client2 Stim2.0 Stim2.1 Choice Val.00 Val.01 Val.02 alpha.0 Val.10 Val.11 Val.12 alpha.1    V25 Val.Client1 Val.Client2
#    <int>   <int>   <int>   <int>   <int>   <int>   <int>  <int>  <num>  <num>  <num>   <num>  <num>  <num>  <num>   <num> <lgcl>       <num>       <num>
# 1:     1       2       0       0       2       0       0      2      0      0 0.0000   0.005      0 0.0000 0.0000   0.005     NA      0.0000      0.0000
# 2:     2       2       0       0       2       0       0      2      0      0 0.0000   0.005      0 0.0000 0.0000   0.005     NA      0.0000      0.0000
# 3:     3       0       1       1       1       2       2      1      0      0 0.0025   0.005      0 0.0000 0.0025   0.005     NA      0.0000      0.0050
# 4:     4       1       2       1       1       2       1      1      0      0 0.0050   0.005      0 0.0025 0.0025   0.005     NA      0.0075      0.0075
# 5:     5       2       0       0       2       0       0      2      0      0 0.0050   0.005      0 0.0025 0.0025   0.005     NA      0.0000      0.0000
# 6:     6       2       0       0       2       0       0      2      0      0 0.0050   0.005      0 0.0025 0.0025   0.005     NA      0.0000      0.0000

这是基于Stim*子级别(例如,Stim2.1.1)和相应Stim2.1中的动态查找Val*名称。也就是说,如果Stim2.1的值为0,那么它应该从Val.10中提取。根据相应的Val*列对数据进行索引,然后将数据分配回Stim*名称的第一个数字(Stim2.12)。

因此,上述stims变量的生成是关键:将相应的Stim#.#*变量组合在一起(它们将被子集/求和)并适当命名。

这似乎是工作。这个想法是基于Stim的值对Val列进行索引,将结果放入维度为c(nrow(rawData), numSti, 2)的数组中,然后对数组进行排序,以便与colSums沿着第二维求和。

numSti <- 2
numFeat <- 2
rawData[
,paste0("Val.Client", 1:2) := as.data.table(
colSums(
aperm(
array(
unlist(rawData[,.SD , .SDcols = grep("Val.", names(rawData), value = TRUE)])[
1:.N + (unlist(rawData[,.SD , .SDcols = grep("Stim", names(rawData), value = TRUE)]) + rep(c(0, numFeat + 1), each = .N))*.N
],
c(.N, numSti, 2)
),
c(2, 1, 3)
)
)
)
]
rawData[, Val.Client1:Val.Client2]
#>    Val.Client1 Val.Client2
#> 1:      0.0000      0.0000
#> 2:      0.0000      0.0000
#> 3:      0.0000      0.0050
#> 4:      0.0075      0.0075
#> 5:      0.0000      0.0000
#> 6:      0.0000      0.0000

对于客户端数量也可以推广:

numCli <- 2
rawData[
,paste0("Val.Client", 1:numCli) := as.data.table(
colSums(
aperm(
array(
unlist(rawData[,.SD , .SDcols = grep("Val.", names(rawData), value = TRUE)])[
1:.N + (
unlist(
rawData[,.SD , .SDcols = grep("Stim", names(rawData), value = TRUE)]
) + rep(
seq(0, by = numFeat + 1, length.out = numCli),
each = .N
)
)*.N
],
c(.N, numSti, numCli)
),
c(2, 1, 3)
)
)
)
]

最新更新