pytorch重复三维



我在文档上使用此示例

In [42]: x = torch.tensor([1,2,3])

In [45]: x.repeat(4,2)
Out[45]: tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
In [46]: x.repeat(4,2).shape 
Out[46]: torch.Size([4, 6])

到目前为止,一切都很好。

但是为什么在第三维度上只重复1次会将第三维度扩展为3(而不是1(?

[On the doc]
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])

仔细检查。

In [43]: x.repeat(4,2,1)
Out[43]:
tensor([[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]]])

为什么它会这样?

它将大小([3](张量沿着第一个dim只扩展一次。(4,2,1(是要重复(3,(张量的次数。最后的张量是(4,2,3(,因为你在最后一个轴上重复(3,(一次,在倒数第二个轴上两次,在第一个轴上4次。

x = torch.tensor([1, 2, 3])
x.shape
torch.Size([3])

然后,

xx = x.repeat(4,2,1)
xx.shape
torch.Size([4, 2, 3])

最新更新