尝试在gan上使用分布式数据并行,但获得关于原地操作的运行时错误



我正在尝试使用分布式数据并行训练具有3gpu的GAN机器。在DDP中包装我的模型之前,一切都很好,但是当我包装它时,它给了我以下运行时错误

RuntimeError:梯度计算所需的变量之一已被本地操作修改:[torch.cuda.]FloatTensor[128]]的版本为5;

我克隆了每个相关的张量到梯度来解决就地操作(如果有的话),但我找不到它。

出现问题的部分代码如下:

Tensor = torch.cuda.FloatTensor

# ----------
#  Training
# ----------
def train_gan(rank, world_size, opt):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
if rank == 0:
get_dataloader(rank, opt)
dist.barrier()
print(f"Rank {rank}/{world_size} training process passed data download barrier.n")
dataloader = get_dataloader(rank, opt)
# Loss function
adversarial_loss = torch.nn.BCELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
generator.to(rank)
discriminator.to(rank)
generator_d = DDP(generator, device_ids=[rank])
discriminator_d = DDP(discriminator, device_ids=[rank])

# Optimizers
# Since we are computing the average of several batches at once (an effective batch size of
# world_size * batch_size) we scale the learning rate to match.
optimizer_G = torch.optim.Adam(generator_d.parameters(), lr=opt.lr * opt.world_size, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator_d.parameters(), lr=opt.lr * opt.world_size, betas=(opt.b1, opt.b2))
losses = []
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Adversarial ground truths
valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(rank)
fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(rank)
# Configure input
real_imgs = Variable(imgs.type(Tensor)).to(rank)
# -----------------
#  Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))).to(rank)
# Generate a batch of images
gen_imgs = generator_d(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator_d(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
#  Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator_d(real_imgs), valid)
fake_loss = adversarial_loss(discriminator_d(gen_imgs.detach()), fake)
d_loss = ((real_loss + fake_loss) / 2).to(rank)

d_loss.backward()
optimizer_D.step()

我在尝试用DistributedDataParallel训练GAN时遇到了类似的错误。我注意到问题来自我的鉴别器中的BatchNorm层。

实际上,DistributedDataParallel在每次转发传递时同步batchnorm参数(参见文档),从而修改了相应的变量,如果在一行中有多个转发传递,则会导致问题。

将我的BatchNorm层转换为SyncBatchNorm对我来说很有用:

discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
discriminator = DPP(discriminator)

在使用DistributedDataParallel时,您可能想要这样做。

或者,如果你不想使用SyncBatchNorm,你可以将broadcast_buffers参数设置为False,但我不认为你真的想这样做,因为这意味着你的批规范统计将不会在进程之间同步。

discriminator = DPP(discriminator, device_ids=[rank], broadcast_buffers=False)

相关内容

  • 没有找到相关文章

最新更新