我正在尝试通过矩阵中NA
值在每列中的唯一位置对矩阵的行进行分组。
例如,使用以下矩阵:
1, 2, NA, 3 NA
2, 5, NA, 4, 5
3, 2, 1, 0, 7
5, 3, NA, 9, 3
0, 2, 1, 4, 6
答案是:
1, 2, 3, 2, 3
表示有 3 个不同的组,即第 2 行和第 4 行在同一组中。
问题是我无法想出一种快速的方法来实现这一目标。这是我目前的实现:
mat <- matrix(rnorm(10000*100), ncol=100)
mat[sample(length(mat), nrow(mat))] <- NA
getNAgroups <- function(x) {
allnas <- t(!is.na(x))
nacases <- unique(allnas, MARGIN=2)
groups <- numeric(nrow(x))
for(i in 1:ncol(nacases)) {
groups[colMeans(allnas == nacases[,i]) == 1] <- i
}
groups
}
对于我所想到的目的来说,这有点太慢了:
system.time(getNAgroups(mat))
user system elapsed
7.672 1.686 9.386
以下是在 NA 位置列表上使用匹配的一种方法:
mat <- matrix(c(1, 2, NA, 3, NA,
2, 5, NA, 4, 5,
3, 2, 1, 0, 7,
5, 3, NA, 9, 3,
0, 2, 1, 4, 6), 5, byrow = TRUE)
categ <- apply(is.na(mat), 1, which)
match(categ, unique(categ))
我们可以逐行paste
这些值并对其进行match
以获得唯一索引。
vals <- apply(is.na(mat), 1, toString)
match(vals, unique(vals))
#[1] 1 2 3 2 3
如果性能是问题,我会尝试以下代码:
library(dplyr)
getNAgrps = function(df){
df = df %>%
dplyr::mutate(NAgrp = '')
lapply(1:nrow(df),function(i){
df$NAgrp[i] <<- paste0(which(is.na(df[i,])),collapse=",")
})
return(df)
}
此函数将输入作为数据帧。要将矩阵转换为数据帧,请执行以下操作:
library(dplyr)
dat = as_data_frame(mat)
性能如下:
> system.time(getNAgrps(mat))
user system elapsed
0.005 0.000 0.006
让我知道它是否有效。
注意:这将返回字符向量,而不是给出整数组,其中 NA 的位置用逗号分隔。
如果您不介意顺序,可以使用interaction
查找组。
tt <- interaction(as.data.frame(is.na(mat)), drop = TRUE)
unclass(tt)
#[1] 3 2 1 2 1
或者一种性能更高的方法是使用sweep
和rowSums
但最多只能工作 30 列。
tt <- is.na(mat)
tt <- rowSums(sweep(tt, 2, cumprod(rep(2L,ncol(tt))), "*"))
match(tt, unique(tt))
#[1] 1 2 3 2 3
或者您可以使用bit
库,它不是更快,但遵循之前的想法并且适用于许多行,并且在内存有限时可能会有所帮助。
library("bit")
tt <- apply(is.na(mat), 1, as.bit)
match(tt, unique(tt))
#[1] 1 2 3 2 3
#For many columns
tt <- apply(apply(is.na(mat), 1, as.bit), 2, paste, collapse=" ")
match(tt, unique(tt))
#[1] 1 2 3 2 3
或者,packBits
可以像这样使用:
tt <- is.na(mat)
tt <- cbind(tt, matrix(TRUE, nrow(tt), ncol=(8 - ncol(tt) %% 8)))
tt <- packBits(t(tt))
tt <- split(tt, rep(seq_len(nrow(mat)), each=length(tt)/nrow(mat)))
match(tt, unique(tt))
#[1] 1 2 3 2 3
或使用来自PKI
或encryptr
raw2hex
的更高性能的版本。
library(PKI) #or library(encryptr)
tt <- is.na(mat)
tt <- cbind(tt, matrix(TRUE, nrow(tt), ncol=(8 - ncol(tt) %% 8)))
tt <- raw2hex(packBits(t(tt)))
tt <- matrix(tt, ncol = nrow(mat))
tt <- apply(tt, 2, paste, collapse="")
match(tt, unique(tt))
[1] 1 2 3 2 3
set.seed(42)
mat <- matrix(rnorm(10000*100), ncol=100)
mat[sample(length(mat), nrow(mat))] <- NA
getNAgroups_Orig <- function(x) {
allnas <- t(!is.na(x))
nacases <- unique(allnas, MARGIN=2)
groups <- numeric(nrow(x))
for(i in 1:ncol(nacases)) {
groups[colMeans(allnas == nacases[,i]) == 1] <- i
}
groups
}
getNAgroups_GKi <- function(mat) {
tt <- is.na(mat)
tt <- rowSums(sweep(tt, 2, cumprod(rep(2L,ncol(tt))), "*"))
match(tt, unique(tt))
}
getNAgroups_Clemsang <- function(mat) {
categ <- apply(is.na(mat), 1, which)
match(categ, unique(categ))
}
getNAgroups_RonakShah <- function(mat) {
vals <- apply(is.na(mat), 1, toString)
match(vals, unique(vals))
}
library("bit")
getNAgroups_bit <- function(mat) {
tt <- apply(apply(is.na(mat), 1, as.bit), 2, paste, collapse=" ")
match(tt, unique(tt))
}
getNAgroups_GKi2 <- function(mat) {
tt <- is.na(mat)
tt <- cbind(tt, matrix(TRUE, nrow(tt), ncol=(8 - ncol(tt) %% 8)))
tt <- packBits(t(tt))
tt <- split(tt, rep(seq_len(nrow(mat)), each=length(tt)/nrow(mat)))
match(tt, unique(tt))
}
library(PKI) #or library(encryptr)
getNAgroups_GKi3 <- function(mat) {
tt <- is.na(mat)
tt <- cbind(tt, matrix(TRUE, nrow(tt), ncol=(8 - ncol(tt) %% 8)))
tt <- raw2hex(packBits(t(tt)))
tt <- matrix(tt, ncol = nrow(mat))
tt <- apply(tt, 2, paste, collapse="")
match(tt, unique(tt))
}
system.time(getNAgroups_Orig(mat))
# User System verstrichen
# 6.928 1.316 8.244
system.time(getNAgroups_GKi(mat)) ###IS NOT WORKING CORRECT DUE TO TOO MANY COLUMNS
# User System verstrichen
# 0.016 0.000 0.016
system.time(getNAgroups_Clemsang(mat))
# User System verstrichen
# 0.045 0.004 0.049
system.time(getNAgroups_RonakShah(mat))
# User System verstrichen
# 0.347 0.000 0.347
system.time(getNAgroups_bit(mat))
# User System verstrichen
# 0.239 0.000 0.240
system.time(getNAgroups_GKi2(mat))
# User System verstrichen
# 0.119 0.000 0.119
system.time(getNAgroups_GKi3(mat))
# User System verstrichen
# 0.046 0.000 0.046