简易变压器的负载池层



我有一个经过微调的简单转换器表示模型。现在我只想以pickle格式保存池层的权重,并将其放在我正在设计的另一个自定义自动编码器的池层中。如何使用pytorch和python来做到这一点?

每个PyTorch模块旁边都有一个名为state_dict的对象,它允许将任何参数映射到其对应的张量变量(此处详细介绍(。使用此实用程序,您可以轻松地保存和加载参数,但请记住,您必须事先确定要在语义(从机器学习的角度(和语法(形状兼容性和…(上做什么!下面的实现将用名称中的单词pooling替换任何参数,并使用我们之前保存的模型中的相应变量。

finetuned_model = BertLMHeadModel.from_pretrained('bert-base-cased')
torch.save(finetuned_model.state_dict(), "finetuned_model.pth")
finetuned_model_state_dict = torch.load("finetuned_model.pth")
new_model = BertLMHeadModel.from_pretrained('bert-base-cased')
new_model_state_dict = new_model.state_dict()
for key, value in new_model_state_dict.items():
if key.find('pooling')!=-1:
new_model_state_dict.update({key: value})

最新更新