numpy ndarray按行排序值,然后排除索引



我有一个numpy.ndarray如下:

from numpy import array
a = array( [[1,1,0.4], [1,1,0.3],[0.4,0.3,1]] )
array([[ 1. ,  1. ,  0.4],
       [ 1. ,  1. ,  0.3],
       [ 0.4,  0.3,  1. ]])

这是列:

dataidx = array( [1,2,3] )

我想按行上面的数组的值,然后指定相关的dataidx:

indices = np.argsort(-a, axis=1)
result = np.hstack((dataidx[:, None], dataidx[indices]))
print(result)
[[1 1 2 3]
 [2 1 2 3]
 [3 3 1 2]]

对于每一行,我如何根据下面的第一列排除dataidx?

[[1 2 3]
 [2 1 3]
 [3 1 2]]

这是一种方式 -

In [56]: m = result.shape[0]
In [57]: mask = np.c_[[True]*m,result[:,1:] != result[:,0,None]]
In [58]: result[mask].reshape(m,-1)
Out[58]: 
array([[1, 2, 3],
       [2, 1, 3],
       [3, 1, 2]])

这是另一个 -

In [105]: rm_idx = (result[:,1:] == result[:,0,None]).argmax(1)+1
In [106]: mask = np.ones(result.shape, dtype=bool)
In [107]: mask[np.arange(len(rm_idx)), rm_idx] = 0
In [108]: result[mask].reshape(result.shape[0],-1)
Out[108]: 
array([[1, 2, 3],
       [2, 1, 3],
       [3, 1, 2]])

相关内容

最新更新