在numpy中获取唯一行位置的更快方法是什么



我有一个唯一行列表和另一个更大的数据数组(例如称为test_rows)。 我想知道是否有一种更快的方法来获取数据中每个唯一行的位置。 我能想到的最快的方法是...

import numpy

uniq_rows = numpy.array([[0, 1, 0],
[1, 1, 0],
[1, 1, 1],
[0, 1, 1]])
test_rows = numpy.array([[0, 1, 1],
[0, 1, 0],
[0, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 1, 1],
[0, 1, 1],
[1, 1, 1],
[1, 1, 0],
[1, 1, 1],
[0, 1, 0],
[0, 0, 0],
[1, 1, 0]])
# this gives me the indexes of each group of unique rows
for row in uniq_rows.tolist():
print row, numpy.where((test_rows == row).all(axis=1))[0]

这打印...

[0, 1, 0] [ 1  4 10]
[1, 1, 0] [ 3  8 12]
[1, 1, 1] [7 9]
[0, 1, 1] [0 5 6]

有没有更好或更数字(不确定这个词是否存在)的方法可以做到这一点? 我正在寻找一个 numpy 组函数,但找不到它。 基本上,对于任何传入的数据集,我需要最快的方法来获取该数据集中每个唯一行的位置。 传入数据集并不总是具有每个唯一行或相同的数字。

编辑: 这只是一个简单的例子。 在我的应用程序中,数字不仅仅是零和一,它们可以是 0 到 32000 之间的任何位置。 uniq 行的大小可能在 4 到 128 行之间,test_rows 的大小可能在数十万行之间。

Numpy

从 numpy的 1.13 版开始,您可以使用 numpy.unique,就像np.unique(test_rows, return_counts=True, return_index=True, axis=1)

熊猫

df = pd.DataFrame(test_rows)
uniq = pd.DataFrame(uniq_rows)

优衣空

0   1   2
0   0   1   0
1   1   1   0
2   1   1   1
3   0   1   1

或者,您可以从传入的数据帧自动生成唯一行

uniq_generated = df.drop_duplicates().reset_index(drop=True)

收益 率

0   1   2
0   0   1   1
1   0   1   0
2   0   0   0
3   1   1   0
4   1   1   1

然后寻找它

d = dict()
for idx, row in uniq.iterrows():
d[idx] = df.index[(df == row).all(axis=1)].values

这与您的where方法大致相同

d

{0: array([ 1,  4, 10], dtype=int64),
1: array([ 3,  8, 12], dtype=int64),
2: array([7, 9], dtype=int64),
3: array([0, 5, 6], dtype=int64)}

这里有很多解决方案,但我添加一个带有香草 numpy 的解决方案。 在大多数情况下,numpy 会比列表推导和字典更快,尽管如果使用大型数组,数组广播可能会导致内存出现问题。

np.where((uniq_rows[:, None, :] == test_rows).all(2))

非常简单,嗯? 这将返回唯一行索引的元组和相应的测试行。

(array([0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 3]),
array([ 1,  4, 10,  3,  8, 12,  7,  9,  0,  5,  6]))

工作原理:

(uniq_rows[:, None, :] == test_rows)

使用数组广播将test_rows的每个元素与uniq_rows中的每一行进行比较。 这将生成一个 4x13x3 数组。all用于确定哪些行相等(所有比较都返回 true)。 最后,where返回这些行的索引。

