PyTorch自定义转发函数不适用于DataParallel



编辑:我尝试过PyTorch 1.6.0和1.7.1,两者都给了我相同的错误。

我有一个模型,允许用户在不同的架构a和B之间轻松切换。两种架构的正向功能也不同,所以我有以下模型类:

p.S.我只是用一个非常简单的例子来演示我的问题,实际的模型要复杂得多

class Net(nn.Module):
def __init__(self, condition):
super().__init__()
self.linear = nn.Linear(10, 1)

if condition == 'A':
self.forward = self.forward_A
elif condition == 'B':
self.linear2 = nn.Linear(10, 1)
self.forward = self.forward_B

def forward_A(self, x):
return self.linear(x)

def forward_B(self, x1, x2):
return self.linear(x1) + self.linear2(x2)

它在单个GPU的情况下运行良好。然而,在多GPU的情况下,它给我带来了一个错误。

device= 'cuda:0'
x = torch.randn(8,10).to(device)
model = Net('B')
model = model.to(device)
model = nn.DataParallel(model)
model(x, x)

RuntimeError:预期所有张量都在同一设备上,但找到至少有两个设备,cuda:0和cuda:1!(检查参数时方法wrapper_admm(中的参数mat1

如何使此模型类与nn.DataParallel一起工作?

您正在强制输入x和模型在'cuda:0'设备上,但在多个GPU上工作时,不应指定任何特定设备
尝试:

x = torch.randn(8,10)  
model = Net('B')
model =  nn.DataParallel(model, device-ids=[0, 1]).cuda()  # assuming 2 GPUs
pred = model(x, x)

如果有两个包装器,每个包装器都用自己的前向函数调用这个模型,那么这个问题就会消失。

你还需要使用nn。DataParallel而不是nn。单元

相关内容

最新更新