我正在编写的一些R代码中有一行速度很慢。它使用apply命令在4维数组中应用logSumExp。我想知道有没有办法加快速度!
Reprex:(这可能需要10秒或更长时间才能运行(
library(microbenchmark)
library(matrixStats)
array4d <- array( runif(5*500*50*5 ,-1,0),
dim = c(5, 500, 50, 5) )
microbenchmark(
result <- apply(array4d, c(1,2,3), logSumExp)
)
非常感谢您的建议!
rowSums
是应用程序的一个不太通用的版本,它在相加时针对速度进行了优化,因此可以用于加快计算速度。如果在NA
和NaN
之间的计算中保持差异很重要,请注意帮助文件?rowSums
中的警告。
library(microbenchmark)
library(matrixStats)
array4d <- array( runif(5*500*50*5 ,-1,0),
dim = c(5, 500, 50, 5) )
microbenchmark(
result <- apply(array4d, c(1,2,3), logSumExp),
result2 <- log(rowSums(exp(array4d), dims=3))
)
# Unit: milliseconds
# expr min lq mean median uq max neval
# result <- apply(array4d, c(1, 2, 3), logSumExp) 249.4757 274.8227 305.24680 297.30245 328.90610 405.5038 100
# result2 <- log(rowSums(exp(array4d), dims = 3)) 31.8783 32.7493 35.20605 33.01965 33.45205 133.3257 100
all.equal(result, result2)
#TRUE
这导致我的电脑速度提高了9倍
@Miff的另一个很好的解决方案是,由于产生了无穷大的数据集,导致我的代码与某些数据集崩溃,我最终发现这是由于下溢问题,可以通过使用"logSumExp技巧"来避免:https://www.xarg.org/2016/06/the-log-sum-exp-trick-in-machine-learning/
从@Miff的代码和Rapply()
函数中获得灵感,我制作了一个新函数,在避免下溢问题的同时提供更快的计算。然而,并没有@Miff的解决方案那么快。张贴以防它帮助其他
apply_logSumExp <- function (X) {
MARGIN <- c(1, 2, 3) # fixing the margins as have not tested other dims
dl <- length(dim(X)) # get length of dim
d <- dim(X) # get dim
dn <- dimnames(X) # get dimnames
ds <- seq_len(dl) # makes sequences of length of dims
d.call <- d[-MARGIN] # gets index of dim not included in MARGIN
d.ans <- d[MARGIN] # define dim for answer array
s.call <- ds[-MARGIN] # used to define permute
s.ans <- ds[MARGIN] # used to define permute
d2 <- prod(d.ans) # length of results object
newX <- aperm(X, c(s.call, s.ans)) # permute X such that dims omitted from calc are first dim
dim(newX) <- c(prod(d.call), d2) # voodoo. Preserves ommitted dim dimension but collapses the rest into 1
maxes <- colMaxs(newX)
ans <- maxes + log(colSums(exp( sweep(newX, 2, maxes, "-"))) )
ans <- array(ans, d.ans)
return(ans)
}
> microbenchmark(
+ res1 <- apply(array4d, c(1,2,3), logSumExp),
+ res2 <- log(rowSums(exp(array4d), dims=3)),
+ res3 <- apply_logSumExp(array4d)
+ )
Unit: milliseconds
expr min lq mean median uq max
res1 <- apply(array4d, c(1, 2, 3), logSumExp) 176.286670 213.882443 247.420334 236.44593 267.81127 486.41072
res2 <- log(rowSums(exp(array4d), dims = 3)) 4.664907 5.821601 7.588448 5.97765 7.47814 30.58002
res3 <- apply_logSumExp(array4d) 12.119875 14.673011 19.635265 15.20385 18.30471 90.59859
neval cld
100 c
100 a
100 b