将每个张量乘以另一个张量的值



我有两个张量:

import torch
a = torch.randn((2,3,5))
b = torch.tensor([[2.0, 1.0, 2.0],[0.5, 1.0, 1.0]])

我想把a中最后一个维度的每个元素与b中相应的元素相乘。这意味着当a是:

tensor([[[ 1, 2, 3, 4,  5],
[1, 2,  3,  4,  5],
[1, 2,  3,  4,  5]],
[[1, 2,  3,  4,  5],
[1, 2,  3,  4,  5],
[1, 2,  3,  4,  5]]])

结果应该是:

tensor([[[ 2, 4, 6, 8,  10],
[1, 2,  3,  4,  5],
[ 2, 4, 6, 8,  10]],
[[0.5, 1.0,  1.5,  2.0,  2.5],
[1, 2,  3,  4,  5],
[1, 2,  3,  4,  5]]])

我该怎么做?

您只需要使两个张量可广播,这是基于NumPy中广播的概念。

粗略地说,当维度不匹配时,您希望在张量的shape中有一个1

有几种方法:

  • 整形,例如a * b.reshape(b.shape + (1,))
  • 使用None轴进行切片,例如a * b[..., None]
  • 未挤压,例如a * b.unsqueeze(-1)

虽然最灵活的是整形,但切片通常是最方便但相当明确的。

我所需要做的就是添加一个维度:

a * b.unsqueeze(-1)

最新更新