huggingface 变形金刚:encode_plus 中的截断策略
huggingface transformers: truncation strategy in encode_plus
huggingface 的变形金刚库中的 encode_plus
允许截断输入序列。两个相关参数:truncation
和 max_length
。我将配对的输入序列传递给 encode_plus
并且需要以“切断”的方式简单地截断输入序列,即,如果整个序列由输入 text
和 text_pair
比 max_length
长,应该从右边相应地截断它。
似乎这两种截断策略都不允许这样做,而是 longest_first
从最长的序列(可以是文本或 text_pair 中删除标记,但不仅仅是从右边或序列的结尾,例如,如果文本比 text_pair 长,这似乎会首先从文本中删除标记),only_first
和 only_second
仅从第一个或第二个中删除标记(因此,也不只是从末尾开始),并且 do_not_truncate
根本不会截断。还是我误解了这一点,实际上 longest_first
可能就是我要找的东西?
否 longest_first
与 cut from the right
不同。当您将截断策略设置为 longest_first
时,标记器将在每次需要删除标记时比较 text
和 text_pair
的长度,并从最长的标记中删除一个标记。例如,这可能意味着它将首先从 text_pair
中删除 3 个标记,并将从 text
和 text_pair
中删除需要交替删除的其余标记。一个例子:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
seq1 = 'This is a long uninteresting text'
seq2 = 'What could be a second sequence to the uninteresting text'
print(len(tokenizer.tokenize(seq1)))
print(len(tokenizer.tokenize(seq2)))
print(tokenizer(seq1, seq2))
print(tokenizer(seq1, seq2, truncation= True, max_length = 15))
print(tokenizer.decode(tokenizer(seq1, seq2, truncation= True, max_length = 15)['input_ids']))
输出:
9
13
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 1037, 2117, 5537, 2000, 1996, 4895, 18447, 18702, 3436, 3793, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 102, 2054, 2071, 2022, 1037, 2117, 5537, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[CLS] this is a long unint [SEP] what could be a second sequence [SEP]
据我从你的问题中可以看出你实际上是在寻找 only_second
因为它从右边切入(即 text_pair
):
print(tokenizer(seq1, seq2, truncation= 'only_second', max_length = 15))
输出:
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
当您尝试 text
输入的长度达到指定的 max_length 时,它会抛出异常。我认为这是正确的,因为在这种情况下,它不再是序列对输入。
以防万一 only_second
不符合您的要求,您可以简单地创建自己的截断策略。以手工为例only_second
:
tok_seq1 = tokenizer.tokenize(seq1)
tok_seq2 = tokenizer.tokenize(seq2)
maxLengthSeq2 = myMax_len - len(tok_seq1) - 3 #number of special tokens for bert sequence pair
if len(tok_seq2) > maxLengthSeq2:
tok_seq2 = tok_seq2[:maxLengthSeq2]
input_ids = [tokenizer.cls_token_id]
input_ids += tokenizer.convert_tokens_to_ids(tok_seq1)
input_ids += [tokenizer.sep_token_id]
token_type_ids = [0]*len(input_ids)
input_ids += tokenizer.convert_tokens_to_ids(tok_seq2)
input_ids += [tokenizer.sep_token_id]
token_type_ids += [1]*(len(tok_seq2)+1)
attention_mask = [1]*len(input_ids)
print(input_ids)
print(token_type_ids)
print(attention_mask)
输出:
[101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
encode_plus
允许截断输入序列。两个相关参数:truncation
和 max_length
。我将配对的输入序列传递给 encode_plus
并且需要以“切断”的方式简单地截断输入序列,即,如果整个序列由输入 text
和 text_pair
比 max_length
长,应该从右边相应地截断它。
似乎这两种截断策略都不允许这样做,而是 longest_first
从最长的序列(可以是文本或 text_pair 中删除标记,但不仅仅是从右边或序列的结尾,例如,如果文本比 text_pair 长,这似乎会首先从文本中删除标记),only_first
和 only_second
仅从第一个或第二个中删除标记(因此,也不只是从末尾开始),并且 do_not_truncate
根本不会截断。还是我误解了这一点,实际上 longest_first
可能就是我要找的东西?
否 longest_first
与 cut from the right
不同。当您将截断策略设置为 longest_first
时,标记器将在每次需要删除标记时比较 text
和 text_pair
的长度,并从最长的标记中删除一个标记。例如,这可能意味着它将首先从 text_pair
中删除 3 个标记,并将从 text
和 text_pair
中删除需要交替删除的其余标记。一个例子:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
seq1 = 'This is a long uninteresting text'
seq2 = 'What could be a second sequence to the uninteresting text'
print(len(tokenizer.tokenize(seq1)))
print(len(tokenizer.tokenize(seq2)))
print(tokenizer(seq1, seq2))
print(tokenizer(seq1, seq2, truncation= True, max_length = 15))
print(tokenizer.decode(tokenizer(seq1, seq2, truncation= True, max_length = 15)['input_ids']))
输出:
9
13
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 1037, 2117, 5537, 2000, 1996, 4895, 18447, 18702, 3436, 3793, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 102, 2054, 2071, 2022, 1037, 2117, 5537, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[CLS] this is a long unint [SEP] what could be a second sequence [SEP]
据我从你的问题中可以看出你实际上是在寻找 only_second
因为它从右边切入(即 text_pair
):
print(tokenizer(seq1, seq2, truncation= 'only_second', max_length = 15))
输出:
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
当您尝试 text
输入的长度达到指定的 max_length 时,它会抛出异常。我认为这是正确的,因为在这种情况下,它不再是序列对输入。
以防万一 only_second
不符合您的要求,您可以简单地创建自己的截断策略。以手工为例only_second
:
tok_seq1 = tokenizer.tokenize(seq1)
tok_seq2 = tokenizer.tokenize(seq2)
maxLengthSeq2 = myMax_len - len(tok_seq1) - 3 #number of special tokens for bert sequence pair
if len(tok_seq2) > maxLengthSeq2:
tok_seq2 = tok_seq2[:maxLengthSeq2]
input_ids = [tokenizer.cls_token_id]
input_ids += tokenizer.convert_tokens_to_ids(tok_seq1)
input_ids += [tokenizer.sep_token_id]
token_type_ids = [0]*len(input_ids)
input_ids += tokenizer.convert_tokens_to_ids(tok_seq2)
input_ids += [tokenizer.sep_token_id]
token_type_ids += [1]*(len(tok_seq2)+1)
attention_mask = [1]*len(input_ids)
print(input_ids)
print(token_type_ids)
print(attention_mask)
输出:
[101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]