Python:BERT 模型池错误 - mean() 收到无效的参数组合 - 得到(str,int)

Python: BERT Model Pooling Error - mean() received an invalid combination of arguments - got (str, int)

我正在编写代码以在我的数据集上训练 bert 模型。当我 运行 代码时,它会在平均池层中抛出错误。我无法理解导致此错误的原因。

型号

class BERTBaseUncased(nn.Module):
    def __init__(self, bert_path):
        super(BERTBaseUncased, self).__init__()
        self.bert_path = bert_path
        self.bert = transformers.BertModel.from_pretrained(self.bert_path)
        self.bert_drop = nn.Dropout(0.3)
        self.out = nn.Linear(768 * 2, 1)

    def forward(
            self,
            ids,
            mask,
            token_type_ids
    ):
        o1, _ = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids)
        
        apool = torch.mean(o1, 1)
        mpool, _ = torch.max(o1, 1)
        cat = torch.cat((apool, mpool), 1)

        bo = self.bert_drop(cat)
        p2 = self.out(bo)
        return p2

错误

Exception in device=TPU:0: mean() received an invalid combination of arguments - got (str, int), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 228, in _start_fn
    fn(gindex, *args)
  File "<ipython-input-12-94e926c1f4df>", line 4, in _mp_fn
    a = _run()
  File "<ipython-input-5-ef9fa564682f>", line 146, in _run
    train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=scheduler)
  File "<ipython-input-5-ef9fa564682f>", line 22, in train_loop_fn
    token_type_ids=token_type_ids
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-11-9196e0d23668>", line 73, in forward
    apool = torch.mean(o1, 1)
TypeError: mean() received an invalid combination of arguments - got (str, int), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)

我正在尝试 运行 在 Kaggle TPU 上执行此操作。如何解决这个问题?

自从其中一个 3.X 更新后,模型 return 现在是任务特定的输出对象(字典)而不是普通元组。您可以通过指定 return_dict=False:

强制模型 return 一个元组
o1, _ = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids,
            return_dict=False)

或利用 basemodeloutputwithpoolingandcrossattentions 对象:

o = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids)
#you can view the other attributes with o.keys()
o1 = o.last_hidden_state