如何替换numpy数组中每一行最小的N个元素?



我想将每行中最小的N个元素替换为0,并且生成的数组将尊重原始数组的相同顺序和形状。

具体来说,如果原始numpy数组为:

import numpy as np
x = np.array([[0,50,20],[2,0,10],[1,1,0]])

和N = 2,我希望结果如下:

x = np.array([[0,50,0],[0,0,10],[0,1,0]])

我尝试了以下操作,但在最后一行中它替换了3个元素而不是2个(因为它替换了1而不是1个)

import numpy as np
N = 2
x = np.array([[0,50,20],[2,0,10],[1,1,0]])
x_sorted = np.sort(x , axis = 1)
x_sorted[:,N:] = 0
replace = x_sorted.copy()
final = np.where(np.isin(x,replace),0,x)

注意,这是一个小的例子,我希望它适用于一个更大的矩阵。

谢谢你的时间!

使用numpy.argsort:

N = 2
x[x.argsort().argsort() < N] = 0

输出:

array([[ 0, 50,  0],
[ 0,  0, 10],
[ 0,  1,  0]])

使用numpy.argpartition查找N最小元素的索引,然后使用该索引替换值:

N = 2
idy = np.argpartition(x, N, axis=1)[:, :N]
x[np.arange(len(x))[:,None], idy] = 0
x
array([[ 0, 50,  0],
[ 0,  0, 10],
[ 1,  0,  0]])

请注意,如果存在关联,则根据所使用的算法可能无法确定哪些值被替换。

最新更新