Torch 没有保存我冻结和优化的模型
Torch is not saving my freezed and optimized model
当我启动我的脚本时,它 运行 一切正常,直到它遇到 traced_model.save(args.save_path) 语句,之后脚本才停止 运行ning。
有人可以帮我解决这个问题吗?
import argparse
import torch
from model import SpeechRecognition
from collections import OrderedDict
def trace(model):
model.eval()
x = torch.rand(1, 81, 300)
hidden = model._init_hidden(1)
traced = torch.jit.trace(model, (x, hidden))
return traced
def main(args):
print("loading model from", args.model_checkpoint)
checkpoint = torch.load(args.model_checkpoint, map_location=torch.device('cpu'))
h_params = SpeechRecognition.hyper_parameters
model = SpeechRecognition(**h_params)
model_state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in model_state_dict.items():
name = k.replace("model.", "") # remove `model.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print("tracing model...")
traced_model = trace(model)
print("saving to", args.save_path)
traced_model.save(args.save_path)
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="testing the wakeword engine")
parser.add_argument('--model_checkpoint', type=str, default='your/checkpoint_file', required=False,
help='Checkpoint of model to optimize')
parser.add_argument('--save_path', type=str, default='path/where/you/want/to/save/the/model', required=False,
help='path to save optmized model')
args = parser.parse_args()
main(args)
如果您启动脚本,您甚至可以看到它停止工作的地方,因为 print("Done!")
没有被执行。
这是我 运行 脚本时在终端中显示的内容:
loading model from C:/Users/supre/Documents/Python Programs/epoch=0-step=11999.ckpt
tracing model...
saving to C:/Users/supre/Documents/Python Programs
根据 PyTorch documentation,一个常见的 PyTorch 惯例是使用 .pt 或 .pth 文件扩展名保存模型。
要保存模型检查点或多个组件,将它们组织在字典中并使用torch.save()
序列化字典。
例如,
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
一个常见的 PyTorch 约定是使用 .tar 文件扩展名保存这些检查点。
希望这能回答您的问题。
当我启动我的脚本时,它 运行 一切正常,直到它遇到 traced_model.save(args.save_path) 语句,之后脚本才停止 运行ning。 有人可以帮我解决这个问题吗?
import argparse
import torch
from model import SpeechRecognition
from collections import OrderedDict
def trace(model):
model.eval()
x = torch.rand(1, 81, 300)
hidden = model._init_hidden(1)
traced = torch.jit.trace(model, (x, hidden))
return traced
def main(args):
print("loading model from", args.model_checkpoint)
checkpoint = torch.load(args.model_checkpoint, map_location=torch.device('cpu'))
h_params = SpeechRecognition.hyper_parameters
model = SpeechRecognition(**h_params)
model_state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in model_state_dict.items():
name = k.replace("model.", "") # remove `model.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print("tracing model...")
traced_model = trace(model)
print("saving to", args.save_path)
traced_model.save(args.save_path)
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="testing the wakeword engine")
parser.add_argument('--model_checkpoint', type=str, default='your/checkpoint_file', required=False,
help='Checkpoint of model to optimize')
parser.add_argument('--save_path', type=str, default='path/where/you/want/to/save/the/model', required=False,
help='path to save optmized model')
args = parser.parse_args()
main(args)
如果您启动脚本,您甚至可以看到它停止工作的地方,因为 print("Done!")
没有被执行。
这是我 运行 脚本时在终端中显示的内容:
loading model from C:/Users/supre/Documents/Python Programs/epoch=0-step=11999.ckpt
tracing model...
saving to C:/Users/supre/Documents/Python Programs
根据 PyTorch documentation,一个常见的 PyTorch 惯例是使用 .pt 或 .pth 文件扩展名保存模型。
要保存模型检查点或多个组件,将它们组织在字典中并使用torch.save()
序列化字典。
例如,
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
一个常见的 PyTorch 约定是使用 .tar 文件扩展名保存这些检查点。
希望这能回答您的问题。