如何制作像 torchvision 数据集一样结构化的自定义 pytorch 数据集?

How do I make custom pytorch datasets structured like the torchvision datasets?

我是 pytorch 的新手,我正在尝试重用 Fashion MNIST CNN(from deeplizard) to categorize my timeseries data. I'm finding it hard to understand the structure of datasets, because following this official tutorial and 尽我所能,我得到的东西太简单了。我认为这是因为我不非常了解 OOP。我制作的数据集在我的 CNN 中运行良好,用于训练,但随后尝试使用他们的代码分析结果时,我遇到了困难。

所以我从两个名为特征 [4050, 1, 150, 6] 和目标 [4050]:

的 pytorch 张量创建了一个数据集
train_dataset = TensorDataset(features,targets) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=False) # create your dataloader
print(train_dataset.__dict__.keys()) # list the attributes

我通过检查属性得到了这个打印输出

dict_keys(['tensors'])

但是在 Fashion MNIST 教程中,他们像这样访问数据:

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)
print(train_set.__dict__.keys()) # list the attributes

你通过检查属性得到这个打印输出

dict_keys(['root', 'transform', 'target_transform', 'transforms', 'train', 'data', 'targets'])

我的数据集非常适合训练,但是当我进入教程的后续分析部分时,他们要我访问部分数据集,但我收到错误消息:

# Analytics
prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
train_preds = get_all_preds(network, prediction_loader)
preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()

print('total correct:', preds_correct)
print('accuracy:', preds_correct / len(train_set))

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-73-daa87335a92a> in <module>
      4 prediction_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50)
      5 train_preds = get_all_preds(network, prediction_loader)
----> 6 preds_correct = train_preds.argmax(dim=1).eq(train_dataset.targets).sum().item()
      7 
      8 print('total correct:', preds_correct)

AttributeError: 'TensorDataset' object has no attribute 'targets'

谁能告诉我这是怎么回事?这是我制作数据集的方式需要改变的地方,还是我可以重写分析代码以访问数据集的正确部分?

.targets 对应 TensorDataset 的等价物是 train_dataset.tensors[1]

TensorDataset的实现很简单:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)