Numpy获取掩码中每列最后两个元素的行索引



我有一个形状为(M, N)的布尔掩码。掩模中的每一列可以具有不同数量的True元素,但保证至少具有两个。我想尽可能有效地找到最后两个这样的元素的行索引。

如果我只想要一个元素,我可以做一些类似(M - 1) - np.argmax(mask[::-1, :], axis=0)的事情。然而,这并不能帮助我获得倒数第二的索引。

我提出了一个使用np.wherenp.nonzero:的迭代解决方案

M = 4
N = 3
mask = np.array([
[False, True, True],
[True, False, True],
[True, False, True],
[False, True, False]
])
result = np.zeros((2, N), dtype=np.intp)
for col in range(N):
result[:, col] = np.flatnonzero(mask[:, col])[-2:]

这将创建预期的result:

array([[1, 0, 1],
[2, 3, 2]], dtype=int64)

我想避免最后一个循环。是否存在上述内容的合理矢量化形式?我正在寻找特别的两行,它们总是保证存在。不需要针对任意元素计数的通用解决方案。

一个argsort可以做到这一点-

In [9]: np.argsort(mask,axis=0,kind='stable')[-2:]
Out[9]: 
array([[1, 0, 1],
[2, 3, 2]])

另一个带有cumsum-

c = mask.cumsum(0)
out = np.where((mask & (c>=c[-1]-1)).T)[1].reshape(-1,2).T

专门针对两行,单向使用argmax-

c = mask.copy()
idx = len(c)-c[::-1].argmax(0)-1
c[idx,np.arange(len(idx))] = 0
idx2 = len(c)-c[::-1].argmax(0)-1
out = np.vstack((idx2,idx))

最新更新