Pytorch nn.Module 的派生 Class 无法通过 Python 中的模块导入加载
Derived Class of Pytorch nn.Module Cannot be Loaded by Module Import in Python
使用 Python 3.6 和 Pytorch 1.3.1。我注意到当整个模块被导入另一个模块时,一些保存的 nn.Modules 无法加载。举个例子,这是一个最小工作示例的模板。
#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'
from torch import nn
class NN(nn.Module):##NN network
# Initialisation and other class methods
networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
# Some testing snippets
pass
当我直接在 shell 中 运行 时,整个文件工作正常。但是,当我想使用 class 并使用此代码将神经网络加载到另一个文件中时,它会失败。
#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *
错误显示为 AttributeError: Can't get attribute 'NN' on <module '__main__'>
在 Pytorch 中加载保存的变量或导入模块是否与其他常见的 Python 库不同?一些帮助或指向根本原因的指针将非常感激。
当您使用 torch.save(model, PATH)
保存模型时,整个对象将使用 pickle
进行序列化,这不会保存 class 本身,而是保存包含 class,因此在加载模型时,需要完全相同的目录和文件结构才能找到正确的 class。当运行一个Python脚本时,那个文件的模块是__main__
,因此如果你想加载那个模块,你的NN
class必须定义在你的剧本 运行.
这非常不灵活,所以推荐的做法是不保存整个模型,而是只保存状态字典,它只保存模型的参数。
# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)
之后,可以加载状态字典并将其应用于您的模型。
from dnn_predict import NN
# Create the model (will have randomly initialised parameters)
model = NN()
# Load the previously saved state dictionary
state_dict = torch.load(PATH)
# Apply the state dictionary to the model
model.load_state_dict(state_dict)
状态字典和 saving/loading 模型的更多详细信息:PyTorch - Saving and Loading Models
使用 Python 3.6 和 Pytorch 1.3.1。我注意到当整个模块被导入另一个模块时,一些保存的 nn.Modules 无法加载。举个例子,这是一个最小工作示例的模板。
#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'
from torch import nn
class NN(nn.Module):##NN network
# Initialisation and other class methods
networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
# Some testing snippets
pass
当我直接在 shell 中 运行 时,整个文件工作正常。但是,当我想使用 class 并使用此代码将神经网络加载到另一个文件中时,它会失败。
#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *
错误显示为 AttributeError: Can't get attribute 'NN' on <module '__main__'>
在 Pytorch 中加载保存的变量或导入模块是否与其他常见的 Python 库不同?一些帮助或指向根本原因的指针将非常感激。
当您使用 torch.save(model, PATH)
保存模型时,整个对象将使用 pickle
进行序列化,这不会保存 class 本身,而是保存包含 class,因此在加载模型时,需要完全相同的目录和文件结构才能找到正确的 class。当运行一个Python脚本时,那个文件的模块是__main__
,因此如果你想加载那个模块,你的NN
class必须定义在你的剧本 运行.
这非常不灵活,所以推荐的做法是不保存整个模型,而是只保存状态字典,它只保存模型的参数。
# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)
之后,可以加载状态字典并将其应用于您的模型。
from dnn_predict import NN
# Create the model (will have randomly initialised parameters)
model = NN()
# Load the previously saved state dictionary
state_dict = torch.load(PATH)
# Apply the state dictionary to the model
model.load_state_dict(state_dict)
状态字典和 saving/loading 模型的更多详细信息:PyTorch - Saving and Loading Models