PyTorch 数据集 class 的 Subclass 找不到数据集文件

Subclass of PyTorch dataset class cannot find dataset files

我正在尝试创建 PyTorch MNIST 数据集 class 的子 class,我称之为 CustomMNISTDataset,如下所示:

import torchvision.datasets as datasets

class CustomMNISTDataset(datasets.MNIST):

    def __init__(self, root='/home/psando'):
        super().__init__(root=root,
                         download=False)

但是当我执行时:

dataset = CustomMNISTDataset()

它失败并出现错误:“RuntimeError:未找到数据集。您可以使用 download=True 来下载它”。

但是,当我 运行 在同一文件中添加以下内容时:

dataset = datasets.MNIST(root='/home/psando', download=False)
print(len(dataset))

成功并按预期打印“60000”。

因为 CustomMNISTDataset subclasses datasets.MNIST 为什么行为不同? 我已经验证了路径 '/home/psando' 包含带有原始和已处理子目录的 MNIST 目录(否则,显式调用 datasets.MNIST() 的构造函数将失败)。当前的行为意味着在 CustomMNISTDataset 中对 super().__init__() 的调用不是在调用 datasets.MNIST 的构造函数,这很奇怪!

其他详细信息:我正在使用 Python 3.6.8 和 torch==1.6.0 以及 torchvision==0.7.0。如有任何帮助,我们将不胜感激!

这需要深入研究源代码,但您的问题是 this 函数。数据集的路径取决于 class 的名称,因此当您子 class MNIST 时,根文件夹更改为 /home/psando/CustomMNISTDataset

因此,如果您将 /home/psando/MNIST 重命名为 /home/psando/CustomMNISTDataset,它会起作用。