从统一数据上的最大池中提取索引

  • 本文关键字:提取 索引 数据 pytorch
  • 更新时间 :
  • 英文 :


我试图在给定内核大小的2D张量中找到最大点,但在所有值都是一致的特殊情况下,我遇到了问题。例如,给定以下示例,我想将每个点标记为最大点:

+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+

如果我运行torch.nn.functional.max_pool2d,内核大小为3,步长为1,填充为1,我会得到以下标记:

+---+---+---+----+
| 0 | 0 | 1 |  2 |
+---+---+---+----+
| 0 | 0 | 1 |  2 |
+---+---+---+----+
| 4 | 4 | 5 |  6 |
+---+---+---+----+
| 8 | 8 | 9 | 10 |
+---+---+---+----+

我需要考虑哪些变化才能获得以下标记?

+----+----+----+----+
| 1  | 2  | 3  |  4 |
+----+----+----+----+
|  5 |  6 |  7 |  8 |
+----+----+----+----+
|  9 | 10 | 11 | 12 |
+----+----+----+----+
| 13 | 14 | 15 | 16 |
+----+----+----+----+

您可以执行以下操作:

a = torch.ones(4,4)
indices = (a == torch.max(a).item()).nonzero()

这样做的目的是返回具有最大值的2D坐标的[16,2]大小的张量,即[0,0], [0,1], .., [3,3]torch.max部分应该很容易理解,nonzero()考虑(a == torch.max(a).item())给出的布尔张量,取False为0,并返回非零索引。希望这能有所帮助!

如果你想要2d形状的索引@ccl已经给了你答案,但对于1d索引,你可以首先使用torch.flatten张量使x为1d,然后使用torch.nonzero获得索引,最后转换为相同的形状。

x = torch.ones(4,4) * 5
(x.flatten() == x.flatten().max()).nonzero().reshape(x.shape) + 1
tensor([[ 1,  2,  3,  4],
[ 5,  6,  7,  8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])

最新更新