Pytorch:如何从数据加载器中获取前 N 项

Pytorch: How to get the first N item from dataloader

我的列表中有 3000 张图片,但我只想要其中的前 N ​​张,例如 1000 张,用于训练。 我想知道如何通过更改循环代码来实现此目的:

for (image, label) in enumerate(train_loader):

for (image, label) in list(enumerate(train_loader))[:1000]:

虽然这不是划分训练和验证数据的好方法。 首先,dataloader class 支持延迟加载(示例直到需要时才加载到内存中),而作为列表进行转换将需要将所有数据加载到内存中,可能会触发 out-of-memory 错误.其次,如果 dataloader 有洗牌,这可能并不总是 return 相同的 1000 个元素。一般来说,dataloader class 不支持索引,因此不适合选择我们数据集的特定子集。转换为列表可以解决此问题,但会牺牲 dataloader class.

的有用属性

最佳做法是对训练和验证分区使用单独的 data.dataset 对象,或者至少对数据集中的数据进行分区,而不是依赖于在前 1000 个示例后停止训练。然后,为训练分区和验证分区创建一个单独的数据加载器。