FastGAN - RuntimeError: Error(s) in loading state_dict for Generator

FastGAN - RuntimeError: Error(s) in loading state_dict for Generator

我在 Google Colab 上 运行 FastGAN (https://github.com/odegeasslbc/FastGAN-pytorch),现在正尝试从网络生成的已保存 .pth 恢复训练。但是,它不断抛出此错误:

Traceback (most recent call last):
  File "train.py", line 202, in 
    train(args)
  File "train.py", line 117, in train
    netG.load_state_dict(ckpt['g'])
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "init.init.0.weight_orig", "init.init.0.weight", "init.init.0.weight_u", "init.init.0.weight_orig", "init.init.0.weight_u", "init.init.0.weight_v", "init.init.1.weight", "init.init.1.bias", "init.init.1.running_mean", "init.init.1.running_var", "feat_8.1.weight_orig", "feat_8.1.weight", "feat_8.1.weight_u", "feat_8.1.weight_orig", "feat_8.1.weight_u", "feat_8.1.weight_v", "feat_8.2.weight", "feat_8.3.weight", "feat_8.3.bias", "feat_8.3.running_mean", "feat_8.3.running_var", "feat_8.5.weight_orig", "feat_8.5.weight", "feat_8.5.weight_u", "feat_8.5.weight_orig", "feat_8.5.weight_u", "feat_8.5.weight_v", "feat_8.6.weight", "feat_8.7.weight", "feat_8.7.bias", "feat_8.7.running_mean", "feat_8.7.running_var", "feat_16.1.weight_orig", "feat_16.1.weight", "feat_16.1.weight_u", "feat_16.1.weight_orig", "feat_16.1.weight_u", "feat_16.1.weight_v", "feat_16.2.weight", "feat_16.2.bias", "feat_16.2.running_mean", "feat_16.2.running_var", "feat_32.1.weight_orig", "feat_32.1.weight", "feat_32.1.weight_u", "feat_32.1.weight_orig", "feat_32.1.weight_u", "feat_32.1.weight_v", "feat_32.2.weight", "feat_32.3.weight", "feat_32.3.bias", "feat_32.3.running_mean", "feat_32.3.running_var", "feat_32.5.weight_orig", "feat_32.5.weight", "feat_32.5.weight_u", "feat_32.5.weight_orig", "feat_32.5.weight_u", "feat_32.5.weight_v", "feat_32.6.weight", "feat_32.7.weight", "feat_32.7.bias", "feat_32.7.running_mean", "feat_32.7.running_var", "feat_64.1.weight_orig", "feat_64.1.weight", "feat_64.1.weight_u", "feat_64.1.weight_orig", "feat_64.1.weight_u", "feat_64.1.weight_v", "feat_64.2.weight", "feat_64.2.bias", "feat_64.2.running_mean", "feat_64.2.running_var", "feat_128.1.weight_orig", "feat_128.1.weight", "feat_128.1.weight_u", "feat_128.1.weight_orig", "feat_128.1.weight_u", "feat_128.1.weight_v", "feat_128.2.weight", "feat_128.3.weight", "feat_128.3.bias", "feat_128.3.running_mean", "feat_128.3.running_var", "feat_128.5.weight_orig", "feat_128.5.weight", "feat_128.5.weight_u", "feat_128.5.weight_orig", "feat_128.5.weight_u", "feat_128.5.weight_v", "feat_128.6.weight", "feat_128.7.weight", "feat_128.7.bias", "feat_128.7.running_mean", "feat_128.7.running_var", "feat_256.1.weight_orig", "feat_256.1.weight", "feat_256.1.weight_u", "feat_256.1.weight_orig", "feat_256.1.weight_u", "feat_256.1.weight_v", "feat_256.2.weight", "feat_256.2.bias", "feat_256.2.running_mean", "feat_256.2.running_var", "se_64.main.1.weight_orig", "se_64.main.1.weight", "se_64.main.1.weight_u", "se_64.main.1.weight_orig", "se_64.main.1.weight_u", "se_64.main.1.weight_v", "se_64.main.3.weight_orig", "se_64.main.3.weight", "se_64.main.3.weight_u", "se_64.main.3.weight_orig", "se_64.main.3.weight_u", "se_64.main.3.weight_v", "se_128.main.1.weight_orig", "se_128.main.1.weight", "se_128.main.1.weight_u", "se_128.main.1.weight_orig", "se_128.main.1.weight_u", "se_128.main.1.weight_v", "se_128.main.3.weight_orig", "se_128.main.3.weight", "se_128.main.3.weight_u", "se_128.main.3.weight_orig", "se_128.main.3.weight_u", "se_128.main.3.weight_v", "se_256.main.1.weight_orig", "se_256.main.1.weight", "se_256.main.1.weight_u", "se_256.main.1.weight_orig", "se_256.main.1.weight_u", "se_256.main.1.weight_v", "se_256.main.3.weight_orig", "se_256.main.3.weight", "se_256.main.3.weight_u", "se_256.main.3.weight_orig", "se_256.main.3.weight_u", "se_256.main.3.weight_v", "to_128.weight_orig", "to_128.weight", "to_128.weight_u", "to_128.weight_orig", "to_128.weight_u", "to_128.weight_v", "to_big.weight_orig", "to_big.weight", "to_big.weight_u", "to_big.weight_orig", "to_big.weight_u", "to_big.weight_v", "feat_512.1.weight_orig", "feat_512.1.weight", "feat_512.1.weight_u", "feat_512.1.weight_orig", "feat_512.1.weight_u", "feat_512.1.weight_v", "feat_512.2.weight", "feat_512.3.weight", "feat_512.3.bias", "feat_512.3.running_mean", "feat_512.3.running_var", "feat_512.5.weight_orig", "feat_512.5.weight", "feat_512.5.weight_u", "feat_512.5.weight_orig", "feat_512.5.weight_u", "feat_512.5.weight_v", "feat_512.6.weight", "feat_512.7.weight", "feat_512.7.bias", "feat_512.7.running_mean", "feat_512.7.running_var", "se_512.main.1.weight_orig", "se_512.main.1.weight", "se_512.main.1.weight_u", "se_512.main.1.weight_orig", "se_512.main.1.weight_u", "se_512.main.1.weight_v", "se_512.main.3.weight_orig", "se_512.main.3.weight", "se_512.main.3.weight_u", "se_512.main.3.weight_orig", "se_512.main.3.weight_u", "se_512.main.3.weight_v", "feat_1024.1.weight_orig", "feat_1024.1.weight", "feat_1024.1.weight_u", "feat_1024.1.weight_orig", "feat_1024.1.weight_u", "feat_1024.1.weight_v", "feat_1024.2.weight", "feat_1024.2.bias", "feat_1024.2.running_mean", "feat_1024.2.running_var". 
    Unexpected key(s) in state_dict: "module.init.init.0.weight_orig", "module.init.init.0.weight_u", "module.init.init.0.weight_v", "module.init.init.1.weight", "module.init.init.1.bias", "module.init.init.1.running_mean", "module.init.init.1.running_var", "module.init.init.1.num_batches_tracked", "module.feat_8.1.weight_orig", "module.feat_8.1.weight_u", "module.feat_8.1.weight_v", "module.feat_8.2.weight", "module.feat_8.3.weight", "module.feat_8.3.bias", "module.feat_8.3.running_mean", "module.feat_8.3.running_var", "module.feat_8.3.num_batches_tracked", "module.feat_8.5.weight_orig", "module.feat_8.5.weight_u", "module.feat_8.5.weight_v", "module.feat_8.6.weight", "module.feat_8.7.weight", "module.feat_8.7.bias", "module.feat_8.7.running_mean", "module.feat_8.7.running_var", "module.feat_8.7.num_batches_tracked", "module.feat_16.1.weight_orig", "module.feat_16.1.weight_u", "module.feat_16.1.weight_v", "module.feat_16.2.weight", "module.feat_16.2.bias", "module.feat_16.2.running_mean", "module.feat_16.2.running_var", "module.feat_16.2.num_batches_tracked", "module.feat_32.1.weight_orig", "module.feat_32.1.weight_u", "module.feat_32.1.weight_v", "module.feat_32.2.weight", "module.feat_32.3.weight", "module.feat_32.3.bias", "module.feat_32.3.running_mean", "module.feat_32.3.running_var", "module.feat_32.3.num_batches_tracked", "module.feat_32.5.weight_orig", "module.feat_32.5.weight_u", "module.feat_32.5.weight_v", "module.feat_32.6.weight", "module.feat_32.7.weight", "module.feat_32.7.bias", "module.feat_32.7.running_mean", "module.feat_32.7.running_var", "module.feat_32.7.num_batches_tracked", "module.feat_64.1.weight_orig", "module.feat_64.1.weight_u", "module.feat_64.1.weight_v", "module.feat_64.2.weight", "module.feat_64.2.bias", "module.feat_64.2.running_mean", "module.feat_64.2.running_var", "module.feat_64.2.num_batches_tracked", "module.feat_128.1.weight_orig", "module.feat_128.1.weight_u", "module.feat_128.1.weight_v", "module.feat_128.2.weight", "module.feat_128.3.weight", "module.feat_128.3.bias", "module.feat_128.3.running_mean", "module.feat_128.3.running_var", "module.feat_128.3.num_batches_tracked", "module.feat_128.5.weight_orig", "module.feat_128.5.weight_u", "module.feat_128.5.weight_v", "module.feat_128.6.weight", "module.feat_128.7.weight", "module.feat_128.7.bias", "module.feat_128.7.running_mean", "module.feat_128.7.running_var", "module.feat_128.7.num_batches_tracked", "module.feat_256.1.weight_orig", "module.feat_256.1.weight_u", "module.feat_256.1.weight_v", "module.feat_256.2.weight", "module.feat_256.2.bias", "module.feat_256.2.running_mean", "module.feat_256.2.running_var", "module.feat_256.2.num_batches_tracked", "module.se_64.main.1.weight_orig", "module.se_64.main.1.weight_u", "module.se_64.main.1.weight_v", "module.se_64.main.3.weight_orig", "module.se_64.main.3.weight_u", "module.se_64.main.3.weight_v", "module.se_128.main.1.weight_orig", "module.se_128.main.1.weight_u", "module.se_128.main.1.weight_v", "module.se_128.main.3.weight_orig", "module.se_128.main.3.weight_u", "module.se_128.main.3.weight_v", "module.se_256.main.1.weight_orig", "module.se_256.main.1.weight_u", "module.se_256.main.1.weight_v", "module.se_256.main.3.weight_orig", "module.se_256.main.3.weight_u", "module.se_256.main.3.weight_v", "module.to_128.weight_orig", "module.to_128.weight_u", "module.to_128.weight_v", "module.to_big.weight_orig", "module.to_big.weight_u", "module.to_big.weight_v", "module.feat_512.1.weight_orig", "module.feat_512.1.weight_u", "module.feat_512.1.weight_v", "module.feat_512.2.weight", "module.feat_512.3.weight", "module.feat_512.3.bias", "module.feat_512.3.running_mean", "module.feat_512.3.running_var", "module.feat_512.3.num_batches_tracked", "module.feat_512.5.weight_orig", "module.feat_512.5.weight_u", "module.feat_512.5.weight_v", "module.feat_512.6.weight", "module.feat_512.7.weight", "module.feat_512.7.bias", "module.feat_512.7.running_mean", "module.feat_512.7.running_var", "module.feat_512.7.num_batches_tracked", "module.se_512.main.1.weight_orig", "module.se_512.main.1.weight_u", "module.se_512.main.1.weight_v", "module.se_512.main.3.weight_orig", "module.se_512.main.3.weight_u", "module.se_512.main.3.weight_v", "module.feat_1024.1.weight_orig", "module.feat_1024.1.weight_u", "module.feat_1024.1.weight_v", "module.feat_1024.2.weight", "module.feat_1024.2.bias", "module.feat_1024.2.running_mean", "module.feat_1024.2.running_var", "module.feat_1024.2.num_batches_tracked". 

知道这里会发生什么吗?

非常感谢您的帮助!

这在更改 nn.Module 中的子模块的属性名称时很常见。

请注意这里的大部分图层键与加载状态 dict 中包含的键有何不同,因为它们的前缀:字典中的所有键都有一个 'module.'前缀。

一个快速的解决方法是切掉这个前缀。例如,您可以使用 dict comprehension:

loaded_state = {k.replace('module.', ''): v for k, v in ckpt['g'].items()}
netG.load_state_dict(loaded_state)