为什么 huggingface tokenizer return 只有 1 个 `input_ids` 而不是 3 个?

Why does huggingface tokenizer return only 1 `input_ids` instead of 3?

我正在尝试在 huggingface tutorial:

之后标记 squad 数据集
from datasets import load_dataset
from transformers import RobertaTokenizer
from transformers import logging
logging.set_verbosity_error()

dataset = load_dataset('squad')
checkpoint = 'roberta-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)

def tokenize_function(example):
    return tokenizer(example['question'], example['context'], [d['text'][0] for d in example['answers']], truncation=True) 

tokenized_datasets = dataset['train'].map(tokenize_function, batched=True)

但是当我打印时

tokenized_datasets

我明白了

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'attention_mask'],
    num_rows: 87599
})

但是这个return 3 input_ids,一个是问题,一个是上下文,一个是答案,不应该吗?

是不是那行代码:

tokenizer(example['question'], example['context'], [d['text'][0] for d in example['answers']], truncation=True)

在课程中显示?

分词器通过其 __call__ 方法 (documentation) 接受大量参数。由于您只指定了 truncation 的名称,其他参数值由它们的位置决定。这意味着,您正在执行:

tokenizer(text=example['question'], text_pair=example['context'], add_special_tokens=[d['text'][0] for d in example['answers']], truncation=True) 

执行代码后,带有 id 5733be284776f41900661182 的示例变为:

{'id': '5733be284776f41900661182', 
'title': 'University_of_Notre_Dame', 
'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 
'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 
'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}, 
'input_ids': [0, 3972, 2661, 222, 5, 9880, 2708, 2346, 2082, 11, 504, 4432, 11, 226, 2126, 10067, 1470, 116, 2, 2, 37848, 37471, 28108, 6, 5, 334, 34, 10, 4019, 2048, 4, 497, 1517, 5, 4326, 6919, 18, 1637, 31346, 16, 10, 9030, 9577, 9, 5, 9880, 2708, 4, 29261, 11, 760, 9, 5, 4326, 6919, 8, 2114, 24, 6, 16, 10, 7621, 9577, 9, 4845, 19, 3701, 62, 33161, 19, 5, 7875, 22, 39043, 1459, 1614, 1464, 13292, 4977, 845, 4130, 7, 5, 4326, 6919, 16, 5, 26429, 2426, 9, 5, 25095, 6924, 4, 29261, 639, 5, 32394, 2426, 16, 5, 7461, 26187, 6, 10, 19035, 317, 9, 9621, 8, 12456, 4, 85, 16, 10, 24633, 9, 5, 11491, 26187, 23, 226, 2126, 10067, 6, 1470, 147, 5, 9880, 2708, 2851, 13735, 352, 1382, 7, 6130, 6552, 625, 3398, 208, 22895, 853, 1827, 11, 504, 4432, 4, 497, 5, 253, 9, 5, 1049, 1305, 36, 463, 11, 10, 2228, 516, 14, 15230, 149, 155, 19638, 8, 5, 2610, 25336, 238, 16, 10, 2007, 6, 2297, 7326, 9577, 9, 2708, 4, 2], 
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

input_idstexttext_pair 的串联:

tokenizer.decode([0, 3972, 2661, 222, 5, 9880, 2708, 2346, 2082, 11, 504, 4432, 11, 226, 2126, 10067, 1470, 116, 2, 2, 37848, 37471, 28108, 6, 5, 334, 34, 10, 4019, 2048, 4, 497, 1517, 5, 4326, 6919, 18, 1637, 31346, 16, 10, 9030, 9577, 9, 5, 9880, 2708, 4, 29261, 11, 760, 9, 5, 4326, 6919, 8, 2114, 24, 6, 16, 10, 7621, 9577, 9, 4845, 19, 3701, 62, 33161, 19, 5, 7875, 22, 39043, 1459, 1614, 1464, 13292, 4977, 845, 4130, 7, 5, 4326, 6919, 16, 5, 26429, 2426, 9, 5, 25095, 6924, 4, 29261, 639, 5, 32394, 2426, 16, 5, 7461, 26187, 6, 10, 19035, 317, 9, 9621, 8, 12456, 4, 85, 16, 10, 24633, 9, 5, 11491, 26187, 23, 226, 2126, 10067, 6, 1470, 147, 5, 9880, 2708, 2851, 13735, 352, 1382, 7, 6130, 6552, 625, 3398, 208, 22895, 853, 1827, 11, 504, 4432, 4, 497, 5, 253, 9, 5, 1049, 1305, 36, 463, 11, 10, 2228, 516, 14, 15230, 149, 155, 19638, 8, 5, 2610, 25336, 238, 16, 10, 2007, 6, 2297, 7326, 9577, 9, 2708, 4, 2])

输出:

<s>To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?</s></s>Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.</s>

这是处理提取性 questions-answering 任务的常用方法。在这种情况下,答案不被视为输入,而只需要作为目标(即预测开始和结束位置)。

编辑: OP 在评论中指定了问题,并想知道如何返回三个文本实体的 input_ids:问题、上下文和答案。所有需要改变的是 tokenize_function 独立编码实体和 returns 字典:

from datasets import load_dataset
from transformers import RobertaTokenizer

dataset = load_dataset('squad')
checkpoint = 'roberta-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)

def tokenize_function(example):
    question_o = tokenizer(example['question'], truncation=True)
    context_o = tokenizer(example['context'], truncation=True)
    answer_o = tokenizer([d['text'][0] for d in example['answers']], truncation=True)

    return {"question_input_ids": question_o.input_ids, "question_attention_mask": question_o.attention_mask, "context_input_ids": context_o.input_ids, "context_attention_mask": context_o.attention_mask, "answer_input_ids": answer_o.input_ids, "answer_attention_mask": answer_o.attention_mask}

tokenized_datasets = dataset['train'].map(tokenize_function, batched=True)