有没有办法获取在BERT中生成某个令牌的子串的位置?

Is there a way to get the location of the substring from which a certain token has been produced in BERT?

我正在向 BERT 模型(Hugging Face 库)输入句子。这些句子使用预训练的分词器进行分词。我知道您可以使用解码函数从标记返回到字符串。

string = tokenizer.decode(...)

然而,重建并不完美。如果您使用无大小写的预训练模型,大写字母会丢失。此外,如果分词器将一个词拆分为 2 个分词,则第二个分词将以“##”开头。例如,单词 'coronavirus' 被拆分为 2 个标记:'corona' 和“##virus”。

所以我的问题是:有没有办法获取创建每个标记的子字符串的索引? 例如,以字符串“Tokyo to report nearly 370 new coronavirus cases, set new single-day record”为例。第9个token是'virus'.

对应的token
['[CLS]', 'tokyo', 'to', 'report', 'nearly', '370', 'new', 'corona', '##virus', 'cases', ',', 'setting', 'new', 'single', '-', 'day', 'record', '[SEP]']

我想要一些东西告诉我标记“##virus”来自原始字符串中的 'virus' 子字符串,它位于原始字符串的索引 37 和 41 之间。

sentence = "Tokyo to report nearly 370 new coronavirus cases, setting new single-day record"
print(sentence[37:42]) # --> outputs 'virus

据我所知,他们没有 built-in 方法,但您可以自己创建一个:

import re
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

sentence = "Tokyo to report nearly 370 new coronavirus cases, setting new single-day record"

b = []
b.append(([101],))
for m in re.finditer(r'\S+', sentence):
  w = m.group(0)
  t = (tokenizer.encode(w, add_special_tokens=False), (m.start(), m.end()-1))

  b.append(t)

b.append(([102],))

b

输出:

[([101],),
 ([5522], (0, 4)),
 ([2000], (6, 7)),
 ([3189], (9, 14)),
 ([3053], (16, 21)),
 ([16444], (23, 25)),
 ([2047], (27, 29)),
 ([21887, 23350], (31, 41)),
 ([3572, 1010], (43, 48)),
 ([4292], (50, 56)),
 ([2047], (58, 60)),
 ([2309, 1011, 2154], (62, 71)),
 ([2501], (73, 78)),
 ([102],)]

我想更新答案。由于 HuggingFace 引入了他们(更快)的 Rust 编写的 Fast Tokenizers 版本,这项任务变得容易得多:

from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
sentence = "Tokyo to report nearly 370 new coronavirus cases, setting new single-day record"

encodings = tokenizer(sentence, return_offsets_mapping=True)
for token_id, pos in zip(encodings['input_ids'], encodings['offset_mapping']):
    print(token_id, pos, sentence[pos[0]:pos[1]])



101 (0, 0) 
5522 (0, 5) Tokyo
2000 (6, 8) to
3189 (9, 15) report
3053 (16, 22) nearly
16444 (23, 26) 370
2047 (27, 30) new
21887 (31, 37) corona
23350 (37, 42) virus
3572 (43, 48) cases
1010 (48, 49) ,
4292 (50, 57) setting
2047 (58, 61) new
2309 (62, 68) single
1011 (68, 69) -
2154 (69, 72) day
2501 (73, 79) record
102 (0, 0) 

不仅如此,如果您向分词器提供单词列表(并设置 is_split_into_words=True)而不是常规字符串,那么可以轻松区分每个单词的第一个和结果分词(第一个值元组的值为零),这是标记分类任务非常普遍的需求。