Hugginface Transformers Bert Tokenizer语言 - 找出哪些文档被截断



我使用Huggingface的Transforms库来创建一个基于Bert的文本分类模型。为此,我对我的文档进行标记,并将截断设置为true,因为我的文档长度超过允许的长度(512)。

我怎样才能知道有多少文档实际上被截断了?我不认为长度(512)是文档的字符数或单词数,因为Tokenizer准备文档作为模型的输入。文档发生了什么,是否有一种直接的方法来检查它是否被截断?

这是我用来标记文档的代码。

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased") 
model = BertForSequenceClassification.from_pretrained("distilbert-base-multilingual-cased", num_labels=7)
train_encoded =  tokenizer(X_train, padding=True, truncation=True, return_tensors="pt")

如果你对我的代码或问题有任何疑问,请随时提问。

你的假设是正确的!

长度大于512的任何字符(假设您使用的是"distilbert-base-multilingual-cased")截断为truncation=True.

一个快速的解决方案是不截断和计算大于模型最大输入长度的示例:


train_encoded_no_trunc =  tokenizer(X_train, padding=True, truncation=False, return_tensors="pt")
count=0 
for doc in train_encoded_no_trunc.input_ids:
if(doc>0).sum()> tokenizer.model_max_length: 
count+=1
print("number of truncated docs: ",count)

最新更新