无法加载 pytorch 模型进行评估

unable to load pytorch model for evaluation

我保存了一个 .pth 模型,我正在尝试加载以使用以下代码进行推理

model = GatherModel()
model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))

我收到如下所示的错误。为什么我得到这个。

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-3bff0e426886> in <module>()
----> 1 model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1405         if len(error_msgs) > 0:
   1406             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1409 

RuntimeError: Error(s) in loading state_dict for GatherModel:
    Missing key(s) in state_dict: "lin0.weight", "lin0.bias", "set2set.lstm.weight_ih_l0", "set2set.lstm.weight_hh_l0", "set2set.lstm.bias_ih_l0", "set2set.lstm.bias_hh_l0", "message_layer.weight", "message_layer.bias", "conv.bias", "conv.edge_func.0.weight", "conv.edge_func.0.bias", "conv.edge_func.2.weight", "conv.edge_func.2.bias". 
    Unexpected key(s) in state_dict: "solute_pass.U_0.weight", "solute_pass.U_0.bias", "solute_pass.U_1.weight", "solute_pass.U_1.bias", "solute_pass.U_2.weight", "solute_pass.U_2.bias", "solute_pass.M_0.weight", "solute_pass.M_0.bias", "solute_pass.M_1.weight", "solute_pass.M_1.bias", "solute_pass.M_2.weight", "solute_pass.M_2.bias", "solvent_pass.U_0.weight", "solvent_pass.U_0.bias", "solvent_pass.U_1.weight", "solvent_pass.U_1.bias", "solvent_pass.U_2.weight", "solvent_pass.U_2.bias", "solvent_pass.M_0.weight", "solvent_pass.M_0.bias", "solvent_pass.M_1.weight", "solvent_pass.M_1.bias", "solvent_pass.M_2.weight", "solvent_pass.M_2.bias", "lstm_solute.weight_ih_l0", "lstm_solute.weight_hh_l0", "lstm_solute.bias_ih_l0", "lstm_solute.bias_hh_l0", "lstm_solvent.weight_ih_l0", "lstm_solvent.weight_hh_l0", "lstm_solvent.bias_ih_l0", "lstm_solvent.bias_hh_l0", "lstm_gather_solute.weight_ih_l0", "lstm_gather_solute.weight_hh_l0", "lstm_gather_solute.bias_ih_l0", "lstm_gather_solute.bias_hh_l0", "lstm_gather_solvent.weight_ih_l0", "lstm_gather_solvent.weight_hh_l0", "lstm_gather_solvent.bias_ih_l0", "lstm_gather_solvent.bias_hh_l0", "first_layer.weight", "first_layer.bias", "second_layer.weight", "second_layer.bias", "third_layer.weight", "third_layer.bias", "fourth_layer.weight", "fourth_layer.bias". 

我曾尝试在 state_dict 中使用 strict=False,但出现此错误

_IncompatibleKeys(missing_keys=['lin0.weight', 'lin0.bias', 'set2set.lstm.weight_ih_l0', 'set2set.lstm.weight_hh_l0', 'set2set.lstm.bias_ih_l0', 'set2set.lstm.bias_hh_l0', 'message_layer.weight', 'message_layer.bias', 'conv.bias', 'conv.edge_func.0.weight', 'conv.edge_func.0.bias', 'conv.edge_func.2.weight', 'conv.edge_func.2.bias'], unexpected_keys=['solute_pass.U_0.weight', 'solute_pass.U_0.bias', 'solute_pass.U_1.weight', 'solute_pass.U_1.bias', 'solute_pass.U_2.weight', 'solute_pass.U_2.bias', 'solute_pass.M_0.weight', 'solute_pass.M_0.bias', 'solute_pass.M_1.weight', 'solute_pass.M_1.bias', 'solute_pass.M_2.weight', 'solute_pass.M_2.bias', 'solvent_pass.U_0.weight', 'solvent_pass.U_0.bias', 'solvent_pass.U_1.weight', 'solvent_pass.U_1.bias', 'solvent_pass.U_2.weight', 'solvent_pass.U_2.bias', 'solvent_pass.M_0.weight', 'solvent_pass.M_0.bias', 'solvent_pass.M_1.weight', 'solvent_pass.M_1.bias', 'solvent_pass.M_2.weight', 'solvent_pass.M_2.bias', 'lstm_solute.weight_ih_l0', 'lstm_solute.weight_hh_l0', 'lstm_solute.bias_ih_l0', 'lstm_solute.bias_hh_l0', 'lstm_solvent.weight_ih_l0', 'lstm_solvent.weight_hh_l0', 'lstm_solvent.bias_ih_l0', 'lstm_solvent.bias_hh_l0', 'lstm_gather_solute.weight_ih_l0', 'lstm_gather_solute.weight_hh_l0', 'lstm_gather_solute.bias_ih_l0', 'lstm_gather_solute.bias_hh_l0', 'lstm_gather_solvent.weight_ih_l0', 'lstm_gather_solvent.weight_hh_l0', 'lstm_gather_solvent.bias_ih_l0', 'lstm_gather_solvent.bias_hh_l0', 'first_layer.weight', 'first_layer.bias', 'second_layer.weight', 'second_layer.bias', 'third_layer.weight', 'third_layer.bias', 'fourth_layer.weight', 'fourth_layer.bias'])

该错误基本上是说您正在使用的架构定义的权重不在 state_dict 中,还有一些权重不是由架构定义的,但存在于 state_dict。您确定 GatherModel() 定义的架构与最初创建 state_dict 的架构相同吗?因为这个错误说明答案是否定的