如何将4维PyTorch张量乘以1维张量



我正在尝试编写用于混合训练的函数。在这个网站上,我发现了一些代码,并适应了我以前的代码。但在原始代码中,只为批处理(64(生成一个随机变量。但我想要批量中每张图片的随机值。带有一个批次变量的代码:

def mixup_data(x, y, alpha=1.0):
lam = np.random.beta(alpha, alpha)
batch_size = x.size()[0]
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index,:]
mixed_y = lam * y + (1 - lam) * y[index,:]
return mixed_x, mixed_y

输入的x和y来自pytorch DataLoader。x输入大小:torch.Size([64, 3, 256, 256])y输入大小:torch.Size([64, 3474])

此代码运行良好。然后我把它改成这个:

def mixup_data(x, y):
batch_size = x.size()[0]
lam = torch.rand(batch_size)
index = torch.randperm(batch_size)
mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]
return mixed_x, mixed_y

但它给出了一个错误:RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

我如何理解代码的工作原理是,它批量获取第一个图像,并乘以lam张量中的第一个值(64个值长(。我该怎么做?

您需要替换以下行:

lam = torch.rand(batch_size)

通过

lam = torch.rand(batch_size, 1, 1, 1)

使用当前代码,lam[index] * x乘法是不可能的,因为lam[index]的大小为torch.Size([64]),而x的大小为torch.Size([64, 3, 256, 256])。因此,您需要将lam[index]的大小设置为torch.Size([64, 1, 1, 1]),以便它可以广播。

为了应对以下声明:

mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

我们可以在语句之前对lam张量进行整形。

lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

问题是相乘在一起的两个张量的大小不匹配。让我们以lam[index] * x为例。尺寸如下:

  • x:torch.Size([64, 3, 256, 256])
  • lam[index]:torch.Size([64])

为了将它们相乘在一起,它们应该具有相同的大小,其中lam[index]对每个批次的[3, 256, 256]使用相同的值,因为您希望将该批次中的每个元素乘以相同值,但每个批次的值不同。

lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])

.expand_as(x)重复奇异维度,使其具有与x相同的大小,有关详细信息,请参阅.expand()文档。

您不需要扩展张量,因为如果存在奇异维度,PyTorch会自动为您扩展张量。这就是所谓的广播:PyTorch-广播语义。因此,大小为torch.Size([64, 1, 1, 1])就足以将其与x相乘。

lam[index].view(batch_size, 1, 1, 1) * x

这同样适用于y,但具有torch.Size([64, 1])的大小,因为y具有torch.Size([64, 3474])的大小。

mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]

只是一个小小的补充,lam[index]只重新排列了lam的元素,但由于你是随机创建的,所以无论你是否重新排列它都没有任何区别。唯一重要的是xy被重新排列,就像在原始代码中一样。

最新更新