如何在训练中组合两个不同形状的pytorch张量



目前,我的模型给出了3个输出张量。我希望他们中的两个更加合作。我想使用self.dropout1(hs(和self.dropout 2(cls_hs(的组合来通过self.entity_out线性层。问题是提到两个张量的形状不同。

当前代码

class NLUModel(nn.Module):
def __init__(self, num_entity, num_intent, num_scenarios):
super(NLUModel, self).__init__()
self.num_entity = num_entity
self.num_intent = num_intent
self.num_scenario = num_scenarios
self.bert = transformers.BertModel.from_pretrained(config.BASE_MODEL)
self.dropout1 = nn.Dropout(0.3)
self.dropout2 = nn.Dropout(0.3)
self.dropout3 = nn.Dropout(0.3)
self.entity_out = nn.Linear(768, self.num_entity)
self.intent_out = nn.Linear(768, self.num_intent)
self.scenario_out = nn.Linear(768, self.num_scenario)
def forward(self, ids, mask, token_type_ids):
out = self.bert(input_ids=ids, attention_mask=mask,
token_type_ids=token_type_ids)
hs, cls_hs = out['last_hidden_state'], out['pooler_output']
entity_hs = self.dropout1(hs)
intent_hs = self.dropout2(cls_hs)
scenario_hs = self.dropout3(cls_hs)
entity_hs = self.entity_out(entity_hs)
intent_hs = self.intent_out(intent_hs)
scenario_hs = self.scenario_out(scenario_hs)
return entity_hs, intent_hs, scenario_hs

所需

def forward(self, ids, mask, token_type_ids):
out = self.bert(input_ids=ids, attention_mask=mask,
token_type_ids=token_type_ids)
hs, cls_hs = out['last_hidden_state'], out['pooler_output']
entity_hs = self.dropout1(hs)
intent_hs = self.dropout2(cls_hs)
scenario_hs = self.dropout3(cls_hs)
entity_hs = self.entity_out(concat(entity_hs, intent_hs)) # Concatination
intent_hs = self.intent_out(intent_hs)
scenario_hs = self.scenario_out(scenario_hs)
return entity_hs, intent_hs, scenario_hs

假设我成功地连接了。。。反向传播会起作用吗?

entity_hs(last_hidden_state(的形状是[batch_size,sequence_length,hidden_size],intent_hs(pool_output(的形状只是[batch_size,hidden_size],将它们放在一起可能没有意义。这取决于你想做什么。

如果出于某种原因,您想要获得输出[batch_size,sequence_length,channels],您可以平铺intent_hs张量:

intent_hs = torch.tile(intent_hs[:, None, :], (1, sequence_lenght, 1))
... = torch.cat([entity_hs, intent_hs], dim=2) 

如果你想得到[batch_size,channels],你可以减少entity_hs张量,例如通过平均:

entity_hs = torch.mean(entity_hs, dim=1) 
... = torch.cat([entity_hs, intent_hs], dim=1) 

是的,反向传递将通过串联(以及其他部分(传播梯度。

最新更新