Numpy Argmax 在具有多个括号的数组中



我在将 argmax 应用于具有多个括号的数组时遇到问题。 在现实生活中,我得到这个是 pytorch 张量的结果。 在这里我可以举一个例子:

a = np.array([[1.0, 1.1],[2.1,2.0]])
np.argmax(a,axis=1)
array([1, 0])

这是正确的。但:

a = np.array([[[1.0, 1.1]],[[2.1,2.0]]])
np.argmax(a,axis=1)
array([[0, 0],
[0, 0]])

它没有给我我所期望的。 考虑到实际上我有这个级别的内括号:

a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])

使用.squeeze()和负索引。

a = np.array([[[[1.0, 1.1]]], [[[2.1, 2.0]]]])
np.argmax(a, axis = -1).squeeze()
array([1, 0], dtype=int32)

一个可能的解决方案是递增轴值:

a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])
np.argmax(a,axis=3)
array([[[1]],
[[0]]])

但我仍然有内括号。

最新更新