如何使用PyTorch从3D张量中删除元素



我有一个形状为torch.Size([4, 161, 325])的张量。如何移除dim=2上的第一个元素,使得到的张量具有torch.Size([4, 161, 324])的形状?

您可以使用简单的切片,

>>>a = torch.randn(4, 161, 325)
>>>b = a[:, :, 1:]
>>>b.shape
torch.Size([4, 161, 324])

进行切片

t = torch.rand(4,161,325)
t = t[..., 1:]            # or t = t[Ellipsis, 1:] Here, Ellipsis indicate rest of dims
t.shape
torch.Size([4, 161, 324])

最新更新