我得到了一个numpy矩阵,我想得到每行中最大值的索引。例如
[[1,2,3],[1,3,2],[3,2,1]]
将返回
[0,1,2]
但是,当每行中的最大值超过1时,numpy.argmax
将只返回最小的索引。例如
[[0,0,0],[0,0,0],[0,0,0]]
将返回
[0,0,0]
我可以将默认值(最小索引(更改为其他值吗?例如,当有相等的最大值时,返回1
或None
,则上述结果为
[1,1,1]
or
[None, None, None]
如果我能在TensorFlow中做到这一点,那会更好。
谢谢!
您可以使用np.partition
两个查找两个最大值并检查它们是否相等,然后将其用作np.where
中的掩码以设置默认值:
In [228]: a = np.array([[1, 2, 3, 2], [3, 1, 3, 2], [3, 5, 2, 1]])
In [229]: twomax = np.partition(a, -2)[:, -2:].T
In [230]: default = -1
In [231]: argmax = np.where(twomax[0] != twomax[1], np.argmax(a, -1), default)
In [232]: argmax
Out[232]: array([ 2, -1, 1])
一个方便的"default"值是-1,因为argmax
本身不会返回该值。None
不适合整数数组。屏蔽数组也是一种选择,但我没有走那么远。以下是NumPy实现
def my_argmax(a):
rows = np.where(a == a.max(axis=1)[:, None])[0]
rows_multiple_max = rows[:-1][rows[:-1] == rows[1:]]
my_argmax = a.argmax(axis=1)
my_argmax[rows_multiple_max] = -1
return my_argmax
使用示例:
import numpy as np
a = np.array([[0, 0, 0], [4, 5, 3], [3, 4, 4], [6, 2, 1]])
my_argmax(a) # array([-1, 1, -1, 0])
说明:where
选择每行中所有最大元素的索引。如果一行具有多个最大值,则行号将在rows
数组中出现多次。由于该数组已经排序,因此通过比较连续元素来检测这种重复。这标识了具有多个最大值的行,之后它们在NumPy的argmax
方法的输出中被屏蔽。