矢量化 R for 循环的有效方法



如何对以下 R 代码进行矢量化以减少计算时间?

q = matrix(0,n,p)
for(u in 1 : n){
q1 <- matrix(0,p,1)
for(iprime in 1 : n){
for(i in 1 : n){
if(cause[iprime]==1 & cause[i]>1 & (time[i]<time[u]) & (time[u] <= time[iprime])){
q1 = q1 + (covs[i,] - S1byS0hat[iprime,])*G[iprime]/G[i]*expz[i]/S0hat[iprime]
}
}
}
q[u,] = q1/(m*m)
}

以下值可以用作示例:

n = 2000
m = 500
p=3
G = runif(n)
time = runif(n,0.01,5)
cause = c(rep(0,600),rep(1,1000),rep(2,400))
covs = matrix(rnorm(n*p),n,p)
S1byS0hat = matrix(rnorm(n*p),n,p)
S0hat = rnorm(n)
expz = rnorm(n)

对解决方案进行基准测试:

coeff <- 10
n = 20 * coeff
m = 500
p = 3
G = runif(n)
time = runif(n, 0.01, 5)
cause = c(rep(0, 6 * coeff), rep(1, 10 * coeff), rep(2, 4 * coeff))
covs = matrix(rnorm(n * p), n, p)
S1byS0hat = matrix(rnorm(n * p), n, p)
S0hat = rnorm(n)
expz = rnorm(n)
system.time({
q = matrix(0,n,p)
for(u in 1 : n){
q1 <- matrix(0,p,1)
for(iprime in 1 : n){
for(i in 1 : n){
if(cause[iprime]==1 & cause[i]>1 & (time[i]<time[u]) & (time[u] <= time[iprime])){
q1 = q1 + (covs[i,] - S1byS0hat[iprime,])*G[iprime]/G[i]*expz[i]/S0hat[iprime]
}
}
}
q[u,] = q1/(m*m)
}
})

在我的计算机上需要 9 秒(使用coeff = 10而不是 100,我们可以稍后为其他解决方案增加它(。


第一个解决方案是预先计算一些东西:

q2 = matrix(0, n, p)
c1 <- G / S0hat
c2 <- expz / G
for (u in 1:n) {
q1 <- rep(0, p)
ind_iprime <- which(cause == 1 & time[u] <= time)
ind_i <- which(cause > 1 & time < time[u])
for (iprime in ind_iprime) {
for (i in ind_i) {
q1 = q1 + (covs[i, ] - S1byS0hat[iprime, ]) * c1[iprime] * c2[i]
}
}
q2[u, ] = q1
}
q2 <- q2 / (m * m)

对于系数 =10,这需要 0.3 秒,对于系数 = 100,这需要 6 分钟。


然后,您可以矢量化至少一个循环:

q3 <- matrix(0, n, p)
c1 <- G / S0hat
c2 <- expz / G
covs_c2 <- sweep(covs, 1, c2, '*')
S1byS0hat_c1 <- sweep(S1byS0hat, 1, c1, '*')
for (u in 1:n) {
q1 <- rep(0, p)
ind_iprime <- which(cause == 1 & time[u] <= time)
ind_i <- which(cause > 1 & time < time[u])
for (iprime in ind_iprime) {
q1 <- q1 + colSums(covs_c2[ind_i, , drop = FALSE]) * c1[iprime] - 
S1byS0hat_c1[iprime, ] * sum(c2[ind_i])
}
q3[u, ] <- q1
}
q3 <- q3 / (m * m)

这只需要 15 秒。


如果您关心进一步的性能,一个好的策略可能是在 Rcpp 中重新编码,尤其是为了避免大量内存分配。

最新更新