如何正确使用 Pytorch 的视图功能?



我有一个大小为 x={4,2,C,H,W} 的张量。我需要将其重塑为 y={8,C,H,W},但我想确保图像以正确的顺序存储,所以说 x[1,0,:,:,:] 的图像必须等于 y[2,C,H,W]。我知道我可以为此使用视图功能,但我不确定如何正确使用它。

目前我正在这样做

feat_imgs_all = feat_imgs_all.view(
rgb.shape[0], rgb.shape[1], feat_imgs_all.shape[1], 
feat_imgs_all.shape[2], feat_imgs_all.shape[3])

这看起来真的很笨拙,有没有办法我可以只喂前两个形状,然后 pytorch 找出其余的?

您可以使用flattenend_dim参数轻松完成此操作,请参阅documentation

import torch
a = torch.randn(4, 2, 32, 64, 64)
flattened = a.flatten(end_dim=1)
torch.all(flattened[2, ...] == a[1, 0, ...]) # True

view也可以使用,如下所示,尽管它不太易读也不太令人愉快:

import torch
a = torch.randn(4, 2, 32, 64, 64)
flattened = a.view(-1, *a.shape[2:])
torch.all(flattened[2, ...] == a[1, 0, ...]) # True as well

最新更新