改进r中函数的基准



我目前正在处理一个基准测试问题,我愿意使用R的矢量化来加快计算速度,但是我真的不知道如何提高速度。非常感谢你的帮助。

function(n = 5, lower = 1, upper = 4, add = 1) {
result <- c(lower, upper)
for (i in 3:n) {
result <- append(result, result[[i - 1]] + result[[i - 2]] + add)
}
result
}

我的想法包括lapply/vapply以及某种递归。

不要在循环中使用append。这被称为"培养一个对象"。results对象每次迭代都会变大。这是出了名的低效,因为随着对象变大,你的计算机必须在内存中找到越来越大的地方来存储它,移动它并复制它。

相反,从一开始就将result初始化为其完整长度。所有你不知道的值设置为NA和填补他们的价值观。

# original
foo = function(n = 5, lower = 1, upper = 4, add = 1) {
result <- c(lower, upper)
for (i in 3:n) {
result <- append(result, result[[i - 1]] + result[[i - 2]] + add)
}
result
}
foo()
bar = function(n = 5, lower = 1, upper = 4, add = 1) {
# initialize to full length
result = integer(length = n)
# set first two entries
result[1:2] <- c(lower, upper)
for (i in 3:n) {
# fill in the rest of the blanks
result[i] <- result[i - 1] + result[i - 2] + add
}
result
}
## same result
identical(foo(), bar())
# [1] TRUE

## about 40x faster when n = 1000 (looking at the iterations per second)
bench::mark(foo(n = 1000), bar(n = 1000))
# # A tibble: 2 × 13
#   expression         min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result  
#   <bch:expr>    <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>  
# 1 foo(n = 1000)   1.73ms   1.95ms      497.    3.86MB    39.3    177    14      356ms <dbl [1…
# 2 bar(n = 1000)  51.87µs  53.46µs    18439.   11.81KB     4.13  8936     2      485ms <dbl [1…
# # … with 3 more variables: memory <list>, time <list>, gc <list>

还请注意,对于矢量,您只需要单括号[。使用双括号[[list类对象中提取单个项。

首先,不要使用recursion,这会降低您的性能。另外,您可以使用预分配的vector来存储更新后的值。下面是一个基准

# OP's solution
f <- function(n = 10, lower = 1, upper = 4, add = 1) {
result <- c(lower, upper)
for (i in 3:n) {
result <- append(result, result[[i - 1]] + result[[i - 2]] + add)
}
result
}

# A recursion implementation
f1 <- function(n = 10, lower = 1, upper = 4, add = 1) {
if (n <= 2) {
return(c(lower, upper)[1:n])
}
v <- Recall(n - 1)
c(v, sum(tail(v, 2)) + add)
}
# for-loop version with pre-allocated vector 
f2 <- function(n = 10, lower = 1, upper = 4, add = 1) {
v <- numeric(n)
for (i in 1:n) {
if (i <= 2) {
v[i] <- c(lower, upper)[i]
} else {
v[i] <- v[i - 1] + v[i - 2] + add
}
}
v
}

你会看到

> microbenchmark(f(), f1(), f2())
Unit: microseconds
expr  min   lq    mean median    uq     max neval
f() 10.5 11.0 150.894  11.60 12.30 13738.9   100
f1() 68.1 69.3 170.973  70.95 82.25  6796.3   100
f2()  2.7  2.9 163.506   3.20  3.80 15966.3   100

相关内容

  • 没有找到相关文章

最新更新