r语言 - 优化map2函数中的多项式到data.table



我们想使用多项式定律将航班数量分布在3个不同的级别(V1, V2, V3)上。但是下面的代码对于1000万行代码来说花费了很多时间。是否有优化map2的方法?可能是通过数据表?请注意,我们使用map2逐行应用proba向量。

开始数据
| number_total_flights| 
|:------------------ :|
| 3                   |
| 4                   |
| 5                   |

预期结果
| number_total_flights | V1  | V2  | V3  |
|:-------------------- |:---:|:---:|----:|
| 3                    | 0   | 0   | 3   |
| 4                    | 2   | 1   | 1   |
| 5                    | 1   | 1   | 3   |  

数据
library(dplyr)
library(data.table)
library(purrr)
base <- structure(list(
number_total_flights = c(3L, 4L, 5L)), row.names = c(NA, 3L), class = "data.frame")
proba <- list(
structure(c(0.1, 0.4, 0.5), .Dim = c(1L, 3L)),
structure(c(0.5, 0.2, 0.3), .Dim = c(1L, 3L)),
structure(c(0.2, 0.2, 0.6), .Dim = c(1L, 3L)))

功能治疗
# Calling by map2
distrib_for_each_level <- function(nb_flights, prob){
level <- t(rmultinom(n=1, size=nb_flights, prob=prob))
}
# Function using map2
adding_levels <- function (base, proba){
list_levels <- map2(base$number_total_flights, proba, distrib_for_each_year) %>%
map(as.data.frame) %>% rbindlist()
base <- base %>% cbind(list_levels)
} 
结果

base_with_levels <- adding_levels(base, proba)

与其遍历1000万个案例,不如遍历类别,这样会更快。你可以用rbinom(),它是矢量化的。其思想是多项式的第一个类别的结果是一个具有类别概率的二项;第二类的结果是使用剩余概率和剩余计数等的二项。

我看不懂你的代码(tidyverse代码很聪明,但难以读懂!),所以我将发布全新的代码来说明。这需要几秒钟来计算:

n <- 10000000 # number of cases
m <- 3        # number of categories
probs <- matrix(runif(n*m), n, m)
probs <- probs/rowSums(probs) # Fake multinomial probabilities
counts <- rpois(n, 3)            # Fake multinomial counts
result <- matrix(NA, n, m)    # Result matrix
for (i in 1:(m-1)) {
prob <- probs[, i]/rowSums(probs[, i:m]) # probability of next column
count <- counts - rowSums(result, na.rm = TRUE) # remaining count
result[, i] <- rbinom(n, count, prob)
}
result[, m] <- counts - rowSums(result, na.rm = TRUE)
head(result)
#>      [,1] [,2] [,3]
#> [1,]    0    3    0
#> [2,]    0    3    0
#> [3,]    1    0    1
#> [4,]    1    0    1
#> [5,]    1    2    0
#> [6,]    0    0    0

创建于2022-12-06 with reprex v2.0.2

您可能可以通过更改rowsum调用来只添加正确的列而不是忽略NAs来加快速度,但是我很懒。

最新更新