Huggingface Transformers Bert Tokenizer - 找出哪些文档被截断

Hugginface Transformers Bert Tokenizer - Find out which documents get truncated

我正在使用 Huggingface 的 Transforms 库来创建基于 Bert 的文本分类模型。为此,我标记了我的文档,并将截断设置为 true,因为我的文档比允许的 (512) 长。

我如何才能知道有多少文档实际被截断了?我不认为长度 (512) 是文档的字符数或字数,因为 Tokenizer 准备文档作为模型的输入。文档发生了什么变化,是否有直接的方法来检查它是否被截断?

这是我用来标记文档的代码。

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased") 
model = BertForSequenceClassification.from_pretrained("distilbert-base-multilingual-cased", num_labels=7)
train_encoded =  tokenizer(X_train, padding=True, truncation=True, return_tensors="pt")

如果您对我的代码或问题有任何疑问,请随时提问。

你的假设是正确的!

长度大于 512 的任何内容(假设您使用的是“distilbert-base-multilingual-cased”)都会被 truncation=True.

截断

一个快速的解决方案是不截断和计算大于模型最大输入长度的示例:


train_encoded_no_trunc =  tokenizer(X_train, padding=True, truncation=False, return_tensors="pt")

count=0 

for doc in train_encoded_no_trunc.input_ids:
    if(doc>0).sum()> tokenizer.model_max_length: 
        count+=1
print("number of truncated docs: ",count)