使用 v1.13 的np.unique(从最新文档的source链接下载,https://github.com/numpy/numpy/blob/master/numpy/lib/arraysetops.py#L112-L247)

In [157]: aset.unique(test_rows, axis=0,return_inverse=True,return_index=True)
Out[157]: 
(array([[0, 0, 0],
[0, 1, 0],
[0, 1, 1],
[1, 1, 0],
[1, 1, 1]]),
array([2, 1, 0, 3, 7], dtype=int32),
array([2, 1, 0, 3, 1, 2, 2, 4, 3, 4, 1, 0, 3], dtype=int32))
In [158]: a,b,c=_
In [159]: c
Out[159]: array([2, 1, 0, 3, 1, 2, 2, 4, 3, 4, 1, 0, 3], dtype=int32)
In [164]: from collections import defaultdict
In [165]: dd = defaultdict(list)
In [166]: for i,v in enumerate(c):
...:     dd[v].append(i)
...:     
In [167]: dd
Out[167]: 
defaultdict(list,
{0: [2, 11],
1: [1, 4, 10],
2: [0, 5, 6],
3: [3, 8, 12],
4: [7, 9]})

或使用唯一行(作为可哈希元组)索引字典:

In [170]: dd = defaultdict(list)
In [171]: for i,v in enumerate(c):
...:     dd[tuple(a[v])].append(i)
...:     
In [172]: dd
Out[172]: 
defaultdict(list,
{(0, 0, 0): [2, 11],
(0, 1, 0): [1, 4, 10],
(0, 1, 1): [0, 5, 6],
(1, 1, 0): [3, 8, 12],
(1, 1, 1): [7, 9]})

这将完成这项工作:

import numpy as np
uniq_rows = np.array([[0, 1, 0],
[1, 1, 0],
[1, 1, 1],
[0, 1, 1]])
test_rows = np.array([[0, 1, 1],
[0, 1, 0],
[0, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 1, 1],
[0, 1, 1],
[1, 1, 1],
[1, 1, 0],
[1, 1, 1],
[0, 1, 0],
[0, 0, 0],
[1, 1, 0]])
indices=np.where(np.sum(np.abs(np.repeat(uniq_rows,len(test_rows),axis=0)-np.tile(test_rows,(len(uniq_rows),1))),axis=1)==0)[0]
loc=indices//len(test_rows)
indices=indices-loc*len(test_rows)
res=[[] for i in range(len(uniq_rows))]
for i in range(len(indices)):
res[loc[i]].append(indices[i])
print(res)
[[1, 4, 10], [3, 8, 12], [7, 9], [0, 5, 6]]

这将适用于所有情况,包括并非uniq_rows中的所有行都存在于test_rows中的情况。但是,如果您以某种方式事先知道所有这些都存在,则可以更换零件

res=[[] for i in range(len(uniq_rows))]
for i in range(len(indices)):
res[loc[i]].append(indices[i])

仅使用行:

res=np.split(indices,np.where(np.diff(loc)>0)[0]+1)

从而完全避免循环。

不是很"numpythonic",但需要一点前期成本,我们可以用键作为行的元组和索引列表来制作一个字典:

test_rowsdict = {}
for i,j in enumerate(test_rows):
test_rowsdict.setdefault(tuple(j),[]).append(i)
test_rowsdict
{(0, 0, 0): [2, 11],
(0, 1, 0): [1, 4, 10],
(0, 1, 1): [0, 5, 6],
(1, 1, 0): [3, 8, 12],
(1, 1, 1): [7, 9]}

然后,您可以根据uniq_rows进行过滤,并快速查找字典:test_rowsdict[tuple(row)]

out = []
for i in uniq_rows:
out.append((i, test_rowsdict.get(tuple(i),[])))

对于您的数据,我只查找就得到了 16us,构建和查找得到了 66us,而您的 np.where 解决方案则得到了 95us。

方法 #1

这里有一种方法,不确定"NumPythonic-ness"的水平,尽管这是一个棘手的问题 -

def get1Ds(a, b): # Get 1D views of each row from the two inputs
# check that casting to void will create equal size elements
assert a.shape[1:] == b.shape[1:]
assert a.dtype == b.dtype
# compute dtypes
void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
# convert to 1d void arrays
a = np.ascontiguousarray(a)
b = np.ascontiguousarray(b)
a_void = a.reshape(a.shape[0], -1).view(void_dt).ravel()
b_void = b.reshape(b.shape[0], -1).view(void_dt).ravel()
return a_void, b_void
def matching_row_indices(uniq_rows, test_rows):
A, B = get1Ds(uniq_rows, test_rows)
validA_mask = np.in1d(A,B)
sidx_A = A.argsort()
validA_mask = validA_mask[sidx_A]    
sidx = B.argsort()
sortedB = B[sidx]
split_idx = np.flatnonzero(sortedB[1:] != sortedB[:-1])+1
all_split_indx = np.split(sidx, split_idx)
match_mask = np.in1d(B,A)[sidx]
valid_mask = np.logical_or.reduceat(match_mask, np.r_[0, split_idx])    
locations = [e for i,e in enumerate(all_split_indx) if valid_mask[i]]
return uniq_rows[sidx_A[validA_mask]], locations    

改进范围(性能):

  1. np.split可以替换为使用slicing进行拆分的 for 循环。
  2. np.r_可以用np.concatenate代替。

示例运行 -

In [331]: unq_rows, idx = matching_row_indices(uniq_rows, test_rows)
In [332]: unq_rows
Out[332]: 
array([[0, 1, 0],
[0, 1, 1],
[1, 1, 0],
[1, 1, 1]])
In [333]: idx
Out[333]: [array([ 1,  4, 10]),array([0, 5, 6]),array([ 3,  8, 12]),array([7, 9])]

方法#2

另一种击败前一个设置开销并利用其get1Ds的方法将是 -

A, B = get1Ds(uniq_rows, test_rows)
idx_group = []
for row in A:
idx_group.append(np.flatnonzero(B == row))

numpy_indexed包(免责声明:我是它的作者)是为了以优雅和高效的方式解决此类问题而创建的:

import numpy_indexed as npi
indices = np.arange(len(test_rows))
unique_test_rows, index_groups = npi.group_by(test_rows, indices)

如果你不关心所有行的索引,而只关心test_rows中存在的索引,npi 也有一堆简单的方法来解决这个问题; f.i:

subset_indices = npi.indices(unique_test_rows, unique_rows)

作为旁注;看看 npi 库中的示例可能会很有用;根据我的经验,大多数时候人们会问这样的问题,这些分组索引只是达到目的的一种手段,而不是计算的最终目标。使用 npi 中的功能,您可以更有效地实现最终目标,而无需显式计算这些索引。你愿意给你的问题提供更多的背景吗?

编辑:如果你的数组确实这么大,并且总是由少量带有二进制值的列组成,那么用以下编码包装它们可能会进一步提高效率:

def encode(rows):
return (rows * [[2**i for i in range(rows.shape[1])]]).sum(axis=1, dtype=np.uint8)

相关内容

最新更新