无法在 google colab 上使用 torch.load() 加载 .pth 文件(预训练神经网络)

Fail to load a .pth file (pre-trained neural network) using torch.load() on google colab

我的 google 驱动器链接到我的 google colab 笔记本。使用 pytorch 库 torch.load($PATH) 无法加载我的 google 驱动器中的这个 219 Mo 文件(预训练神经网络)(https://drive.google.com/drive/folders/1-9m4aVg8Hze0IsZRyxvm5gLybuRLJHv-)。但是,当我在我的计算机上本地执行时,它工作正常。我在 google collab 上得到的错误是:(设置:Python 3.6,pytorch 1.3.1):

state_dict = torch.load(model_path)['state_dict']
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 303, in load
return _load(f, map_location, pickle_module)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 454, in _load
return legacy_load(f)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 380, in legacy_load
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar,
File "/usr/lib/python3.6/tarfile.py", line 1589, in open
return func(name, filemode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1619, in taropen
return cls(name, mode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1482, in init
self.firstmember = self.next()
File "/usr/lib/python3.6/tarfile.py", line 2297, in next
tarinfo = self.tarinfo.fromtarfile(self)
File "/usr/lib/python3.6/tarfile.py", line 1092, in fromtarfile
buf = tarfile.fileobj.read(BLOCKSIZE)
OSError: [Errno 5] Input/output error```   




Any help would be much appreciated!

Drive 上的大文件会自动分析病毒,每次尝试下载大文件时都必须通过此扫描,因此很难下载 link。

你可以直接使用Drive API下载文件然后传给torch,在Python上应该不难实现,我已经做了一个例子下载您的文件并将其传递给 Torch。

import torch
import pickle
import os.path
import io
from googleapiclient.discovery import build
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
from googleapiclient.http import MediaIoBaseDownload
from __future__ import print_function

url = "https://drive.google.com/file/d/1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs/view?usp=sharing"
# If modifying these scopes, delete the file token.pickle.
SCOPES = (
    'https://www.googleapis.com/auth/drive',
    )

def main():
    """Shows basic usage of the Sheets API.
    Prints values from a sample spreadsheet.
    """
    creds = None
    # The file token.pickle stores the user's access and refresh tokens, and is
    # created automatically when the authorization flow completes for the first
    # time.
    if os.path.exists('token.pickle'):
        with open('token.pickle', 'rb') as token:
            creds = pickle.load(token)
    # If there are no (valid) credentials available, let the user log in.
    if not creds or not creds.valid:
        if creds and creds.expired and creds.refresh_token:
            creds.refresh(Request())
        else:
            flow = InstalledAppFlow.from_client_secrets_file(
                'credentials.json', SCOPES)
            creds = flow.run_local_server(port=0)
        # Save the credentials for the next run
        with open('token.pickle', 'wb') as token:
            pickle.dump(creds, token)

    drive_service = build('drive', 'v2', credentials=creds)

    file_id = '1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs'
    request = drive_service.files().get_media(fileId=file_id)
    # fh = io.BytesIO()
    fh = open('file', 'wb')
    downloader = MediaIoBaseDownload(fh, request)
    done = False
    while done is False:
      status, done = downloader.next_chunk()
      print("Download %d%%." % int(status.progress() * 100))
    fh.close()
    torch.load('file')

if __name__ == '__main__':
    main()

要运行它你首先要:

  • 为您的帐户启用驱动器API
  • 安装 Google 驱动器 API 库,

这不会超过 3 分钟,并在 Quickstart Guide for Google Drive API 上进行了正确解释,只需按照步骤 1 和 2 以及 运行 上面提供的示例代码进行操作即可。

它通过直接将文件上传到 google colab 而不是从 google 驱动器加载它来工作:

from google.colab import files
uploaded= files.upload()

我想这个解决方案类似于@Yuri 提出的解决方案