我目前正在处理一个基准测试问题,我愿意使用R的矢量化来加快计算速度,但是我真的不知道如何提高速度。非常感谢你的帮助。
function(n = 5, lower = 1, upper = 4, add = 1) {
result <- c(lower, upper)
for (i in 3:n) {
result <- append(result, result[[i - 1]] + result[[i - 2]] + add)
}
result
}
我的想法包括lapply/vapply以及某种递归。
不要在循环中使用append
。这被称为"培养一个对象"。results
对象每次迭代都会变大。这是出了名的低效,因为随着对象变大,你的计算机必须在内存中找到越来越大的地方来存储它,移动它并复制它。
相反,从一开始就将result
初始化为其完整长度。所有你不知道的值设置为NA
和填补他们的价值观。
# original
foo = function(n = 5, lower = 1, upper = 4, add = 1) {
result <- c(lower, upper)
for (i in 3:n) {
result <- append(result, result[[i - 1]] + result[[i - 2]] + add)
}
result
}
foo()
bar = function(n = 5, lower = 1, upper = 4, add = 1) {
# initialize to full length
result = integer(length = n)
# set first two entries
result[1:2] <- c(lower, upper)
for (i in 3:n) {
# fill in the rest of the blanks
result[i] <- result[i - 1] + result[i - 2] + add
}
result
}
## same result
identical(foo(), bar())
# [1] TRUE
## about 40x faster when n = 1000 (looking at the iterations per second)
bench::mark(foo(n = 1000), bar(n = 1000))
# # A tibble: 2 × 13
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result
# <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm> <list>
# 1 foo(n = 1000) 1.73ms 1.95ms 497. 3.86MB 39.3 177 14 356ms <dbl [1…
# 2 bar(n = 1000) 51.87µs 53.46µs 18439. 11.81KB 4.13 8936 2 485ms <dbl [1…
# # … with 3 more variables: memory <list>, time <list>, gc <list>
还请注意,对于矢量,您只需要单括号[
。使用双括号[[
从list
类对象中提取单个项。
首先,不要使用recursion
,这会降低您的性能。另外,您可以使用预分配的vector来存储更新后的值。下面是一个基准
# OP's solution
f <- function(n = 10, lower = 1, upper = 4, add = 1) {
result <- c(lower, upper)
for (i in 3:n) {
result <- append(result, result[[i - 1]] + result[[i - 2]] + add)
}
result
}
# A recursion implementation
f1 <- function(n = 10, lower = 1, upper = 4, add = 1) {
if (n <= 2) {
return(c(lower, upper)[1:n])
}
v <- Recall(n - 1)
c(v, sum(tail(v, 2)) + add)
}
# for-loop version with pre-allocated vector
f2 <- function(n = 10, lower = 1, upper = 4, add = 1) {
v <- numeric(n)
for (i in 1:n) {
if (i <= 2) {
v[i] <- c(lower, upper)[i]
} else {
v[i] <- v[i - 1] + v[i - 2] + add
}
}
v
}
你会看到
> microbenchmark(f(), f1(), f2())
Unit: microseconds
expr min lq mean median uq max neval
f() 10.5 11.0 150.894 11.60 12.30 13738.9 100
f1() 68.1 69.3 170.973 70.95 82.25 6796.3 100
f2() 2.7 2.9 163.506 3.20 3.80 15966.3 100