Pytorch跨多个维度的argmax



我有一个四维张量,我想在最后两个维度上得到argmax。torch.argmax只接受整数作为"暗淡";参数,而不是元组。

我怎样才能做到这一点?

这是我的想法,但我不知道如何匹配我的两个"索引"的尺寸。张量。original_array为形状[1,512,37,59]

max_vals, indices_r = torch.max(original_array, dim=2)
max_vals, indices_c = torch.max(max_vals, dim=2)
indices = torch.hstack((indices_r, indices_c))

正如其他人所提到的,最好将最后两个维度扁平化并应用argmax

original_array = torch.rand(1, 512, 37, 59)
original_flatten = original_array.view(1, 512, -1)
_, max_ind = original_flatten.max(-1)

. .你会得到最大值的线性索引。如果你想要2D指数的最大值,你可以执行"un扁平化";使用列数

的索引
# 59 is the number of columns for the (37, 59) part
torch.stack([max_ind // 59, max_ind % 59], -1)

这将给你一个(1, 512, 2),其中每个最后2个dim包含2D坐标。

您可以使用torch.flatten将最后两个维度平坦,并在其上应用torch.argmax:

>>> x = torch.rand(2,3,100,100)
>>> x.flatten(-2).argmax(-1)
tensor([[2660, 6328, 8166],
[5934, 5494, 9717]])

最新更新