使用 Dataloader 时出现问题

Something wrong when I use the Dataloader

数据集pre_proceed中的短语

class data_test(Dataset):
def __init__(self,data_root,transform=None):
    data_image=glob.glob(data_root+'/*.jpg')

    self.data_image=data_image
    self.transform=transform

def __getitem__(self, index):
    data_image_path=self.data_image[index]

    image_data=cv2.imread(data_image_path,-1) # unchanged
    if self.transform:
        image_data=self.transform(image_data)

    return image_data

上面的操作很普通,但是当我加载数据集的时候,

`dataset=data_test(train_dataset,transforms)
data=DataLoader(dataset,batch_size=8,num_workers=0)
for idx,data in enumerate(data):
    print(data.shape)`

发生错误,

该错误实际上非常具体,引发的错误是 NotImplementedError。您应该在自定义数据集中实现 __len__ 函数。

在您的情况下,这很简单(假设 self.data_image 包含您所有的数据集实例)将此函数添加到 data_test class:

    def __len__(self):
        return len(self.data_image)