我试图从Pytorch中表示图像的张量中提取亮度,因此我需要将大小为3的矢量(对于三个RGB值权重(与表示图像的3xNxN张量逐元素相乘,以便最终获得NxN矩阵,其中张量的三个通道已与矢量中给定的权重相加。
我想Pytorch操作可以帮助我在没有循环的情况下做到这一点,但我还没有找到它们。
您必须重塑3
维RGB矢量,以便像这样广播到3xNxN
:
rgb = rgb.reshape(-1, 1, 1)
因此它将具有(3, 1, 1)
形状
现在您可以将其与原始图像相乘,并沿第一个维度求和:
result = torch.sum(rgb * image, dim=0)