pytorch视图张量和降维



所以我有一个形状为[4,1,128,678]的4d张量,我想将其查看/重塑为[4,678,128]

我必须对多个张量这样做,其中最后的形状值678并不总是已知的,并且可能不同,所以[4,1,128,575]也应该转到[4,575,128]

你知道变换张量的最佳运算是什么吗?查看/重塑?如何?

感谢

您也可以使用(较少写入和IMO更清洁(:

# x.shape == (4, 1, 128, 678)
x.squeeze().permute(0, 2, 1)

如果你使用view,你会丢失维度信息(但也许这就是你想要的(,在这种情况下,它会是:

x.squeeze().view(4, -1, 128)

permute对张量进行重新排序,而shape只提供不同的视图,而没有重构底层内存。您可以在这个StackOverflow答案中看到这两个操作之间的区别。

使用einops,它可以一次完成所有操作,并验证已知尺寸:

from einops import reshape
y = rearrange(x, 'x 1 y z -> x z y', x=4, y=128)

最新更新