Pytorch - 高效元素乘法



>我有一个 [100x3] 的 3D 点张量

我有一个权重为 [100x1] 的向量,需要将其元素乘以 X,Y,Z 坐标。

目前,我

正在创建一个新的向量 W,在我进行元素明智乘法之前,我将重复的 [100x3] 元素堆叠成 [100x3] 张量。

我需要这样做很多次,这太慢了,而且占用大量内存。有没有更好的方法?

PyTorch 中的标准乘法 ( * ( 已经是元素化的。此外,它还广播。所以

import torch
xyz = torch.randn(100, 3)
w = torch.randn(100, 1)
multiplied = xyz * w

只会做这个伎俩。

最新更新