torch.argmax() 中的 TypeError 当想要找到具有最高“开始”分数的标记时

TypeError in torch.argmax() when want to find the tokens with the highest `start` score


import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question = '''Why was the student group called "the Methodists?"'''

paragraph = ''' The movement which would become The United Methodist Church began in the mid-18th century within the Church of England.
            A small group of students, including John Wesley, Charles Wesley and George Whitefield, met on the Oxford University campus.
            They focused on Bible study, methodical study of scripture and living a holy life.
            Other students mocked them, saying they were the "Holy Club" and "the Methodists", being methodical and exceptionally detailed in their Bible study, opinions and disciplined lifestyle.
            Eventually, the so-called Methodists started individual societies or classes for members of the Church of England who wanted to live a more religious life. '''
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(start_scores)


Exception has occurred: TypeError
argmax(): argument 'input' (position 1) must be Tensor, not str
  File "D:\bert\", line 33, in <module>
    start_index = torch.argmax(start_scores)


BertForQuestionAnswering return 一个 QuestionAnsweringModelOutput 对象。

由于您将 BertForQuestionAnswering 的输出设置为 start_scores, end_scores,因此 return QuestionAnsweringModelOutput 对象被强制转换为字符串元组 ('start_logits', 'end_logits'),从而导致类型不匹配错误。


outputs = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(outputs.start_logits)

Huggingface 变换器提供了一种简单的高级方法 运行 模型,如此 guide:

from transformers import pipeline

nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)
print(nlp(question=question, context=paragraph, topk=5))

topk 允许 select 几个得分最高的答案。