Pytorch 默认数据加载器卡在大型图像分类训练集上
Pytorch default dataloader gets stuck for large image classification training set
我正在 Pytorch 中训练图像 class化模型,并使用它们 default data loader 加载我的训练数据。我有一个非常大的训练数据集,所以通常每个 class 有几千张样本图像。我过去训练过总共有大约 20 万张图像的模型,没有出现任何问题。但是我发现当总共有超过一百万张图片时,Pytorch 数据加载器会卡住。
我相信当我调用 datasets.ImageFolder(...)
时代码挂起。当我按 Ctrl-C 时,这始终是输出:
Traceback (most recent call last): │
File "main.py", line 412, in <module> │
main() │
File "main.py", line 122, in main │
run_training(args.group, args.num_classes) │
File "main.py", line 203, in run_training │
train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True) │
File "main.py", line 236, in create_dataloader │
dataset = datasets.ImageFolder(directory, trans) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__ │
is_valid_file=is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__ │
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset │
for root, _, fnames in sorted(os.walk(d)): │
File "/usr/lib/python3.5/os.py", line 380, in walk │
is_dir = entry.is_dir() │
Keyboard Interrupt
我认为某处可能存在死锁,但是根据 Ctrl-C 的堆栈输出,它看起来不像是在等待锁。所以后来我认为数据加载器很慢,因为我试图加载更多数据。我让它 运行 大约 2 天,它没有取得任何进展,在加载的最后 2 小时,我检查了 RAM 使用量保持不变。过去不到几个小时,我还能够加载包含超过 20 万张图像的训练数据集。我还尝试将我的 GCP 机器升级为具有 32 个内核、4 个 GPU 和超过 100GB 的 RAM,但是似乎在加载一定数量的内存后数据加载器就卡住了。
我很困惑数据加载器在遍历目录时是如何卡住的,我仍然不确定它是卡住了还是非常慢。有什么方法可以更改 Pytortch 数据加载器,使其能够处理超过 100 万张图像进行训练?任何调试建议也表示赞赏!
谢谢!
这不是 DataLoader
的问题,而是 torchvision.datasets.ImageFolder
及其工作方式的问题(以及为什么您拥有的数据越多,效果越差)。
它挂在这条线上,如您的错误所示:
for root, _, fnames in sorted(os.walk(d)):
可以找到来源here。
潜在的问题是它在巨型 list
中保留了每个 path
和相应的 label
,请参见下面的代码(为简洁起见删除了一些内容):
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
# Iterate over all subfolders which were found previously
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target) # Create path to this subfolder
# Assuming it is directory (which usually is the case)
for root, _, fnames in sorted(os.walk(d, followlinks=True)):
# Iterate over ALL files in this subdirectory
for fname in sorted(fnames):
path = os.path.join(root, fname)
# Assuming it is correctly recognized as image file
item = (path, class_to_idx[target])
# Add to path with all images
images.append(item)
return images
显然图像将包含 100 万个字符串(也很长)和 类 对应的 int
,这肯定很多并且取决于 RAM 和 CPU。
尽管您可以创建自己的数据集(前提是您事先更改了图像的名称),因此 dataset
不会占用内存。
设置数据结构
您的文件夹结构应如下所示:
root
class1
class2
class3
...
用多少类你have/need.
现在每个class
应该有以下数据:
class1
0.png
1.png
2.png
...
假设您可以继续创建数据集。
创建数据集
下面torch.utils.data.Dataset
使用PIL
打开图片,不过你可以用其他方式打开:
import os
import pathlib
import torch
from PIL import Image
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
self._data = pathlib.Path(root) / folder
self.klass = klass
self.extension = extension
# Only calculate once how many files are in this folder
# Could be passed as argument if you precalculate it somehow
# e.g. ls | wc -l on Linux
self._length = sum(1 for entry in os.listdir(self._data))
def __len__(self):
# No need to recalculate this value every time
return self._length
def __getitem__(self, index):
# images always follow [0, n-1], so you access them directly
return Image.open(self._data / "{}.{}".format(str(index), self.extension))
现在您可以轻松创建您的数据集(假设文件夹结构如上:
root = "/path/to/root/with/images"
dataset = (
ImageDataset(root, "class0", 0)
+ ImageDataset(root, "class1", 1)
+ ImageDataset(root, "class2", 2)
)
您可以根据需要添加任意数量的 datasets
和指定的 类,循环执行或其他方式。
最后,照常使用torch.utils.data.DataLoader
,例如:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
我正在 Pytorch 中训练图像 class化模型,并使用它们 default data loader 加载我的训练数据。我有一个非常大的训练数据集,所以通常每个 class 有几千张样本图像。我过去训练过总共有大约 20 万张图像的模型,没有出现任何问题。但是我发现当总共有超过一百万张图片时,Pytorch 数据加载器会卡住。
我相信当我调用 datasets.ImageFolder(...)
时代码挂起。当我按 Ctrl-C 时,这始终是输出:
Traceback (most recent call last): │
File "main.py", line 412, in <module> │
main() │
File "main.py", line 122, in main │
run_training(args.group, args.num_classes) │
File "main.py", line 203, in run_training │
train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True) │
File "main.py", line 236, in create_dataloader │
dataset = datasets.ImageFolder(directory, trans) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__ │
is_valid_file=is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__ │
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset │
for root, _, fnames in sorted(os.walk(d)): │
File "/usr/lib/python3.5/os.py", line 380, in walk │
is_dir = entry.is_dir() │
Keyboard Interrupt
我认为某处可能存在死锁,但是根据 Ctrl-C 的堆栈输出,它看起来不像是在等待锁。所以后来我认为数据加载器很慢,因为我试图加载更多数据。我让它 运行 大约 2 天,它没有取得任何进展,在加载的最后 2 小时,我检查了 RAM 使用量保持不变。过去不到几个小时,我还能够加载包含超过 20 万张图像的训练数据集。我还尝试将我的 GCP 机器升级为具有 32 个内核、4 个 GPU 和超过 100GB 的 RAM,但是似乎在加载一定数量的内存后数据加载器就卡住了。
我很困惑数据加载器在遍历目录时是如何卡住的,我仍然不确定它是卡住了还是非常慢。有什么方法可以更改 Pytortch 数据加载器,使其能够处理超过 100 万张图像进行训练?任何调试建议也表示赞赏!
谢谢!
这不是 DataLoader
的问题,而是 torchvision.datasets.ImageFolder
及其工作方式的问题(以及为什么您拥有的数据越多,效果越差)。
它挂在这条线上,如您的错误所示:
for root, _, fnames in sorted(os.walk(d)):
可以找到来源here。
潜在的问题是它在巨型 list
中保留了每个 path
和相应的 label
,请参见下面的代码(为简洁起见删除了一些内容):
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
# Iterate over all subfolders which were found previously
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target) # Create path to this subfolder
# Assuming it is directory (which usually is the case)
for root, _, fnames in sorted(os.walk(d, followlinks=True)):
# Iterate over ALL files in this subdirectory
for fname in sorted(fnames):
path = os.path.join(root, fname)
# Assuming it is correctly recognized as image file
item = (path, class_to_idx[target])
# Add to path with all images
images.append(item)
return images
显然图像将包含 100 万个字符串(也很长)和 类 对应的 int
,这肯定很多并且取决于 RAM 和 CPU。
尽管您可以创建自己的数据集(前提是您事先更改了图像的名称),因此 dataset
不会占用内存。
设置数据结构
您的文件夹结构应如下所示:
root
class1
class2
class3
...
用多少类你have/need.
现在每个class
应该有以下数据:
class1
0.png
1.png
2.png
...
假设您可以继续创建数据集。
创建数据集
下面torch.utils.data.Dataset
使用PIL
打开图片,不过你可以用其他方式打开:
import os
import pathlib
import torch
from PIL import Image
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
self._data = pathlib.Path(root) / folder
self.klass = klass
self.extension = extension
# Only calculate once how many files are in this folder
# Could be passed as argument if you precalculate it somehow
# e.g. ls | wc -l on Linux
self._length = sum(1 for entry in os.listdir(self._data))
def __len__(self):
# No need to recalculate this value every time
return self._length
def __getitem__(self, index):
# images always follow [0, n-1], so you access them directly
return Image.open(self._data / "{}.{}".format(str(index), self.extension))
现在您可以轻松创建您的数据集(假设文件夹结构如上:
root = "/path/to/root/with/images"
dataset = (
ImageDataset(root, "class0", 0)
+ ImageDataset(root, "class1", 1)
+ ImageDataset(root, "class2", 2)
)
您可以根据需要添加任意数量的 datasets
和指定的 类,循环执行或其他方式。
最后,照常使用torch.utils.data.DataLoader
,例如:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)