相关代码:
from transformers import (
AdamW,
MT5ForConditionalGeneration,
AutoTokenizer,
get_linear_schedule_with_warmup
)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-small', use_fast=True)
model=MT5ForConditionalGeneration.from_pretrained("working/result/",
return_dict=True)
def generate(text):
model.eval()
# print(model)
# input_ids = tokenizer.encode("WebNLG:{} </s>".format(text),
# return_tensors="pt")
input_ids = tokenizer.batch_encode_plus(
[text], max_length=512, pad_to_max_length=True, return_tensors="pt"
).to(device)
source_ids = input_ids["input_ids"].squeeze()
print(tokenizer.decode(source_ids))
print(type(input_ids.input_ids))
input_ids.input_ids.to(device)
print(input)
outputs = model.generate(input_ids.input_ids)
print(outputs)
print(outputs[0])
return tokenizer.decode(outputs[0])
调用上面的函数
input_str = "Title: %s Category: %s" % ("10 Min Quick Food Recipe","Food")
input_str = "Title: %s Category: %s" % ("I am marathon runner and going to run 21km on 4th dec in Thane","Fitness")
print(input_str)
print(generate(input_str))
输出:
Title: I am marathon runner and going to run 21km on 4th dec in Thane Category: Fitness
Title: I am marathon runner and going to run 21km on 4th dec in Thane Category: Fitness</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
<class 'torch.Tensor'>
<bound method Kernel.raw_input of <ipykernel.ipkernel.IPythonKernel object at 0x7ff645eed970>>
tensor([[ 0, 259, 266, 259, 3659, 390, 259, 262, 48580, 288,
259, 262, 38226, 5401, 259, 1]], device='cuda:0')
tensor([ 0, 259, 266, 259, 3659, 390, 259, 262, 48580, 288,
259, 262, 38226, 5401, 259, 1], device='cuda:0')
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In [30], line 5
2 input_str = "Title: %s Category: %s" % ("I am marathon runner and going to run 21km on 4th dec in Thane","Fitness")
4 print(input_str)
----> 5 print(generate(input_str))
Cell In [29], line 18, in generate(text)
16 print(outputs)
17 print(outputs[0])
---> 18 return tokenizer.decode(outputs[0])
File ~/T5/t5_venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3436, in PreTrainedTokenizerBase.decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
3433 # Convert inputs to python lists
3434 token_ids = to_py_obj(token_ids)
-> 3436 return self._decode(
3437 token_ids=token_ids,
3438 skip_special_tokens=skip_special_tokens,
3439 clean_up_tokenization_spaces=clean_up_tokenization_spaces,
3440 **kwargs,
3441 )
File ~/T5/t5_venv/lib/python3.8/site-packages/transformers/tokenization_utils.py:949, in PreTrainedTokenizer._decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, spaces_between_special_tokens, **kwargs)
947 current_sub_text.append(token)
948 if current_sub_text:
--> 949 sub_texts.append(self.convert_tokens_to_string(current_sub_text))
951 if spaces_between_special_tokens:
952 text = " ".join(sub_texts)
File ~/T5/t5_venv/lib/python3.8/site-packages/transformers/models/byt5/tokenization_byt5.py:243, in ByT5Tokenizer.convert_tokens_to_string(self, tokens)
241 tok_string = token.encode("utf-8")
242 else:
--> 243 tok_string = bytes([ord(token)])
244 bstring += tok_string
245 string = bstring.decode("utf-8", errors="ignore")
ValueError: bytes must be in range(0, 256)
我试图将max_length参数更改为256,但似乎无法使其工作。任何线索高度赞赏。提前感谢。
明白了。我犯了一个愚蠢的错误。我尝试了不同的预训练标记器和T5模型。
在训练期间,我使用了google/mt5-base
,但在推理期间,我使用了google/byt5-small
,这就产生了这个问题。改回google/mt5-base
修复问题。现在推理工作正常。