无法在 google colab 上使用 torch.load() 加载 .pth 文件(预训练神经网络)
Fail to load a .pth file (pre-trained neural network) using torch.load() on google colab
我的 google 驱动器链接到我的 google colab 笔记本。使用 pytorch 库 torch.load($PATH) 无法加载我的 google 驱动器中的这个 219 Mo 文件(预训练神经网络)(https://drive.google.com/drive/folders/1-9m4aVg8Hze0IsZRyxvm5gLybuRLJHv-)。但是,当我在我的计算机上本地执行时,它工作正常。我在 google collab 上得到的错误是:(设置:Python 3.6,pytorch 1.3.1):
state_dict = torch.load(model_path)['state_dict']
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 303, in load
return _load(f, map_location, pickle_module)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 454, in _load
return legacy_load(f)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 380, in legacy_load
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar,
File "/usr/lib/python3.6/tarfile.py", line 1589, in open
return func(name, filemode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1619, in taropen
return cls(name, mode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1482, in init
self.firstmember = self.next()
File "/usr/lib/python3.6/tarfile.py", line 2297, in next
tarinfo = self.tarinfo.fromtarfile(self)
File "/usr/lib/python3.6/tarfile.py", line 1092, in fromtarfile
buf = tarfile.fileobj.read(BLOCKSIZE)
OSError: [Errno 5] Input/output error```
Any help would be much appreciated!
Drive 上的大文件会自动分析病毒,每次尝试下载大文件时都必须通过此扫描,因此很难下载 link。
你可以直接使用Drive API下载文件然后传给torch,在Python上应该不难实现,我已经做了一个例子下载您的文件并将其传递给 Torch。
import torch
import pickle
import os.path
import io
from googleapiclient.discovery import build
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
from googleapiclient.http import MediaIoBaseDownload
from __future__ import print_function
url = "https://drive.google.com/file/d/1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs/view?usp=sharing"
# If modifying these scopes, delete the file token.pickle.
SCOPES = (
'https://www.googleapis.com/auth/drive',
)
def main():
"""Shows basic usage of the Sheets API.
Prints values from a sample spreadsheet.
"""
creds = None
# The file token.pickle stores the user's access and refresh tokens, and is
# created automatically when the authorization flow completes for the first
# time.
if os.path.exists('token.pickle'):
with open('token.pickle', 'rb') as token:
creds = pickle.load(token)
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
'credentials.json', SCOPES)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open('token.pickle', 'wb') as token:
pickle.dump(creds, token)
drive_service = build('drive', 'v2', credentials=creds)
file_id = '1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs'
request = drive_service.files().get_media(fileId=file_id)
# fh = io.BytesIO()
fh = open('file', 'wb')
downloader = MediaIoBaseDownload(fh, request)
done = False
while done is False:
status, done = downloader.next_chunk()
print("Download %d%%." % int(status.progress() * 100))
fh.close()
torch.load('file')
if __name__ == '__main__':
main()
要运行它你首先要:
- 为您的帐户启用驱动器API
- 安装 Google 驱动器 API 库,
这不会超过 3 分钟,并在 Quickstart Guide for Google Drive API 上进行了正确解释,只需按照步骤 1 和 2 以及 运行 上面提供的示例代码进行操作即可。
它通过直接将文件上传到 google colab 而不是从 google 驱动器加载它来工作:
from google.colab import files
uploaded= files.upload()
我想这个解决方案类似于@Yuri 提出的解决方案
我的 google 驱动器链接到我的 google colab 笔记本。使用 pytorch 库 torch.load($PATH) 无法加载我的 google 驱动器中的这个 219 Mo 文件(预训练神经网络)(https://drive.google.com/drive/folders/1-9m4aVg8Hze0IsZRyxvm5gLybuRLJHv-)。但是,当我在我的计算机上本地执行时,它工作正常。我在 google collab 上得到的错误是:(设置:Python 3.6,pytorch 1.3.1):
state_dict = torch.load(model_path)['state_dict']
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 303, in load
return _load(f, map_location, pickle_module)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 454, in _load
return legacy_load(f)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 380, in legacy_load
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar,
File "/usr/lib/python3.6/tarfile.py", line 1589, in open
return func(name, filemode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1619, in taropen
return cls(name, mode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1482, in init
self.firstmember = self.next()
File "/usr/lib/python3.6/tarfile.py", line 2297, in next
tarinfo = self.tarinfo.fromtarfile(self)
File "/usr/lib/python3.6/tarfile.py", line 1092, in fromtarfile
buf = tarfile.fileobj.read(BLOCKSIZE)
OSError: [Errno 5] Input/output error```
Any help would be much appreciated!
Drive 上的大文件会自动分析病毒,每次尝试下载大文件时都必须通过此扫描,因此很难下载 link。
你可以直接使用Drive API下载文件然后传给torch,在Python上应该不难实现,我已经做了一个例子下载您的文件并将其传递给 Torch。
import torch
import pickle
import os.path
import io
from googleapiclient.discovery import build
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
from googleapiclient.http import MediaIoBaseDownload
from __future__ import print_function
url = "https://drive.google.com/file/d/1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs/view?usp=sharing"
# If modifying these scopes, delete the file token.pickle.
SCOPES = (
'https://www.googleapis.com/auth/drive',
)
def main():
"""Shows basic usage of the Sheets API.
Prints values from a sample spreadsheet.
"""
creds = None
# The file token.pickle stores the user's access and refresh tokens, and is
# created automatically when the authorization flow completes for the first
# time.
if os.path.exists('token.pickle'):
with open('token.pickle', 'rb') as token:
creds = pickle.load(token)
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
'credentials.json', SCOPES)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open('token.pickle', 'wb') as token:
pickle.dump(creds, token)
drive_service = build('drive', 'v2', credentials=creds)
file_id = '1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs'
request = drive_service.files().get_media(fileId=file_id)
# fh = io.BytesIO()
fh = open('file', 'wb')
downloader = MediaIoBaseDownload(fh, request)
done = False
while done is False:
status, done = downloader.next_chunk()
print("Download %d%%." % int(status.progress() * 100))
fh.close()
torch.load('file')
if __name__ == '__main__':
main()
要运行它你首先要:
- 为您的帐户启用驱动器API
- 安装 Google 驱动器 API 库,
这不会超过 3 分钟,并在 Quickstart Guide for Google Drive API 上进行了正确解释,只需按照步骤 1 和 2 以及 运行 上面提供的示例代码进行操作即可。
它通过直接将文件上传到 google colab 而不是从 google 驱动器加载它来工作:
from google.colab import files
uploaded= files.upload()
我想这个解决方案类似于@Yuri 提出的解决方案