为什么 Transformer 的 BERT(用于序列分类)输出严重依赖于最大序列长度填充?

Why does Transformer's BERT (for sequence classification) output depend heavily on maximum sequence length padding?

我正在使用 Transformer 的 RobBERT(RoBERTa 的荷兰语版本)进行序列分类 - 针对荷兰书评数据集的情感分析进行了训练。

我想测试它在类似数据集(以及情感分析)上的效果如何,所以我为一组文本片段做了注释并检查了它的准确性。当我检查哪种句子被错误分类时,我注意到一个独特句子的输出在很大程度上取决于我在标记化时给出的填充长度。请参阅下面的代码。

from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch.nn.functional as F
import torch


model = RobertaForSequenceClassification.from_pretrained("pdelobelle/robBERT-dutch-books", num_labels=2)
tokenizer = RobertaTokenizer.from_pretrained("pdelobelle/robBERT-dutch-books", do_lower_case=True)

sent = 'De samenwerking gaat de laatste tijd beter'
max_seq_len = 64


test_token = tokenizer(sent,
                        max_length = max_seq_len,
                        padding = 'max_length',
                        truncation = True,
                        return_tensors = 'pt'
                        )

out = model(test_token['input_ids'],test_token['attention_mask'])

probs = F.softmax(out[0], dim=1).detach().numpy()

对于给定的示例文本,用英语翻译为“The collaboration has been improving lately”,根据 max_seq_len,分类输出存在巨大差异。即,对于 max_seq_len = 64probs 的输出是:

[[0.99149346 0.00850648]]

而对于 max_seq_len = 9,是包括 cls 标记在内的实际长度:

[[0.00494814 0.9950519 ]]

谁能解释为什么会出现这种巨大的分类差异?我认为注意掩码确保输出中没有差异,因为填充到最大序列长度。

这是因为您的比较不正确造成的。句子 De samenwerking gaat de laatste tijd beter 实际上有 16 个标记(特殊标记为 +2)而不是 9 个。您只计算了不一定是标记的单词。

print(tokenizer.tokenize(sent))
print(len(tokenizer.tokenize(sent)))

输出:

['De', 'Ġsam', 'en', 'wer', 'king', 'Ġga', 'at', 'Ġde', 'Ġla', 'at', 'ste', 'Ġt', 'ij', 'd', 'Ġbe', 'ter']
16

当您将序列长度设置为 9 时,您会将句子截断为:

tokenizer.decode(tokenizer(sent,
                         max_length = 9,
                         padding = 'max_length',
                         truncation = True,
                         return_tensors = 'pt', 
                         add_special_tokens=False
                         )['input_ids'][0])

输出:

'De samenwerking gaat de la'

最后证明,将max_length设置为52时的输出也是[[0.99149346 0.00850648]]。