我有一个大小为 x={4,2,C,H,W} 的张量。我需要将其重塑为 y={8,C,H,W},但我想确保图像以正确的顺序存储,所以说 x[1,0,:,:,:] 的图像必须等于 y[2,C,H,W]。我知道我可以为此使用视图功能,但我不确定如何正确使用它。
目前我正在这样做
feat_imgs_all = feat_imgs_all.view(
rgb.shape[0], rgb.shape[1], feat_imgs_all.shape[1],
feat_imgs_all.shape[2], feat_imgs_all.shape[3])
这看起来真的很笨拙,有没有办法我可以只喂前两个形状,然后 pytorch 找出其余的?
您可以使用flatten
和end_dim
参数轻松完成此操作,请参阅documentation
:
import torch
a = torch.randn(4, 2, 32, 64, 64)
flattened = a.flatten(end_dim=1)
torch.all(flattened[2, ...] == a[1, 0, ...]) # True
view
也可以使用,如下所示,尽管它不太易读也不太令人愉快:
import torch
a = torch.randn(4, 2, 32, 64, 64)
flattened = a.view(-1, *a.shape[2:])
torch.all(flattened[2, ...] == a[1, 0, ...]) # True as well