DataLoader 中的批量大小

Batchsize in DataLoader

我有两个张量:

x[train], y[train]

形状是

(311, 3, 224, 224), (311) # 311 Has No Information

我想用DataLoader来批量加载,我写的代码是:

from torch.utils.data import Dataset

class KD_Train(Dataset):

    def __init__(self,a,b):
        self.imgs = a
        self.index = b

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self,index):
        return self.imgs, self.index

kdt = KD_Train(x[train], y[train])

train_data_loader = Data.DataLoader(
    kdt,
    batch_size = 64,
    shuffle = True,
    num_workers = 0)

for step, (a,b) in enumerate (train_data_loader):
    print(a.shape)
    break

但是它显示:

(64, 311, 3, 224, 224)

DataLoader只是直接添加一个维度,而不是选择一些批次,有人知道我该怎么做吗?

您的数据集的 __getitem__ 方法应该 return 单个元素:

def __getitem__(self, index):
    return self.imgs[index], self.index[index]