如何在Pytorch中交换三维



我有一个a= torch.randn(28, 28, 8),我想交换张量的维度,并将第三个维度移动到第一个位置,第一个移动到第二个位置,第二个移到第三个位置。我使用了b = a.transpose(2, 0, 1),但收到了以下错误:

TypeError: transpose() received an invalid combination of arguments - got (int, int, int), but expected one of:
* (name dim0, name dim1)
* (int dim0, int dim1)

我是否应该多次使用转置,每次只交换两个维度?有什么方法可以让我一次交换所有东西吗?

谢谢。

您可以使用Pytorch的permute()函数一次交换所有

>>>a = torch.randn(28, 28, 8)
>>>b = a.permute(2, 0, 1)
>>>b.shape
torch.Size([8, 28, 28])

使用permute:

b = a.permute(2, 0, 1)