为什么我的 train_generator 和 val_generator 生成相同的图片?

Why do my train_generator and val_generator produce the same picture?

我预留了如下验证拆分:

val_samples = 60
train_imgs = coco_imgs[:-val_samples]
train_masks = coco_masks[:-val_samples]
val_imgs = coco_imgs[-val_samples:]
val_masks = coco_masks[-val_samples:]

我的train_imgsval_imgs显示不同的图片:

fig, ax = plt.subplots(ncols =2, figsize = (10,3), sharex = True, sharey = True)
ax[0].imshow(train_imgs[14])
ax[1].imshow(val_imgs[14])

然后我写数据生成函数:

class DataGenerator(keras.utils.Sequence):
  def __init__(self, input_img, input_mask, image_size, 
               augmentation, batch_size):
    self.image_size = img_size
    self.augmentation = augmentation
    self.batch_size = batch_size
    self.input_img = train_imgs
    self.input_mask = train_masks

  def __len__(self):
        return len(self.input_img) // self.batch_size

  def __getitem__(self, index):
     data_index_min = int(index*self.batch_size)
     data_index_max = int(min((index+1)*self.batch_size, len(self.input_img)))

     indexes = self.input_img[data_index_min:data_index_max]
     this_batch_size = len(indexes)
     
     X = np.empty((this_batch_size, self.image_size , self.image_size , 3), dtype=np.float32)
     y = np.empty((this_batch_size, self.image_size , self.image_size , self.nb_y_features), dtype=np.uint8)
     
     for i, sample_index in enumerate(indexes):
          X_sample = self.input_img[index * self.batch_size + i]
          y_sample = self.input_mask[index * self.batch_size + i]
          if self.augmentation is True:
              aug = transform(image = X_sample, mask = y_sample)
              img_aug = aug['image']
              mask_aug = aug['mask']
              X[i, ...] = img_aug/255
              y[i, ...] = mask_aug.reshape(self.image_size , self.image_size , self.nb_y_features).astype(np.uint8)
          else:
              pass
     return X, y

这是我的 train_generatorval_generator

train_generator = DataGenerator(input_img = train_imgs, input_mask = train_masks, image_size = img_size,
                                augmentation=True, batch_size = 5)
val_generator = DataGenerator(input_img = val_imgs, input_mask = val_masks, image_size = img_size,
                                augmentation=True, batch_size = 5)

他们显示与 train_imgs 相同的图片。

for i in range(3):
  X_sample_temp, y_sample_temp = train_generator[2]
  fig, ax = plt.subplots(ncols=2)
  ax[0].imshow(X_sample_temp[4])
  ax[1].imshow(y_sample_temp[4,:,:,0])
  plt.show()

for i in range(3):
  X_sample_temp, y_sample_temp = val_generator[2]
  fig, ax = plt.subplots(ncols=2)
  ax[0].imshow(X_sample_temp[4])
  ax[1].imshow(y_sample_temp[4,:,:,0])
  plt.show()

我希望 val_generator 生成与 val_imgs 相同的图片,但我不知道如何修复它。我感谢任何意见。

__init__ 你已经硬编码了

    self.input_img  = train_imgs
    self.input_mask = train_masks

因此所有生成器都使用相同的 train_imgs train_masks
但你应该使用参数 input_imginput_mask

    self.input_img  = input_img
    self.input_mask = input_mask

不知道是不是打错了self.image_size = img_size
因为它应该是 self.image_size = image_size