使用 BERT 示例输入错误
Type errors with BERT example
我是 BERT QA 模型的新手,我正在尝试遵循 this article 中的示例。问题是当我 运行 附加到示例的代码时,它会产生如下类型错误 TypeError: argmax(): argument 'input' (position 1) must be Tensor, not str
.
这是我试过的代码 运行ning :
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
question = '''SAMPLE QUESTION"'''
paragraph = '''SAMPLE PARAGRAPH'''
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph, add_special=True)
inputs = encoding['input_ids'] #Token embeddings
sentence_embedding = encoding['token_type_ids'] #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens
start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))
start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)
answer = ' '.join(tokens[start_index:end_index+1])
问题出现在这段代码的第 13 行,我试图获取 start_scores
中的最大元素,说这不是张量。当我尝试打印此变量时,它显示“start_logits”作为字符串。有谁知道这个问题的解决方案吗?
因此在参考 BERT Documentation 之后,我们确定模型输出对象包含多个属性,而不仅仅是开始和结束分数。因此,我们对代码应用了以下更改。
outputs = model(input_ids=torch.tensor([inputs]),token_type_ids=torch.tensor([sentence_embedding]))
start_index = torch.argmax(outputs.start_logits)
end_index = torch.argmax(outputs.end_logits)
answer = ' '.join(tokens[start_index:end_index+1])
总是先参考文档:"D
我是 BERT QA 模型的新手,我正在尝试遵循 this article 中的示例。问题是当我 运行 附加到示例的代码时,它会产生如下类型错误 TypeError: argmax(): argument 'input' (position 1) must be Tensor, not str
.
这是我试过的代码 运行ning :
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
question = '''SAMPLE QUESTION"'''
paragraph = '''SAMPLE PARAGRAPH'''
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph, add_special=True)
inputs = encoding['input_ids'] #Token embeddings
sentence_embedding = encoding['token_type_ids'] #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens
start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))
start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)
answer = ' '.join(tokens[start_index:end_index+1])
问题出现在这段代码的第 13 行,我试图获取 start_scores
中的最大元素,说这不是张量。当我尝试打印此变量时,它显示“start_logits”作为字符串。有谁知道这个问题的解决方案吗?
因此在参考 BERT Documentation 之后,我们确定模型输出对象包含多个属性,而不仅仅是开始和结束分数。因此,我们对代码应用了以下更改。
outputs = model(input_ids=torch.tensor([inputs]),token_type_ids=torch.tensor([sentence_embedding]))
start_index = torch.argmax(outputs.start_logits)
end_index = torch.argmax(outputs.end_logits)
answer = ' '.join(tokens[start_index:end_index+1])
总是先参考文档:"D