没有 torchvision 的 pytorch 负载模型

pytorch load model without torchvision

是否可以在不依赖 torchvision 的情况下加载 pytorch 模型(从 .pth 文件,包含架构+state_dict)?

import os
import torch
assert os.path.exists(r'.\vgg.pth')
model = torch.load(r'.\vgg.pth')

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-4-e26863d95688> in <module>
      2 import torch
      3 assert os.path.exists(r'.\vgg.pth')
----> 4 model = torch.load(r'.\vgg.pth')

~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    590                     opened_file.seek(orig_position)
    591                     return torch.jit.load(opened_file)
--> 592                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    593         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    594 

~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
    849     unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
    850     unpickler.persistent_load = persistent_load
--> 851     result = unpickler.load()
    852 
    853     torch._utils._validate_loaded_sparse_tensors()

ModuleNotFoundError: No module named 'torchvision'

我已经研究过 torch/serialization.py,但我看不出为什么它需要 torchvision。此文件中的导入如下:

import difflib
import os
import io
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
import pathlib

是什么导致了我的问题

我问题中的vgg.pth文件生成如下:

import torchvision
vgg = models.vgg16(pretrained=True, init_weights=False)
torch.save(vgg, r'.\vgg.pth')

这样,文件 vgg.pth 不仅包含模型参数,还包含模型架构(参见 pytorch: save/load entire model)。但是,正如@Kishore 在评论中指出的那样,该架构似乎还需要 torchvision 作为依赖项。

我是怎么解决的

  • 在有 torchvision 的环境中,我将预训练的 VGG 模型加载到内存中并保存 state_dict
from torchvision.models.vgg import vgg16
import torch

model = vgg16(pretrained=True)
torch.save(model.state_dict, r'.\state_dict.pth')
  • 在没有 torchvision 的环境中,我通过检查 torchvision.models.vgg 代码重建了模型。
    然后我将这个 state_dict 文件加载到我模型的 state_dict 中。
    最后,我将这个模型(包括架构)保存到 .pth 文件中。
import torch

# a file where I pasted the torchvision.models.vgg code
# and commented out the torchvision dependencies I don't need
# in this case: 'from .._internally_replaced_utils import load_state_dict_from_url'
from torch_save import *

model = vgg16()
model.load_state_dict(torch.load(r'.\state_dict.pth'))
torch.save(model, r'.\entire_model.pth')

当我在无 torchvision 的环境中再次加载时,我没有收到任何错误。