在使用多GPU的Transformer中进行训练时,遮罩的形状将除以GPU的数量.为什么

  • 本文关键字:GPU Transformer nlp pytorch transformer-model
  • 更新时间 :
  • 英文 :


那里我正在用多GPU训练Transformer,但我遇到了一个问题。我正在使用Pytorch并使用

model = Transformer(
src_tokens=src_tokens, tgt_tokens=tgt_tokens, dim_model=dim_model, num_heads=num_heads,
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dropout_p=0.1)
model = nn.DataParallel(model, device_ids=device_ids)
model.to(device)

培训过程如下:

def train_loop(model, opt, loss_fn, dataloader):
model.train()
total_loss = 0
for X, y in dataloader:
X, y = X.t().to(device), y.t().to(device)
y_input = y[:, :-1]
y_expected = y[:, 1:]
sequence_length = y_input.size(1)
src_pad_mask = create_pad_mask(X, 1)
tgt_pad_mask = create_pad_mask(y_input, 1)
tgt_mask = get_tgt_mask(sequence_length)
pred = model(X, y_input, tgt_mask=tgt_mask, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask)
# Permute pred to have batch size first again
pred = pred.permute(1, 2, 0)
loss = loss_fn(pred, y_expected)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.detach().item()
return total_loss / len(dataloader)

我的模型.py是这样的:

class Transformer(nn.Module):
"""
Model from "A detailed guide to Pytorch's nn.Transformer() module.", by
Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
"""
# Constructor
def __init__(
self,
src_tokens,
tgt_tokens,
dim_model,
num_heads,
num_encoder_layers,
num_decoder_layers,
dropout_p,
):
super().__init__()
# INFO
self.model_type = "Transformer"
self.dim_model = dim_model
# LAYERS
self.positional_encoder = PositionalEncoding(
dim_model=dim_model, dropout_p=dropout_p, max_len=5000
)
self.src_embedding = nn.Embedding(src_tokens, dim_model)
self.tgt_embedding = nn.Embedding(tgt_tokens, dim_model)
self.transformer = nn.Transformer(
d_model=dim_model,
nhead=num_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout_p,
)
self.out = nn.Linear(dim_model, tgt_tokens)
def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
# Src size must be (batch_size, src sequence length)
# Tgt size must be (batch_size, tgt sequence length)
# Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
src = self.src_embedding(src) * math.sqrt(self.dim_model)
tgt = self.tgt_embedding(tgt) * math.sqrt(self.dim_model)
src = self.positional_encoder(src)
tgt = self.positional_encoder(tgt)
# We could use the parameter batch_first=True, but our KDL version doesn't support it yet, so we permute
# to obtain size (sequence length, batch_size, dim_model),
src = src.permute(1, 0, 2)
tgt = tgt.permute(1, 0, 2)
print('src_pad_mask: '+str(src_pad_mask.shape)+'  tgt_pad_mask: '+str(tgt_pad_mask.shape)+'  tgt_mask: '+str(tgt_mask.shape))
print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
# Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask,
tgt_key_padding_mask=tgt_pad_mask)
out = self.out(transformer_out)
return out

我得到这个错误:

root@3b:/koi/transformer-multi# python train.py
src_pad_mask: torch.Size([1, 4784])  tgt_pad_mask: torch.Size([1, 3225])  tgt_mask: torch.Size([538, 3225])
++++++++++++++++++++++++++++++++++++++++++++++++++++++++
src_pad_mask: torch.Size([1, 4784])  tgt_pad_mask: torch.Size([1, 3225])  tgt_mask: torch.Size([538, 3225])
++++++++++++++++++++++++++++++++++++++++++++++++++++++++
src_pad_mask: torch.Size([1, 4784])  tgt_pad_mask: torch.Size([1, 3225])  tgt_mask: torch.Size([538, 3225])
++++++++++++++++++++++++++++++++++++++++++++++++++++++++
src_pad_mask: torch.Size([1, 4784])  tgt_pad_mask: torch.Size([1, 3225])  tgt_mask: torch.Size([538, 3225])
++++++++++++++++++++++++++++++++++++++++++++++++++++++++
src_pad_mask: torch.Size([1, 4784])  tgt_pad_mask: torch.Size([1, 3225])  tgt_mask: torch.Size([538, 3225])
++++++++++++++++++++++++++++++++++++++++++++++++++++++++
src_pad_mask: torch.Size([1, 4784])  tgt_pad_mask: torch.Size([1, 3225])  tgt_mask: torch.Size([535, 3225])
++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Traceback (most recent call last):
File "/koi/transformer-multi/train.py", line 160, in <module>
train_loss = train_loop(model, opt, loss_fn, trn_loader)
File "/koi/transformer-multi/train.py", line 121, in train_loop
pred = model(X, y_input, tgt_mask=tgt_mask, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/koi/transformer-multi/model.py", line 93, in forward
transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask,
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 142, in forward
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 248, in forward
output = mod(output, memory, tgt_mask=tgt_mask,
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 451, in forward
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 460, in _sa_block
x = self.self_attn(x, x, x,
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1003, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/root/anaconda3/envs/koi/lib/python3.9/site-packages/torch/nn/functional.py", line 5011, in multi_head_attention_forward
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([538, 3225]), but should be (3225, 3225).

我测试了好几次,每次我改变GPU的数量。掩模的形状将除以GPU的数量。我不知道怎么解决这个问题。

Dataparallel通过假设第一个维度大约是批量大小来分割第一个维度,因此它将tgt_mask分割为6个张量。有人在讨论这个问题,但我确信现在是否有解决方案。https://discuss.pytorch.org/t/avoid-tensor-split-with-nn-dataparallel/18293

您可以将tgt_mask[32253225]重复到[N32252225],并在将其传递给变压器模型之前,将其从[132250325]重新整形为[32253225]。

相关内容

  • 没有找到相关文章

最新更新