Huggingface Transformers Bert Tokenizer - 找出哪些文档被截断
Hugginface Transformers Bert Tokenizer - Find out which documents get truncated
我正在使用 Huggingface 的 Transforms 库来创建基于 Bert 的文本分类模型。为此,我标记了我的文档,并将截断设置为 true,因为我的文档比允许的 (512) 长。
我如何才能知道有多少文档实际被截断了?我不认为长度 (512) 是文档的字符数或字数,因为 Tokenizer 准备文档作为模型的输入。文档发生了什么变化,是否有直接的方法来检查它是否被截断?
这是我用来标记文档的代码。
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
model = BertForSequenceClassification.from_pretrained("distilbert-base-multilingual-cased", num_labels=7)
train_encoded = tokenizer(X_train, padding=True, truncation=True, return_tensors="pt")
如果您对我的代码或问题有任何疑问,请随时提问。
你的假设是正确的!
长度大于 512 的任何内容(假设您使用的是“distilbert-base-multilingual-cased”)都会被 truncation=True
.
截断
一个快速的解决方案是不截断和计算大于模型最大输入长度的示例:
train_encoded_no_trunc = tokenizer(X_train, padding=True, truncation=False, return_tensors="pt")
count=0
for doc in train_encoded_no_trunc.input_ids:
if(doc>0).sum()> tokenizer.model_max_length:
count+=1
print("number of truncated docs: ",count)
我正在使用 Huggingface 的 Transforms 库来创建基于 Bert 的文本分类模型。为此,我标记了我的文档,并将截断设置为 true,因为我的文档比允许的 (512) 长。
我如何才能知道有多少文档实际被截断了?我不认为长度 (512) 是文档的字符数或字数,因为 Tokenizer 准备文档作为模型的输入。文档发生了什么变化,是否有直接的方法来检查它是否被截断?
这是我用来标记文档的代码。
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
model = BertForSequenceClassification.from_pretrained("distilbert-base-multilingual-cased", num_labels=7)
train_encoded = tokenizer(X_train, padding=True, truncation=True, return_tensors="pt")
如果您对我的代码或问题有任何疑问,请随时提问。
你的假设是正确的!
长度大于 512 的任何内容(假设您使用的是“distilbert-base-multilingual-cased”)都会被 truncation=True
.
一个快速的解决方案是不截断和计算大于模型最大输入长度的示例:
train_encoded_no_trunc = tokenizer(X_train, padding=True, truncation=False, return_tensors="pt")
count=0
for doc in train_encoded_no_trunc.input_ids:
if(doc>0).sum()> tokenizer.model_max_length:
count+=1
print("number of truncated docs: ",count)