r-正在寻找跨多维数组实现logSumExp的更快方法



我正在编写的一些R代码中有一行速度很慢。它使用apply命令在4维数组中应用logSumExp。我想知道有没有办法加快速度!

Reprex:(这可能需要10秒或更长时间才能运行(

library(microbenchmark)
library(matrixStats)
array4d <- array( runif(5*500*50*5 ,-1,0),
dim = c(5, 500, 50, 5) )
microbenchmark(
result <- apply(array4d, c(1,2,3), logSumExp)
)

非常感谢您的建议!

rowSums是应用程序的一个不太通用的版本,它在相加时针对速度进行了优化,因此可以用于加快计算速度。如果在NANaN之间的计算中保持差异很重要,请注意帮助文件?rowSums中的警告。

library(microbenchmark)
library(matrixStats)
array4d <- array( runif(5*500*50*5 ,-1,0),
dim = c(5, 500, 50, 5) )
microbenchmark(
result <- apply(array4d, c(1,2,3), logSumExp),
result2 <- log(rowSums(exp(array4d), dims=3))
)

# Unit: milliseconds
#                                            expr      min       lq      mean    median        uq      max neval
# result <- apply(array4d, c(1, 2, 3), logSumExp) 249.4757 274.8227 305.24680 297.30245 328.90610 405.5038   100
# result2 <- log(rowSums(exp(array4d), dims = 3))  31.8783  32.7493  35.20605  33.01965  33.45205 133.3257   100
all.equal(result, result2)
#TRUE

这导致我的电脑速度提高了9倍

@Miff的另一个很好的解决方案是,由于产生了无穷大的数据集,导致我的代码与某些数据集崩溃,我最终发现这是由于下溢问题,可以通过使用"logSumExp技巧"来避免:https://www.xarg.org/2016/06/the-log-sum-exp-trick-in-machine-learning/

从@Miff的代码和Rapply()函数中获得灵感,我制作了一个新函数,在避免下溢问题的同时提供更快的计算。然而,并没有@Miff的解决方案那么快。张贴以防它帮助其他

apply_logSumExp <- function (X) {
MARGIN <- c(1, 2, 3) # fixing the margins as have not tested other dims
dl <- length(dim(X)) # get length of dim
d <- dim(X) # get dim
dn <- dimnames(X) # get dimnames
ds <- seq_len(dl) # makes sequences of length of dims
d.call <- d[-MARGIN]    # gets index of dim not included in MARGIN
d.ans <- d[MARGIN]  # define dim for answer array
s.call <- ds[-MARGIN] # used to define permute
s.ans <- ds[MARGIN]     # used to define permute
d2 <- prod(d.ans)   # length of results object

newX <- aperm(X, c(s.call, s.ans)) # permute X such that dims omitted from calc are first dim
dim(newX) <- c(prod(d.call), d2) # voodoo. Preserves ommitted dim dimension but collapses the rest into 1

maxes <- colMaxs(newX)
ans <- maxes + log(colSums(exp( sweep(newX, 2, maxes, "-"))) )
ans <- array(ans, d.ans)

return(ans)
}
> microbenchmark(
+     res1 <- apply(array4d, c(1,2,3), logSumExp),
+     res2 <- log(rowSums(exp(array4d), dims=3)),
+     res3 <- apply_logSumExp(array4d)
+ )
Unit: milliseconds
expr        min         lq       mean    median        uq       max
res1 <- apply(array4d, c(1, 2, 3), logSumExp) 176.286670 213.882443 247.420334 236.44593 267.81127 486.41072
res2 <- log(rowSums(exp(array4d), dims = 3))   4.664907   5.821601   7.588448   5.97765   7.47814  30.58002
res3 <- apply_logSumExp(array4d)  12.119875  14.673011  19.635265  15.20385  18.30471  90.59859
neval cld
100   c
100 a  
100  b 

最新更新