根据给定的维度从PyTorch张量中分割和提取值



我有一个大小为torch.Size([32, 32, 3, 3])的张量A,我想将其拆分并从中提取大小为torch.Size([16, 16, 3, 3])的张量B。张量可以是1d或4d,并且必须根据给定的新张量维度进行拆分。我已经能够生成目标维度,但无法从源张量中分割和提取值。我尝试过torch.narrow,但它只需要3个参数,在很多情况下我需要4个。torch.split将dim作为int,因此张量仅沿一维分裂。但我想把它分成多个维度。

您有多个选项:

  • 多次使用.split
  • 多次使用.narrow
  • 使用切片

例如:

t = torch.rand(32, 32, 3, 3)
t0, t1 = t.split((16, 16), 0)
print(t0.shape, t1.shape)
>>> torch.Size([16, 32, 3, 3]) torch.Size([16, 32, 3, 3])
t00, t01 = t0.split((16, 16), 1)
print(t00.shape, t01.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])
t00_alt, t01_alt = t[:16, :16, :, :], t[16:, 16:, :, :]
print(t00_alt.shape, t01_alt.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])

最新更新