手写文字(图片)预训练模型预测-Pytorch

Prediction for pretrained model on handwritten text(images)-Pytorch

我在使用预训练模型进行预测时遇到问题,该模型包含用于手写文本识别的编码器和解码器。 我所做的是:

checkpoint = torch.load("Model/SPAN/SPAN-PT-RA_rimes.pt",map_location=torch.device('cpu'))
encoder_state_dict = checkpoint['encoder_state_dict']
decoder_state_dict = checkpoint['decoder_state_dict']

img = torch.LongTensor(img).unsqueeze(1).to(torch.device('cpu'))
global_pred = decoder_state_dict(encoder_state_dict(img))

这会生成此错误:

TypeError: 'collections.OrderedDict' object is not callable

非常感谢您的帮助! ^_^

encoder_state_dict decoder_state_dict 不是火炬模型,而是张量的集合(字典),其中包含您加载的检查点的预训练参数。

将输入(例如您转换后的输入图像)提供给这样的张量集合没有意义。事实上,您应该使用这些 stat_dicts(即预训练张量的集合)将它们加载到映射到网络的模型对象的参数中。参见 torch.nn.Module class。