如何解决 DataLoader 中错误的形状?

How can I solve the wrong shape in DataLoader?

我有一个要用于 GAN 的文本数据集,它应该转向 onehotencode,这就是我为我的文件创建自定义数据集的方式

class Dataset2(torch.utils.data.Dataset):
    def __init__(self, list_, labels):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        mylist = self.list_IDs[index]

        # Load data and get label
        X = F.one_hot(mylist, num_classes=len(alphabet))
        y = self.labels[index]

        return X, y

它运行良好,每次我调用它时,它都运行良好,但问题是当我使用 DataLoader 并尝试使用它时,它的形状与刚从数据集中出来的形状不一样,这个是数据集出来的形状

x , _ = dataset[1]
x.shape

torch.Size([1274, 22])

这就是 dataloader 出来的形状

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

one = []
for epoch in range(epochs):
    for i, (real_data, _) in enumerate(dataloader):
        one.append(real_data)
one[3].shape

torch.Size([4, 1274, 22])

这个 4 是我数据中的样本数,但它不应该存在,我该如何解决这个问题?

您确认您的数据集中只有四个元素。您已使用 batch_size=64 大于 4 的数据加载器包装了数据集。这意味着 dataloader 将仅输出包含 4 个元素的单个批次。

反过来,这意味着您每个时期只附加一个元素,one[3].shape 是一个批次(数据加载器的唯一批次),形状为 (4, 1274, 22).