我在将 argmax 应用于具有多个括号的数组时遇到问题。 在现实生活中,我得到这个是 pytorch 张量的结果。 在这里我可以举一个例子:
a = np.array([[1.0, 1.1],[2.1,2.0]])
np.argmax(a,axis=1)
array([1, 0])
这是正确的。但:
a = np.array([[[1.0, 1.1]],[[2.1,2.0]]])
np.argmax(a,axis=1)
array([[0, 0],
[0, 0]])
它没有给我我所期望的。 考虑到实际上我有这个级别的内括号:
a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])
使用.squeeze()
和负索引。
a = np.array([[[[1.0, 1.1]]], [[[2.1, 2.0]]]])
np.argmax(a, axis = -1).squeeze()
array([1, 0], dtype=int32)
一个可能的解决方案是递增轴值:
a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])
np.argmax(a,axis=3)
array([[[1]],
[[0]]])
但我仍然有内括号。