我有一个a= torch.randn(28, 28, 8)
,我想交换张量的维度,并将第三个维度移动到第一个位置,第一个移动到第二个位置,第二个移到第三个位置。我使用了b = a.transpose(2, 0, 1)
,但收到了以下错误:
TypeError: transpose() received an invalid combination of arguments - got (int, int, int), but expected one of:
* (name dim0, name dim1)
* (int dim0, int dim1)
我是否应该多次使用转置,每次只交换两个维度?有什么方法可以让我一次交换所有东西吗?
谢谢。
您可以使用Pytorch的permute()
函数一次交换所有
>>>a = torch.randn(28, 28, 8)
>>>b = a.permute(2, 0, 1)
>>>b.shape
torch.Size([8, 28, 28])
使用permute
:
b = a.permute(2, 0, 1)