无法加载 pytorch 模型进行评估
unable to load pytorch model for evaluation
我保存了一个 .pth
模型,我正在尝试加载以使用以下代码进行推理
model = GatherModel()
model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))
我收到如下所示的错误。为什么我得到这个。
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-3bff0e426886> in <module>()
----> 1 model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1405 if len(error_msgs) > 0:
1406 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407 self.__class__.__name__, "\n\t".join(error_msgs)))
1408 return _IncompatibleKeys(missing_keys, unexpected_keys)
1409
RuntimeError: Error(s) in loading state_dict for GatherModel:
Missing key(s) in state_dict: "lin0.weight", "lin0.bias", "set2set.lstm.weight_ih_l0", "set2set.lstm.weight_hh_l0", "set2set.lstm.bias_ih_l0", "set2set.lstm.bias_hh_l0", "message_layer.weight", "message_layer.bias", "conv.bias", "conv.edge_func.0.weight", "conv.edge_func.0.bias", "conv.edge_func.2.weight", "conv.edge_func.2.bias".
Unexpected key(s) in state_dict: "solute_pass.U_0.weight", "solute_pass.U_0.bias", "solute_pass.U_1.weight", "solute_pass.U_1.bias", "solute_pass.U_2.weight", "solute_pass.U_2.bias", "solute_pass.M_0.weight", "solute_pass.M_0.bias", "solute_pass.M_1.weight", "solute_pass.M_1.bias", "solute_pass.M_2.weight", "solute_pass.M_2.bias", "solvent_pass.U_0.weight", "solvent_pass.U_0.bias", "solvent_pass.U_1.weight", "solvent_pass.U_1.bias", "solvent_pass.U_2.weight", "solvent_pass.U_2.bias", "solvent_pass.M_0.weight", "solvent_pass.M_0.bias", "solvent_pass.M_1.weight", "solvent_pass.M_1.bias", "solvent_pass.M_2.weight", "solvent_pass.M_2.bias", "lstm_solute.weight_ih_l0", "lstm_solute.weight_hh_l0", "lstm_solute.bias_ih_l0", "lstm_solute.bias_hh_l0", "lstm_solvent.weight_ih_l0", "lstm_solvent.weight_hh_l0", "lstm_solvent.bias_ih_l0", "lstm_solvent.bias_hh_l0", "lstm_gather_solute.weight_ih_l0", "lstm_gather_solute.weight_hh_l0", "lstm_gather_solute.bias_ih_l0", "lstm_gather_solute.bias_hh_l0", "lstm_gather_solvent.weight_ih_l0", "lstm_gather_solvent.weight_hh_l0", "lstm_gather_solvent.bias_ih_l0", "lstm_gather_solvent.bias_hh_l0", "first_layer.weight", "first_layer.bias", "second_layer.weight", "second_layer.bias", "third_layer.weight", "third_layer.bias", "fourth_layer.weight", "fourth_layer.bias".
我曾尝试在 state_dict 中使用 strict=False
,但出现此错误
_IncompatibleKeys(missing_keys=['lin0.weight', 'lin0.bias', 'set2set.lstm.weight_ih_l0', 'set2set.lstm.weight_hh_l0', 'set2set.lstm.bias_ih_l0', 'set2set.lstm.bias_hh_l0', 'message_layer.weight', 'message_layer.bias', 'conv.bias', 'conv.edge_func.0.weight', 'conv.edge_func.0.bias', 'conv.edge_func.2.weight', 'conv.edge_func.2.bias'], unexpected_keys=['solute_pass.U_0.weight', 'solute_pass.U_0.bias', 'solute_pass.U_1.weight', 'solute_pass.U_1.bias', 'solute_pass.U_2.weight', 'solute_pass.U_2.bias', 'solute_pass.M_0.weight', 'solute_pass.M_0.bias', 'solute_pass.M_1.weight', 'solute_pass.M_1.bias', 'solute_pass.M_2.weight', 'solute_pass.M_2.bias', 'solvent_pass.U_0.weight', 'solvent_pass.U_0.bias', 'solvent_pass.U_1.weight', 'solvent_pass.U_1.bias', 'solvent_pass.U_2.weight', 'solvent_pass.U_2.bias', 'solvent_pass.M_0.weight', 'solvent_pass.M_0.bias', 'solvent_pass.M_1.weight', 'solvent_pass.M_1.bias', 'solvent_pass.M_2.weight', 'solvent_pass.M_2.bias', 'lstm_solute.weight_ih_l0', 'lstm_solute.weight_hh_l0', 'lstm_solute.bias_ih_l0', 'lstm_solute.bias_hh_l0', 'lstm_solvent.weight_ih_l0', 'lstm_solvent.weight_hh_l0', 'lstm_solvent.bias_ih_l0', 'lstm_solvent.bias_hh_l0', 'lstm_gather_solute.weight_ih_l0', 'lstm_gather_solute.weight_hh_l0', 'lstm_gather_solute.bias_ih_l0', 'lstm_gather_solute.bias_hh_l0', 'lstm_gather_solvent.weight_ih_l0', 'lstm_gather_solvent.weight_hh_l0', 'lstm_gather_solvent.bias_ih_l0', 'lstm_gather_solvent.bias_hh_l0', 'first_layer.weight', 'first_layer.bias', 'second_layer.weight', 'second_layer.bias', 'third_layer.weight', 'third_layer.bias', 'fourth_layer.weight', 'fourth_layer.bias'])
该错误基本上是说您正在使用的架构定义的权重不在 state_dict
中,还有一些权重不是由架构定义的,但存在于 state_dict
。您确定 GatherModel()
定义的架构与最初创建 state_dict
的架构相同吗?因为这个错误说明答案是否定的
我保存了一个 .pth
模型,我正在尝试加载以使用以下代码进行推理
model = GatherModel()
model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))
我收到如下所示的错误。为什么我得到这个。
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-3bff0e426886> in <module>()
----> 1 model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1405 if len(error_msgs) > 0:
1406 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407 self.__class__.__name__, "\n\t".join(error_msgs)))
1408 return _IncompatibleKeys(missing_keys, unexpected_keys)
1409
RuntimeError: Error(s) in loading state_dict for GatherModel:
Missing key(s) in state_dict: "lin0.weight", "lin0.bias", "set2set.lstm.weight_ih_l0", "set2set.lstm.weight_hh_l0", "set2set.lstm.bias_ih_l0", "set2set.lstm.bias_hh_l0", "message_layer.weight", "message_layer.bias", "conv.bias", "conv.edge_func.0.weight", "conv.edge_func.0.bias", "conv.edge_func.2.weight", "conv.edge_func.2.bias".
Unexpected key(s) in state_dict: "solute_pass.U_0.weight", "solute_pass.U_0.bias", "solute_pass.U_1.weight", "solute_pass.U_1.bias", "solute_pass.U_2.weight", "solute_pass.U_2.bias", "solute_pass.M_0.weight", "solute_pass.M_0.bias", "solute_pass.M_1.weight", "solute_pass.M_1.bias", "solute_pass.M_2.weight", "solute_pass.M_2.bias", "solvent_pass.U_0.weight", "solvent_pass.U_0.bias", "solvent_pass.U_1.weight", "solvent_pass.U_1.bias", "solvent_pass.U_2.weight", "solvent_pass.U_2.bias", "solvent_pass.M_0.weight", "solvent_pass.M_0.bias", "solvent_pass.M_1.weight", "solvent_pass.M_1.bias", "solvent_pass.M_2.weight", "solvent_pass.M_2.bias", "lstm_solute.weight_ih_l0", "lstm_solute.weight_hh_l0", "lstm_solute.bias_ih_l0", "lstm_solute.bias_hh_l0", "lstm_solvent.weight_ih_l0", "lstm_solvent.weight_hh_l0", "lstm_solvent.bias_ih_l0", "lstm_solvent.bias_hh_l0", "lstm_gather_solute.weight_ih_l0", "lstm_gather_solute.weight_hh_l0", "lstm_gather_solute.bias_ih_l0", "lstm_gather_solute.bias_hh_l0", "lstm_gather_solvent.weight_ih_l0", "lstm_gather_solvent.weight_hh_l0", "lstm_gather_solvent.bias_ih_l0", "lstm_gather_solvent.bias_hh_l0", "first_layer.weight", "first_layer.bias", "second_layer.weight", "second_layer.bias", "third_layer.weight", "third_layer.bias", "fourth_layer.weight", "fourth_layer.bias".
我曾尝试在 state_dict 中使用 strict=False
,但出现此错误
_IncompatibleKeys(missing_keys=['lin0.weight', 'lin0.bias', 'set2set.lstm.weight_ih_l0', 'set2set.lstm.weight_hh_l0', 'set2set.lstm.bias_ih_l0', 'set2set.lstm.bias_hh_l0', 'message_layer.weight', 'message_layer.bias', 'conv.bias', 'conv.edge_func.0.weight', 'conv.edge_func.0.bias', 'conv.edge_func.2.weight', 'conv.edge_func.2.bias'], unexpected_keys=['solute_pass.U_0.weight', 'solute_pass.U_0.bias', 'solute_pass.U_1.weight', 'solute_pass.U_1.bias', 'solute_pass.U_2.weight', 'solute_pass.U_2.bias', 'solute_pass.M_0.weight', 'solute_pass.M_0.bias', 'solute_pass.M_1.weight', 'solute_pass.M_1.bias', 'solute_pass.M_2.weight', 'solute_pass.M_2.bias', 'solvent_pass.U_0.weight', 'solvent_pass.U_0.bias', 'solvent_pass.U_1.weight', 'solvent_pass.U_1.bias', 'solvent_pass.U_2.weight', 'solvent_pass.U_2.bias', 'solvent_pass.M_0.weight', 'solvent_pass.M_0.bias', 'solvent_pass.M_1.weight', 'solvent_pass.M_1.bias', 'solvent_pass.M_2.weight', 'solvent_pass.M_2.bias', 'lstm_solute.weight_ih_l0', 'lstm_solute.weight_hh_l0', 'lstm_solute.bias_ih_l0', 'lstm_solute.bias_hh_l0', 'lstm_solvent.weight_ih_l0', 'lstm_solvent.weight_hh_l0', 'lstm_solvent.bias_ih_l0', 'lstm_solvent.bias_hh_l0', 'lstm_gather_solute.weight_ih_l0', 'lstm_gather_solute.weight_hh_l0', 'lstm_gather_solute.bias_ih_l0', 'lstm_gather_solute.bias_hh_l0', 'lstm_gather_solvent.weight_ih_l0', 'lstm_gather_solvent.weight_hh_l0', 'lstm_gather_solvent.bias_ih_l0', 'lstm_gather_solvent.bias_hh_l0', 'first_layer.weight', 'first_layer.bias', 'second_layer.weight', 'second_layer.bias', 'third_layer.weight', 'third_layer.bias', 'fourth_layer.weight', 'fourth_layer.bias'])
该错误基本上是说您正在使用的架构定义的权重不在 state_dict
中,还有一些权重不是由架构定义的,但存在于 state_dict
。您确定 GatherModel()
定义的架构与最初创建 state_dict
的架构相同吗?因为这个错误说明答案是否定的