手写文字(图片)预训练模型预测-Pytorch
Prediction for pretrained model on handwritten text(images)-Pytorch
我在使用预训练模型进行预测时遇到问题,该模型包含用于手写文本识别的编码器和解码器。
我所做的是:
checkpoint = torch.load("Model/SPAN/SPAN-PT-RA_rimes.pt",map_location=torch.device('cpu'))
encoder_state_dict = checkpoint['encoder_state_dict']
decoder_state_dict = checkpoint['decoder_state_dict']
img = torch.LongTensor(img).unsqueeze(1).to(torch.device('cpu'))
global_pred = decoder_state_dict(encoder_state_dict(img))
这会生成此错误:
TypeError: 'collections.OrderedDict' object is not callable
非常感谢您的帮助! ^_^
encoder_state_dict
和 decoder_state_dict
不是火炬模型,而是张量的集合(字典),其中包含您加载的检查点的预训练参数。
将输入(例如您转换后的输入图像)提供给这样的张量集合没有意义。事实上,您应该使用这些 stat_dicts(即预训练张量的集合)将它们加载到映射到网络的模型对象的参数中。参见 torch.nn.Module
class。
我在使用预训练模型进行预测时遇到问题,该模型包含用于手写文本识别的编码器和解码器。 我所做的是:
checkpoint = torch.load("Model/SPAN/SPAN-PT-RA_rimes.pt",map_location=torch.device('cpu'))
encoder_state_dict = checkpoint['encoder_state_dict']
decoder_state_dict = checkpoint['decoder_state_dict']
img = torch.LongTensor(img).unsqueeze(1).to(torch.device('cpu'))
global_pred = decoder_state_dict(encoder_state_dict(img))
这会生成此错误:
TypeError: 'collections.OrderedDict' object is not callable
非常感谢您的帮助! ^_^
encoder_state_dict
和 decoder_state_dict
不是火炬模型,而是张量的集合(字典),其中包含您加载的检查点的预训练参数。
将输入(例如您转换后的输入图像)提供给这样的张量集合没有意义。事实上,您应该使用这些 stat_dicts(即预训练张量的集合)将它们加载到映射到网络的模型对象的参数中。参见 torch.nn.Module
class。