我有一个形状为(196,256)的PyTorch张量。这对应于由196个补丁组成的224x224图像,其中左上角的补丁是由原始张量的第一行组成的——一个16x16的补丁。紧挨着它右边的补丁是由原始张量的第二行构成的…等等......我如何将这个张量重新排列成224x224的原始图像?
我尝试了patches.view(224,224),但它混合了行,因此不对应于原始图像。
尝试使用Tensor.as_strided
。例如,如果我们有一个6x6张量img
它原来看起来像
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])
,但表示为9个形状为2x2的斑块,排列在一个张量patches
中,如
tensor([[ 0, 1, 6, 7],
[ 2, 3, 8, 9],
[ 4, 5, 10, 11],
[12, 13, 18, 19],
[14, 15, 20, 21],
[16, 17, 22, 23],
[24, 25, 30, 31],
[26, 27, 32, 33],
[28, 29, 34, 35]])
我们可以用
patches.as_strided((3, 2, 3, 2), (3 * 2 * 2, 2, 2 * 2, 1)).reshape_as(img)
得到img
。在您的例子中,正确的代码应该是
patches.as_strided((14, 16, 14, 16), (14 * 16 * 16, 16, 16 * 16, 1)).reshape_as(img)