没有 torchvision 的 pytorch 负载模型
pytorch load model without torchvision
是否可以在不依赖 torchvision 的情况下加载 pytorch 模型(从 .pth
文件,包含架构+state_dict)?
import os
import torch
assert os.path.exists(r'.\vgg.pth')
model = torch.load(r'.\vgg.pth')
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
<ipython-input-4-e26863d95688> in <module>
2 import torch
3 assert os.path.exists(r'.\vgg.pth')
----> 4 model = torch.load(r'.\vgg.pth')
~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
590 opened_file.seek(orig_position)
591 return torch.jit.load(opened_file)
--> 592 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
593 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
594
~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
849 unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
850 unpickler.persistent_load = persistent_load
--> 851 result = unpickler.load()
852
853 torch._utils._validate_loaded_sparse_tensors()
ModuleNotFoundError: No module named 'torchvision'
我已经研究过 torch/serialization.py
,但我看不出为什么它需要 torchvision。此文件中的导入如下:
import difflib
import os
import io
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
import pathlib
是什么导致了我的问题
我问题中的vgg.pth
文件生成如下:
import torchvision
vgg = models.vgg16(pretrained=True, init_weights=False)
torch.save(vgg, r'.\vgg.pth')
这样,文件 vgg.pth
不仅包含模型参数,还包含模型架构(参见 pytorch: save/load entire model)。但是,正如@Kishore 在评论中指出的那样,该架构似乎还需要 torchvision 作为依赖项。
我是怎么解决的
- 在有 torchvision 的环境中,我将预训练的 VGG 模型加载到内存中并保存 state_dict
from torchvision.models.vgg import vgg16
import torch
model = vgg16(pretrained=True)
torch.save(model.state_dict, r'.\state_dict.pth')
- 在没有 torchvision 的环境中,我通过检查 torchvision.models.vgg 代码重建了模型。
然后我将这个 state_dict 文件加载到我模型的 state_dict 中。
最后,我将这个模型(包括架构)保存到 .pth
文件中。
import torch
# a file where I pasted the torchvision.models.vgg code
# and commented out the torchvision dependencies I don't need
# in this case: 'from .._internally_replaced_utils import load_state_dict_from_url'
from torch_save import *
model = vgg16()
model.load_state_dict(torch.load(r'.\state_dict.pth'))
torch.save(model, r'.\entire_model.pth')
当我在无 torchvision 的环境中再次加载时,我没有收到任何错误。
是否可以在不依赖 torchvision 的情况下加载 pytorch 模型(从 .pth
文件,包含架构+state_dict)?
import os
import torch
assert os.path.exists(r'.\vgg.pth')
model = torch.load(r'.\vgg.pth')
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
<ipython-input-4-e26863d95688> in <module>
2 import torch
3 assert os.path.exists(r'.\vgg.pth')
----> 4 model = torch.load(r'.\vgg.pth')
~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
590 opened_file.seek(orig_position)
591 return torch.jit.load(opened_file)
--> 592 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
593 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
594
~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
849 unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
850 unpickler.persistent_load = persistent_load
--> 851 result = unpickler.load()
852
853 torch._utils._validate_loaded_sparse_tensors()
ModuleNotFoundError: No module named 'torchvision'
我已经研究过 torch/serialization.py
,但我看不出为什么它需要 torchvision。此文件中的导入如下:
import difflib
import os
import io
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
import pathlib
是什么导致了我的问题
我问题中的vgg.pth
文件生成如下:
import torchvision
vgg = models.vgg16(pretrained=True, init_weights=False)
torch.save(vgg, r'.\vgg.pth')
这样,文件 vgg.pth
不仅包含模型参数,还包含模型架构(参见 pytorch: save/load entire model)。但是,正如@Kishore 在评论中指出的那样,该架构似乎还需要 torchvision 作为依赖项。
我是怎么解决的
- 在有 torchvision 的环境中,我将预训练的 VGG 模型加载到内存中并保存 state_dict
from torchvision.models.vgg import vgg16
import torch
model = vgg16(pretrained=True)
torch.save(model.state_dict, r'.\state_dict.pth')
- 在没有 torchvision 的环境中,我通过检查 torchvision.models.vgg 代码重建了模型。
然后我将这个 state_dict 文件加载到我模型的 state_dict 中。
最后,我将这个模型(包括架构)保存到.pth
文件中。
import torch
# a file where I pasted the torchvision.models.vgg code
# and commented out the torchvision dependencies I don't need
# in this case: 'from .._internally_replaced_utils import load_state_dict_from_url'
from torch_save import *
model = vgg16()
model.load_state_dict(torch.load(r'.\state_dict.pth'))
torch.save(model, r'.\entire_model.pth')
当我在无 torchvision 的环境中再次加载时,我没有收到任何错误。