如何在pytorch中共享两个模型的共同部分



我在用pytorch实现模型时遇到了问题。

我想建立两个模型,其中一些是共享的,并像这个一样共享编码器部分

Model1: input_1 -> encoder -> decoder_1 -> ouput_1
Model2: input_2 -> encoder -> decoder_2 -> ouput_2  

我想做的是让这两个模型一起使用编码器部分,但解码器部分不一样。我查阅了有关参数共享的信息,但它似乎与这里的要求有所不同。

我自己的想法是建立一个包括encode、decoder_1和decoder_2的模型,然后根据输入选择使用哪个解码器。

我不确定这种方法,如果可能的话,你能举几个简单的例子来使用两个模型的公共部分吗?

您可以执行以下操作:

import torch.nn as nn
class SharedModel(nn.Module):
def __init__(self, mode):
super(SharedModel, self).__init__()
self.mode = mode # use 1 or 2
self.encoder = ...
self.decoder_1 = ...
self.decoder_2 = ...

def forward(self, x):
x = self.encoder(x)
if self.mode == 1:
x = self.decoder_1(x)
elif self.mode == 2:
x = self.decoder_2(x)
else:
raise ValueError("Unkown mode.")
return x

我不会制作一个SharedModel,宁愿选择两个,但共享encoder部分。

import torch
class Model(torch.nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder: torch.nn.Module = encoder
self.decoder: torch.nn.Module

def forward(self, x):
return self.decoder(self.encoder)
encoder = ...
decoder1 = ...
decoder2 = ...
first = Model(encoder, decoder1)
second = Model(encoder, decoder2)

您也可以拆分Model类,但本质上,将编码器作为参数传递给构造函数,它将在任意多个模型之间共享。

不需要if、自定义mode和其他变通方法。此外,它是独立于输入的,没有办法意外地传递错误的输入并获得结果。

最新更新