我有一个nn。forward
函数有两个输入的模块。在函数内部,我将其中一个输入x1
乘以一组可训练参数,然后将它们与另一个输入x2
连接起来。
class ConcatMe(nn.Module):
def __init__(self, pad_len, emb_size):
super(ConcatMe, self).__init__()
self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
self.emb_size = emb_size
def forward(self, x1: Tensor, x2: Tensor):
cat = self.W * torch.reshape(x2, (1, -1, 1))
return torch.cat((x1, cat), dim=-1)
根据我的理解,一个人应该能够在PyTorch的nn中编写操作。像我们处理批大小为1的输入一样的模块。出于某种原因,情况并非如此。我得到一个错误,表明PyTorch仍在考虑batch_size。
x1 = torch.randn(100,2,512)
x2 = torch.randint(10, (2,1))
concat = ConcatMe(100, 512)
concat(x1, x2)
-----------------------------------------------------------------------------------
File "/home/my/file/path.py, line 0, in forward
cat = self.W * torch.reshape(x2, (1, -1, 1))
RuntimeError: The size of tensor a (100) must match the size of tensor b (2) at non-singleton dimension 1
我做了一个for循环来修补这个问题,如下所示:
class ConcatMe(nn.Module):
def __init__(self, pad_len, emb_size):
super(ConcatMe, self).__init__()
self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
self.emb_size = emb_size
def forward(self, x1: Tensor, x2: Tensor):
batch_size = x2.shape[0]
cat = torch.ones(x1.shape).to(DEVICE)
for i in range(batch_size):
cat[:, i, :] = self.W * x2[i]
return torch.cat((x1, cat), dim=-1)
但我觉得有一个更优雅的解决方案。这和我在n。module中创建参数有关系吗?如果是这样,我能实现什么不需要for循环的解决方案呢?
根据我的理解,应该能够在PyTorch的
nn.Module
s中编写操作,就像我们对批量大小为1的输入一样。
我不确定你从哪里得到这个假设,它肯定是不是正确的-相反:你总是需要以一种方式来写它们,它们可以处理任意批处理维度的一般情况。
从你的第二个实现来看,你似乎在尝试将两个维度不兼容的张量相乘。为了解决这个问题你需要定义
self.W = torch.nn.Parameter(torch.randn(pad_len, 1, emb_size), requires_grad=True)
为了更好地理解这类事情,学习广播会有所帮助。