令牌索引序列长度长于此模型指定的最大序列长度 (28627 > 512)



我使用BERT的Huggingface DistilBERT模型作为问答应用程序的后端。我用来训练模型的文本是一个非常大的单一文本字段。即使文本字段是单个字符串,标点符号也保留在适当的位置,作为BERT的线索。当我执行应用程序时,我得到了">Token索引序列长度错误"。我使用transformer.encodeplus()方法将文本传递到模型中。我已经尝试了各种机制来截断输入id到长度<= 512。我目前使用的是Windows 10,但我也会将代码移植到树莓派4平台。

代码在这一行出错:

start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))

我试图在这行执行截断:

encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)

完整代码在这里:

from transformers import AutoTokenizer, DistilBertTokenizer, DistilBertForQuestionAnswering
import torch
# globals - set once used everywhere
tokenizer = None
model = None
context = ''
def establishSettings():
global tokenizer, model, context
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True, model_max_length=512)
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad', return_dict=False)
# context = "Some 1,500 volcanoes are still considered potentially active around the world today 161 of those over 10 percent sit within the boundaries of the United States."
# get the volcano corpus
with open('volcanic.corpus', encoding="utf8") as file:
context = file.read().replace('n', '')
print(len(tokenizer(context, truncation=True).input_ids))

def askQuestion(question):
global tokenizer, model, context
print("nQuestion ", question)
encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
ans_tokens = input_ids[torch.argmax(start_scores): torch.argmax(end_scores) + 1]
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
#all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
return answer_tokens

def main():
# set the global itmes once
establishSettings()
# ask a question
question = "How many potentially active volcanoes are there in the world today?"
answer_tokens = askQuestion(question)
print("answer_tokens: ", answer_tokens)
if len(answer_tokens) == 0:
answer = "Sorry, I don't have an answer for that  one.  Ask me another question about New Mexico volcanoes."
print(answer)
else:
answer_tokens_to_string = tokenizer.convert_tokens_to_string(answer_tokens)
print("nFinal Answer : ")
print(answer_tokens_to_string)
if __name__ == '__main__':
main()

input.ids截断为长度为<= 512的最佳方法是什么?

编辑这一行:

encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)

encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True, max_length=512).input_ids)

最新更新