如何对不同维度的张量进行元素乘

  • 本文关键字:张量进 元素 python pytorch
  • 更新时间 :
  • 英文 :


我有一个张量expanded_mask,它的大小为torch.Size([1, 208]),另一个张量为inputs,它的尺寸为torch.Size([1, 208, 161])

我想将expanded_maskinput元素相乘,使得第三维度的所有161个元素都与expanded_mask的208个元素相乘。

按照乔达的回答,我试过了:

masked_inputs = expanded_mask.unsqueeze(2) * inputs

inputs为:

tensor([1.8851e-02, 4.4921e-02, 7.5260e-02, 3.8994e-02, 3.5651e-02, 3.0457e-02,
1.2933e-02, 2.5496e-02, 2.3260e-04, 2.4903e-03, 6.5678e-03, 1.0501e-02,
1.2387e-02, 1.9434e-03, 1.0831e-03, 6.5691e-03, 5.3792e-03, 9.1925e-03,
1.8146e-03, 4.9215e-03, 1.4623e-03, 9.4454e-03, 1.0504e-03, 3.3749e-03,
2.1361e-03, 8.0782e-03, 1.7916e-03, 1.1577e-03, 1.1246e-04, 2.2520e-03,
2.2255e-03, 2.1072e-03, 9.8782e-03, 2.2909e-03, 2.9957e-03, 5.8540e-03,
1.1067e-02, 9.0582e-03, 5.6360e-03, 6.3841e-03, 5.9298e-03, 1.9501e-04,
2.7967e-03, 3.5786e-03, 9.2363e-03, 8.3934e-03, 8.8185e-04, 5.4591e-03,
2.2451e-04, 2.2307e-03, 2.4871e-03, 3.6736e-03, 1.3842e-04, 2.7455e-03,
6.2199e-03, 1.1924e-02, 9.5953e-03, 1.6939e-03, 4.1919e-04, 9.3509e-05,
1.8351e-03, 6.3350e-04, 1.1076e-03, 1.5472e-03, 1.2104e-03, 3.1803e-04,
8.6507e-04, 3.0083e-03, 2.8435e-03, 1.6740e-03, 8.1023e-05, 7.5767e-04,
9.1442e-04, 2.0204e-03, 1.3987e-03, 3.7729e-03, 5.2012e-04, 2.0367e-03,
1.5177e-03, 1.6948e-03, 9.5833e-04, 1.2050e-03, 1.8356e-03, 9.4503e-04,
4.8612e-04, 1.6844e-04, 1.2222e-04, 1.7526e-03, 2.6397e-04, 1.3026e-03,
1.0704e-03, 3.6407e-04, 1.3135e-03, 2.6665e-03, 1.8639e-03, 3.0385e-05,
1.0212e-03, 7.6236e-04, 1.7878e-03, 2.4298e-03, 7.2158e-05, 1.2488e-03,
2.1347e-03, 3.9256e-03, 3.1436e-03, 3.1648e-03, 3.4657e-03, 1.3746e-03,
1.6927e-03, 1.0794e-03, 8.8152e-04, 1.1757e-04, 3.2254e-04, 4.1866e-04,
9.2787e-04, 2.0020e-03, 1.4813e-03, 1.1912e-03, 2.4577e-03, 2.2247e-03,
1.7862e-03, 1.7460e-03, 1.4388e-03, 4.3175e-04, 6.7808e-04, 2.6875e-04,
3.6475e-04, 8.7643e-04, 3.6790e-04, 2.1274e-04, 6.3725e-04, 2.0949e-03,
2.4069e-03, 1.7348e-03, 1.0026e-03, 1.2451e-03, 4.7888e-04, 5.9790e-04,
1.4343e-03, 4.0900e-03, 1.0176e-03, 5.5178e-04, 2.0624e-03, 1.2878e-03,
6.9607e-04, 4.3259e-04, 1.8573e-03, 7.5521e-04, 5.2949e-04, 3.4758e-04,
4.7898e-04, 7.5599e-04, 6.0631e-04, 1.7585e-03, 1.8156e-03, 3.2421e-04,
8.9446e-04, 7.2131e-04, 6.2817e-04, 1.0827e-03, 2.0211e-03],
device='cuda:0')

expanded_mask是:

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
grad_fn=<AsStridedBackward>) 

则CCD_ 10为:

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0', grad_fn=<SelectBackward>)

看起来1的没有被相乘。

使用广播的另一种方式:

import torch
mask = torch.tensor([[1, 0, 1]])
inputs = torch.randn(1, 3, 2)
masked = inputs * mask[..., None]
print(mask)
print(inputs)
print(masked)

结果:

tensor([[1, 0, 1]])
tensor([[[ 2.2820,  2.7476],
[-0.1738, -0.5703],
[ 0.7077, -0.6384]]])
tensor([[[ 2.2820,  2.7476],
[-0.0000, -0.0000],
[ 0.7077, -0.6384]]])

省略号运算符表示所有维度,然后None在末尾添加一个维度。

您可以在此处依赖广播语义。我们首先在expanded_mask上使用Tensor.unsqueeze(2),在末端添加一个酉维数,使其成为大小为[1, 154, 1]的张量。然后,乘法运算将隐式地使用类似numpy的广播语义来将inputs的161个信道中的每一个与expanded_mask相乘。

所以最后的结果是

expanded_mask.unsqueeze(2) * inputs

最新更新