我有以下代码:
import numpy as np
def suborder(x, y):
pos = np.in1d(x, y, assume_unique=True)
return x[pos]
x
和y
是 1d numpy 整数数组,y
的元素是x
中元素的子集,并且两个数组都没有重复。结果是y
的元素,按照它们在x
中出现的顺序。代码给出了我想要的结果。但是中间阵列pos
的大小与x
相同,并且在许多用例中y
比x
小得多。有没有办法在不分配中间数组pos
的情况下更直接地获得结果以节省一些内存?
x
未排序。在我的例子中,它的元素是对象的 id,是值 0->len(x),但顺序未指定,并且按分配给每个对象的分数顺序排序。suborder
的目的是对具有相同分数顺序的子集进行排序。
x
大约有1000万个元素;我有许多不同的y
值,有些接近x
的大小,一直到只有少数几个元素。
编辑:我对对象的一组分数进行argsort
得到了x
。我曾设想过最好对所有分数进行一次排序,然后使用该排序对子集施加顺序。实际上,取scores[y]
,然后argsort
它并按该顺序(对于每个y
)获取y
元素可能更好。
解决方案 1
由于项目在range(0, len(x))
并且都是唯一的(即排列),因此您只能预分配一个大小为len(x)
(RAM 中的len(x)*4
字节)的缓冲区。策略是在对x
进行排序后首先构建一次反向索引:
idx = np.array(len(x), dtype=np.int32) # Can be reused after each sort of `x`
idx[x] = np.arange(len(x), dtype=np.int32) # Can be filled chunk-by-chunk in a loop if memory matters
然后,您需要过滤y
数组,以便所有值都range(0, len(x))
.如果已经是这种情况,请跳过此步骤。可以使用yFilt = y[np.logical_and(y >= 0, y < len(x))]
完成该操作。由于y
可能很大,因此您可以逐块执行此操作。更简单、更快、更节省内存的解决方案是使用 Numba 即时过滤y
。
然后,您需要计算x[np.sorted(idx[yFilt])]
以重新排序y
的项目,例如在x
中。可以使用以下代码就地完成此操作:
# Should not allocate any temporary arrays
idx.take(yFilt, out=yFilt)
yFilt.sort()
x.take(yFilt, out=yFilt)
之后,yFilt
现在像x
中的项目一样订购。请注意,您可以改变y
以便不执行任何临时数组分配(尽管这意味着在此操作后y
代码中的其他内容不会使用它)。
此重新排序算法O(Ny log Ny)
Ny = len(y)
.预计算在O(Nx)
时间内运行,Nx = len(x)
.它需要4 (Nx + Ny)
字节空间用于异地实现,4 Nx
字节用于不执行分配以重新排序y
的就地版本。
解决方案 2
如果前面的解决方案占用了太多内存,那么尽管计算量要大得多,但此解决方案应该是不错的解决方案。它仅使用O(8 Ny)
字节(就地实现O(4 Ny)
),并在O(Nx log Ny)
时间内运行。请注意,输出数组可以预分配一次(并且只能在以后填充),以避免 GC/分配器出现任何问题。
这个想法是在y
的排序+过滤版本中对x
的每个值执行二叉搜索。值在输出数组中动态追加。这个解决方案要求Numba或Cython快速(尽管可以使用块和np.searchsorted
编写复杂的纯Numpy实现)。
import numba as nb
# `out` can be preallocated and passed in parameter to
# avoid allocations in hot loops
@nb.njit('int32[:](int32[:], int32[:])')
def orderLike(x, y):
sorted = np.sort(y) # Use y.sort() for an in-place implementation
out = np.empty(len(y), np.int32)
cur = 0
for v in x:
pos = np.searchsorted(sorted, v)
if pos < len(y) and sorted[pos] == v: # Found
out[cur] = v
cur += 1
return out[:cur]
in1d
开头为:
if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:
...
mask = np.zeros(len(ar1), dtype=bool)
for a in ar2:
mask |= (ar1 == a)
return mask
换句话说,它对y
的每个元素进行相等性测试。 如果您的大小差异不是那么大,那么它使用不同的方法,一种基于连接数组并执行argsort
的方法。
我可以想象使用np.flatnonzero(ar1==a)
来获取等效的索引,并将它们连接起来。 但这将维护y
秩序。