DataLoader 中的批量大小
Batchsize in DataLoader
我有两个张量:
x[train], y[train]
形状是
(311, 3, 224, 224), (311) # 311 Has No Information
我想用DataLoader来批量加载,我写的代码是:
from torch.utils.data import Dataset
class KD_Train(Dataset):
def __init__(self,a,b):
self.imgs = a
self.index = b
def __len__(self):
return len(self.imgs)
def __getitem__(self,index):
return self.imgs, self.index
kdt = KD_Train(x[train], y[train])
train_data_loader = Data.DataLoader(
kdt,
batch_size = 64,
shuffle = True,
num_workers = 0)
for step, (a,b) in enumerate (train_data_loader):
print(a.shape)
break
但是它显示:
(64, 311, 3, 224, 224)
DataLoader只是直接添加一个维度,而不是选择一些批次,有人知道我该怎么做吗?
您的数据集的 __getitem__
方法应该 return 单个元素:
def __getitem__(self, index):
return self.imgs[index], self.index[index]
我有两个张量:
x[train], y[train]
形状是
(311, 3, 224, 224), (311) # 311 Has No Information
我想用DataLoader来批量加载,我写的代码是:
from torch.utils.data import Dataset
class KD_Train(Dataset):
def __init__(self,a,b):
self.imgs = a
self.index = b
def __len__(self):
return len(self.imgs)
def __getitem__(self,index):
return self.imgs, self.index
kdt = KD_Train(x[train], y[train])
train_data_loader = Data.DataLoader(
kdt,
batch_size = 64,
shuffle = True,
num_workers = 0)
for step, (a,b) in enumerate (train_data_loader):
print(a.shape)
break
但是它显示:
(64, 311, 3, 224, 224)
DataLoader只是直接添加一个维度,而不是选择一些批次,有人知道我该怎么做吗?
您的数据集的 __getitem__
方法应该 return 单个元素:
def __getitem__(self, index):
return self.imgs[index], self.index[index]