按不同数组中的出现次数对 numpy 数组的元素进行排序的效率



我有以下代码:

import numpy as np
def suborder(x, y):
pos = np.in1d(x, y, assume_unique=True)
return x[pos]

xy是 1d numpy 整数数组,y的元素是x中元素的子集,并且两个数组都没有重复。结果是y的元素,按照它们在x中出现的顺序。代码给出了我想要的结果。但是中间阵列pos的大小与x相同,并且在许多用例中yx小得多。有没有办法在不分配中间数组pos的情况下更直接地获得结果以节省一些内存?

x未排序。在我的例子中,它的元素是对象的 id,是值 0->len(x),但顺序未指定,并且按分配给每个对象的分数顺序排序。suborder的目的是对具有相同分数顺序的子集进行排序。

x大约有1000万个元素;我有许多不同的y值,有些接近x的大小,一直到只有少数几个元素。

编辑:我对对象的一组分数进行argsort得到了x。我曾设想过最好对所有分数进行一次排序,然后使用该排序对子集施加顺序。实际上,取scores[y],然后argsort它并按该顺序(对于每个y)获取y元素可能更好。

解决方案 1

由于项目在range(0, len(x))并且都是唯一的(即排列),因此您只能预分配一个大小为len(x)(RAM 中的len(x)*4字节)的缓冲区。策略是在对x进行排序后首先构建一次反向索引:

idx = np.array(len(x), dtype=np.int32)      # Can be reused after each sort of `x`
idx[x] = np.arange(len(x), dtype=np.int32)  # Can be filled chunk-by-chunk in a loop if memory matters

然后,您需要过滤y数组,以便所有值都range(0, len(x)).如果已经是这种情况,请跳过此步骤。可以使用yFilt = y[np.logical_and(y >= 0, y < len(x))]完成该操作。由于y可能很大,因此您可以逐块执行此操作。更简单、更快、更节省内存的解决方案是使用 Numba 即时过滤y

然后,您需要计算x[np.sorted(idx[yFilt])]以重新排序y的项目,例如在x中。可以使用以下代码就地完成此操作:

# Should not allocate any temporary arrays
idx.take(yFilt, out=yFilt)
yFilt.sort()
x.take(yFilt, out=yFilt)

之后,yFilt现在像x中的项目一样订购。请注意,您可以改变y以便不执行任何临时数组分配(尽管这意味着在此操作后y代码中的其他内容不会使用它)。

此重新排序算法O(Ny log Ny)Ny = len(y).预计算在O(Nx)时间内运行,Nx = len(x).它需要4 (Nx + Ny)字节空间用于异地实现,4 Nx字节用于不执行分配以重新排序y的就地版本。

<小时 />

解决方案 2

如果前面的解决方案占用了太多内存,那么尽管计算量要大得多,但此解决方案应该是不错的解决方案。它仅使用O(8 Ny)字节(就地实现O(4 Ny)),并在O(Nx log Ny)时间内运行。请注意,输出数组可以预分配一次(并且只能在以后填充),以避免 GC/分配器出现任何问题。

这个想法是在y的排序+过滤版本中对x的每个值执行二叉搜索。值在输出数组中动态追加。这个解决方案要求Numba或Cython快速(尽管可以使用块和np.searchsorted编写复杂的纯Numpy实现)。

import numba as nb
# `out` can be preallocated and passed in parameter to 
# avoid allocations in hot loops
@nb.njit('int32[:](int32[:], int32[:])')
def orderLike(x, y):
sorted = np.sort(y)  # Use y.sort() for an in-place implementation
out = np.empty(len(y), np.int32)
cur = 0
for v in x:
pos = np.searchsorted(sorted, v)
if pos < len(y) and sorted[pos] == v: # Found
out[cur] = v
cur += 1
return out[:cur]

in1d开头为:

if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:
...
mask = np.zeros(len(ar1), dtype=bool)
for a in ar2:
mask |= (ar1 == a)
return mask

换句话说,它对y的每个元素进行相等性测试。 如果您的大小差异不是那么大,那么它使用不同的方法,一种基于连接数组并执行argsort的方法。

我可以想象使用np.flatnonzero(ar1==a)来获取等效的索引,并将它们连接起来。 但这将维护y秩序。

最新更新