我试图在给定内核大小的2D张量中找到最大点,但在所有值都是一致的特殊情况下,我遇到了问题。例如,给定以下示例,我想将每个点标记为最大点:
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
如果我运行torch.nn.functional.max_pool2d,内核大小为3,步长为1,填充为1,我会得到以下标记:
+---+---+---+----+
| 0 | 0 | 1 | 2 |
+---+---+---+----+
| 0 | 0 | 1 | 2 |
+---+---+---+----+
| 4 | 4 | 5 | 6 |
+---+---+---+----+
| 8 | 8 | 9 | 10 |
+---+---+---+----+
我需要考虑哪些变化才能获得以下标记?
+----+----+----+----+
| 1 | 2 | 3 | 4 |
+----+----+----+----+
| 5 | 6 | 7 | 8 |
+----+----+----+----+
| 9 | 10 | 11 | 12 |
+----+----+----+----+
| 13 | 14 | 15 | 16 |
+----+----+----+----+
您可以执行以下操作:
a = torch.ones(4,4)
indices = (a == torch.max(a).item()).nonzero()
这样做的目的是返回具有最大值的2D坐标的[16,2]
大小的张量,即[0,0], [0,1], .., [3,3]
。torch.max
部分应该很容易理解,nonzero()
考虑(a == torch.max(a).item())
给出的布尔张量,取False
为0,并返回非零索引。希望这能有所帮助!
如果你想要2d
形状的索引@ccl已经给了你答案,但对于1d
索引,你可以首先使用torch.flatten
张量使x
为1d,然后使用torch.nonzero
获得索引,最后转换为相同的形状。
x = torch.ones(4,4) * 5
(x.flatten() == x.flatten().max()).nonzero().reshape(x.shape) + 1
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])