为三维numpy数组的每一层查找最大元素的列索引



我有一个3D NumPy数组arr。这里有一个例子:

>>> arr
array([[[0.05, 0.05, 0.9 ],
[0.4 , 0.5 , 0.1 ],
[0.7 , 0.2 , 0.1 ],
[0.1 , 0.2 , 0.7 ]],
[[0.98, 0.01, 0.01],
[0.2 , 0.3 , 0.95],
[0.33, 0.33, 0.34],
[0.33, 0.33, 0.34]]])

对于立方体的每一层(即,对于每个矩阵(,我想找到包含矩阵中最大数字的列的索引。例如,让我们以第一层为例:

>>> arr[0]
array([[0.05, 0.05, 0.9 ],
[0.4 , 0.5 , 0.1 ],
[0.7 , 0.2 , 0.1 ],
[0.1 , 0.2 , 0.7 ]])

这里,最大的元素是0.9,它可以在第三列(即索引2(上找到。相反,在第二层中,可以在第一列上找到最大值(最大值为0.98,列索引为0(。

上一个例子的预期结果是:

array([2, 0])

以下是我迄今为止所做的:

tmp = arr.max(axis=-1)
argtmp = arr.argmax(axis=-1)
indices = np.take_along_axis(
argtmp,
tmp.argmax(axis=-1).reshape((arr.shape[0], -1)),
1,
).reshape(-1)

上面的代码是有效的,但我想知道它是否可以进一步简化,因为从我的角度来看,它似乎太复杂了。

在应用argmax:之前查找每列中的最大值

arr.max(-2).argmax(-1)

将列减少为单个最大值不会改变哪一列的值最大。由于您不关心行索引,这为您省去了很多麻烦。

最新更新