使用 BERT 示例输入错误

Type errors with BERT example

我是 BERT QA 模型的新手,我正在尝试遵循 this article 中的示例。问题是当我 运行 附加到示例的代码时,它会产生如下类型错误 TypeError: argmax(): argument 'input' (position 1) must be Tensor, not str.

这是我试过的代码 运行ning :

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question = '''SAMPLE QUESTION"'''

paragraph = '''SAMPLE PARAGRAPH'''
            
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph, add_special=True)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(start_scores)

end_index = torch.argmax(end_scores)

answer = ' '.join(tokens[start_index:end_index+1])

问题出现在这段代码的第 13 行,我试图获取 start_scores 中的最大元素,说这不是张量。当我尝试打印此变量时,它显示“start_logits”作为字符串。有谁知道这个问题的解决方案吗?

因此在参考 BERT Documentation 之后,我们确定模型输出对象包含多个属性,而不仅仅是开始和结束分数。因此,我们对代码应用了以下更改。


outputs = model(input_ids=torch.tensor([inputs]),token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(outputs.start_logits)

end_index = torch.argmax(outputs.end_logits)

answer = ' '.join(tokens[start_index:end_index+1])

总是先参考文档:"D