我使用Huggingface的Transforms库来创建一个基于Bert的文本分类模型。为此,我对我的文档进行标记,并将截断设置为true,因为我的文档长度超过允许的长度(512)。
我怎样才能知道有多少文档实际上被截断了?我不认为长度(512)是文档的字符数或单词数,因为Tokenizer准备文档作为模型的输入。文档发生了什么,是否有一种直接的方法来检查它是否被截断?
这是我用来标记文档的代码。
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
model = BertForSequenceClassification.from_pretrained("distilbert-base-multilingual-cased", num_labels=7)
train_encoded = tokenizer(X_train, padding=True, truncation=True, return_tensors="pt")
如果你对我的代码或问题有任何疑问,请随时提问。
你的假设是正确的!
长度大于512的任何字符(假设您使用的是"distilbert-base-multilingual-cased")截断为truncation=True
.
一个快速的解决方案是不截断和计算大于模型最大输入长度的示例:
train_encoded_no_trunc = tokenizer(X_train, padding=True, truncation=False, return_tensors="pt")
count=0
for doc in train_encoded_no_trunc.input_ids:
if(doc>0).sum()> tokenizer.model_max_length:
count+=1
print("number of truncated docs: ",count)