填充和注意力掩码在GPT语言模型中的批量输入中不能按预期工作



以下代码没有批处理:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
context=torch.tensor([tokenizer.encode("This is")])
output, past = model(context)
token = torch.argmax(output[..., -1, :])
print(tokenizer.decode(token.item()))
output: ' a'

这很好用。现在,我将其扩展到批量设置:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
context=[torch.tensor(tokenizer.encode("This is ")),torch.tensor(tokenizer.encode("Hello How are "))]
context=pad_sequence(context,batch_first=True)
mask=torch.tensor([[1,1,0],[1,1,1]])
output, past = model(context,attention_mask=mask)
token = torch.argmax(output[..., -1, :],dim=1)
tokenizer.decode(token)
output: 'n you'

这里n是该批的第一上下文的下一个令牌,而you是该批第二上下文的下个令牌。但是第一个上下文的预期下一个令牌是a,因为所有设置都是相同的。此外,如果您将第二个上下文减少为2个令牌,您将在此批处理设置中获得a。很明显,模型无法理解填充。此外,注意力面罩也不起作用。因为在填充之后,序列CCD_ 5的下一个令牌是0(零(。并且根据注意力掩码([1,1,0](,应该避免这个零,并且应该只关注令牌thisis。这种注意力掩蔽不起作用的证据是:

  • 使用注意力掩码[1,1,1],这意味着即使在填充零的时候也要注意,你会得到相同的输出即CCD_ 9。

  • 使用字符串this is!。这里CCD_ 11在词汇表矩阵中具有零索引。再次得到相同的输出,即n

只有时间,才有可能获得理想的输出是没有批量设置和注意力掩码的(现在看来,这无关紧要,因为它无论如何都没有效果(

然后我发现了这个,建议使用pad_token。所以我用了如下:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from torch.nn.utils.rnn import pad_sequence  
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token="<PAD>")
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
context=[torch.tensor(tokenizer.encode("This is <PAD> ")),torch.tensor(tokenizer.encode("Hello How are"))]
context=torch.stack(context)
print(context)
mask=torch.tensor([[1,1,0],[1,1,1]])
output, past = model(context,attention_mask=mask)
token = torch.argmax(output[..., -1, :],dim=1)
tokenizer.decode(token)
output: 'The you'

这里CCD_ 13是该批的第一上下文的下一个令牌,而CCD_。这也不起作用。因为对于第一个上下文不期望The

如何在gpt/gpt2模型的批量设置中使用可变长度序列?

我不确定这是否有帮助,但您不需要实现自己的注意力掩蔽和填充。Transformers库提供encode_plus((和batch_encode_plus。结果显示为Python字典。

最新更新