BERT 如何建模 select 标签排序?

How does the BERT model select the label ordering?

我正在为分类任务训练 BertForSequenceClassification。我的数据集由 'contains adverse effect' (1) 和 'does not contain adverse effect' (0) 组成。数据集包含所有 1,然后是 0(数据未打乱)。为了进行培训,我洗牌了我的数据并获得了逻辑。据我了解,logits 是 softmax 之前的概率分布。一个示例 logit 是 [-4.673831, 4.7095485]。第一个值是否对应于标签 1(包含 AE),因为它首先出现在数据集中,或者标签 0。任何帮助将不胜感激。

第一个值对应于标签 0,第二个值对应于标签 1。BertForSequenceClassification 所做的是将 pooler 的输出馈送到线性层(在我将在这个答案中忽略的 dropout 之后) ).让我们看下面的例子:

from torch import nn
from transformers import BertModel, BertTokenizer
t = BertTokenizer.from_pretrained('bert-base-uncased')
m = BertModel.from_pretrained('bert-base-uncased')
i = t.encode_plus('This is an example.', return_tensors='pt')
o = m(**i)
print(o.pooler_output.shape)

输出:

torch.Size([1, 768])

pooled_output 是形状为 [batch_size,hidden_size] 的张量,表示输入序列的语境化(即应用注意力)[CLS] 标记。该张量被馈送到线性层以计算序列的 logits

classificationLayer = nn.Linear(768,2)
logits = classificationLayer(o.pooler_output)

当我们对这些 logits 进行归一化时,我们可以看到线性层预测我们的输入应该属于标签 1:

print(nn.functional.softmax(logits,dim=-1))

输出(会有所不同,因为线性层是随机初始化的):

tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)

线性层应用线性变换:y=xA^T+b 并且您已经可以看到线性层不知道您的标签。它 'only' 有一个大小为 [2,768] 的权重矩阵来产生大小为 [1,2] 的对数(即:第一行对应于第一个值,第二行对应于第二个值):

import torch:

logitsOwnCalculation = torch.matmul(o.pooler_output,  classificationLayer.weight.transpose(0,1))+classificationLayer.bias
print(nn.functional.softmax(logitsOwnCalculation,dim=-1))

输出:

tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)

BertForSequenceClassification 模型通过应用 CrossEntropyLoss 进行学习。当某个 class (在您的情况下为标签)的 logits 仅略微偏离预期时,此损失函数会产生较小的损失。这意味着 CrossEntropyLoss 可以让您的模型了解到第一个 logit 在输入 does not contain adverse effect 时应该很高,或者在输入 contains adverse effect 时应该很小。您可以使用以下内容检查我们的示例:

loss_fct = nn.CrossEntropyLoss()
label0 = torch.tensor([0]) #does not contain adverse effect
label1 = torch.tensor([1]) #contains adverse effect
print(loss_fct(logits, label0))
print(loss_fct(logits, label1))

输出:

tensor(1.7845, grad_fn=<NllLossBackward>)
tensor(0.1838, grad_fn=<NllLossBackward>)