我有两个张量:
import torch
a = torch.randn((2,3,5))
b = torch.tensor([[2.0, 1.0, 2.0],[0.5, 1.0, 1.0]])
我想把a中最后一个维度的每个元素与b中相应的元素相乘。这意味着当a是:
tensor([[[ 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]],
[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]]])
结果应该是:
tensor([[[ 2, 4, 6, 8, 10],
[1, 2, 3, 4, 5],
[ 2, 4, 6, 8, 10]],
[[0.5, 1.0, 1.5, 2.0, 2.5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]]])
我该怎么做?
您只需要使两个张量可广播,这是基于NumPy中广播的概念。
粗略地说,当维度不匹配时,您希望在张量的shape
中有一个1
。
有几种方法:
- 整形,例如
a * b.reshape(b.shape + (1,))
- 使用
None
轴进行切片,例如a * b[..., None]
- 未挤压,例如
a * b.unsqueeze(-1)
虽然最灵活的是整形,但切片通常是最方便但相当明确的。
我所需要做的就是添加一个维度:
a * b.unsqueeze(-1)