我有一个有3个通道的torch张量,我希望它是1个通道(所有其他维度都应该保持不变(。所以如果我当前的维度是torch.Size([6, 3, 512, 512])
,我希望它是torch.Size([6, 1, 512, 512])
我该怎么做?
这能解决您的问题吗?
a = torch.ones(6, 3, 512, 512)
b = a[:, 0:1, :, :]
print(b.size()) # torch.Size([6, 1, 512, 512])