压缩两个可迭代对象时如何获得下一次迭代

How to get the next iteration when zipping two iterables

当标签本身是图像时,我需要训练模型。我想对输入图像和输出图像应用相同的数据增强。在 之后,我压缩了两个生成器:

# create augmentation generators for the input images and label images
image_datagen = ImageDataGenerator(rotation_range=45, width_shift_range=0.2, height_shift_range=0.2, brightness_range=(0.5,1.5),
                                               fill_mode="reflect", horizontal_flip=True ,zoom_range=0.3,preprocessing_function = self.apply_kernels)
desity_datagen = ImageDataGenerator(rotation_range=45, width_shift_range=0.2, height_shift_range=0.2, brightness_range=(0.5,1.5),
                                               fill_mode="reflect", horizontal_flip=True ,zoom_range=0.3,preprocessing_function = self.apply_kernels)
image_generator = image_datagen.flow_from_directory(self.train_data_dir, batch_size=batch_size, class_mode=None, seed=seed)
density_generator = desity_datagen.flow_from_directory(self.train_label_dir, batch_size=batch_size, class_mode=None, seed=seed)
# Combine the image and label generator
self.train_generator = zip(image_generator, density_generator)

我在生成器 calss 中构建了它,我使用它进行初始化:

gen = data_generator(path_to_images, path_to_labels, batch_size);

我没有包括整个 class 因为我不确定是否需要它。如果需要,我将编辑并添加它。我正在尝试从两个生成器调用下一批,看看它是否有效:

image,label = gen.train_generator.next()
print(labels.shape)

然后我得到

AttributeError: 'zip' object has no attribute 'next'

我明白为什么我会得到它,虽然我不知道如何得到一个批次。

使用 list(gen.train_generator) 内存太大。

这会做你想做的事:

gen.train_generator.__next__()

尽管@OlvinR​​oght 的 comment/answer 更干净。

你应该使用built-in函数,next-

value = next(gen.train_generator)
如果没有下一个值,

next 将抛出 StopIteration 异常。确保包裹在 try-except 中以防止错误冒泡 -

try:
  value = next(gen.train_generator)
except StopIteration:
  value = None

因为这种模式很常见,next 接受在生成器耗尽时使用的辅助参数 -

value = next(gen.train_generator, None) # default to None if generator is exhausted