masked_lm_labels参数如何在BertForMaskedLM中工作


from transformers import BertTokenizer, BertForMaskedLM
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2] 

这段代码来自huggingface transformers页面。https://huggingface.co/transformers/model_doc/bert.html#bertformaskedlm

我无法理解model中的masked_lm_labels=input_ids自变量。它是如何工作的?这是否意味着当input_ids通过时,它将自动屏蔽部分文本?

第一个参数是屏蔽输入,masked_lm_labels参数是所需输出。

应屏蔽input_ids。一般来说,如何进行掩蔽取决于你自己。在最初的BERT中,他们选择15%的代币和以下代币,或者

  • 使用[MASK]令牌;或
  • 使用随机令牌;或
  • 保持原始令牌不变

这会修改输入,因此您需要告诉模型什么是原始的非屏蔽输入,即masked_lm_labels参数。还要注意,您不希望只计算实际选择用于屏蔽的令牌的损失。其余的令牌应替换为索引-100

有关更多详细信息,请参阅文档。

相关内容

  • 没有找到相关文章

最新更新