如何从PyTorch中的平坦斑块张量重建图像?

  • 本文关键字:张量 重建 图像 PyTorch pytorch
  • 更新时间 :
  • 英文 :


我有一个形状为(196,256)的PyTorch张量。这对应于由196个补丁组成的224x224图像,其中左上角的补丁是由原始张量的第一行组成的——一个16x16的补丁。紧挨着它右边的补丁是由原始张量的第二行构成的…等等......我如何将这个张量重新排列成224x224的原始图像?

我尝试了patches.view(224,224),但它混合了行,因此不对应于原始图像。

尝试使用Tensor.as_strided。例如,如果我们有一个6x6张量img它原来看起来像

tensor([[ 0,  1,  2,  3,  4,  5],
[ 6,  7,  8,  9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])

,但表示为9个形状为2x2的斑块,排列在一个张量patches中,如

tensor([[ 0,  1,  6,  7],
[ 2,  3,  8,  9],
[ 4,  5, 10, 11],
[12, 13, 18, 19],
[14, 15, 20, 21],
[16, 17, 22, 23],
[24, 25, 30, 31],
[26, 27, 32, 33],
[28, 29, 34, 35]])

我们可以用

patches.as_strided((3, 2, 3, 2), (3 * 2 * 2, 2, 2 * 2, 1)).reshape_as(img)

得到img。在您的例子中,正确的代码应该是

patches.as_strided((14, 16, 14, 16), (14 * 16 * 16, 16, 16 * 16, 1)).reshape_as(img)

最新更新