用于图像分割数据集的 Keras 数据增强管道(具有相同操作的图像和蒙版)

Keras data augmentation pipeline for image segmentation dataset (image and mask with same manipulation)

我正在为我的图像分割数据集构建预处理和数据增强管道 keras 有一个强大的 API 可以做到这一点,但我 运行 遇到了在图像和分割掩码(第二张图像)上重现相同增强的问题。两张图片必须经过完全相同的处理。还不支持吗?

https://www.tensorflow.org/tutorials/images/data_augmentation

示例/伪代码

data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])

(train_ds, test_ds), info = tfds.load('somedataset', split=['train[:80%]', 'train[80%:]'], with_info=True)

这段代码不起作用,但说明了我的梦想 api 是如何实现的:

train_ds = train_ds.map(lambda datapoint: data_augmentation((datapoint['image'], datapoint['segmentation_mask']), training=True))

备选

另一种方法是按照图像分割教程 (https://www.tensorflow.org/tutorials/images/segmentation)

中的建议编写自定义加载和操作/运行domization 方法

非常感谢有关此类数据集的最先进数据扩充的任何提示:)

您可以尝试使用外部库进行额外的图像增强。这些链接可能有助于图像增强以及分割掩码,

白化

https://github.com/albumentations-team/albumentations

https://albumentations.ai/docs/getting_started/mask_augmentation/

imgaug

https://github.com/aleju/imgaug

https://nbviewer.jupyter.org/github/aleju/imgaug-doc/blob/master/notebooks/B05%20-%20Augment%20Segmentation%20Maps.ipynb

这是我自己的实现,以防其他人想在 2020 年 12 月使用内置函数 (tf.image api):)

@tf.function
def load_image(datapoint, augment=True):
    
    # resize image and mask
    img_orig = input_image = tf.image.resize(datapoint['image'], (IMG_SIZE, IMG_SIZE))
    mask_orig = input_mask = tf.image.resize(datapoint['segmentation_mask'], (IMG_SIZE, IMG_SIZE))
    
    # rescale the image
    if IMAGE_CHANNELS == 1:
        input_image = tf.image.rgb_to_grayscale(input_image)
    input_image = tf.cast(input_image, tf.float32) / 255.0
    
    # augmentation
    if augment:
        # zoom in a bit
        if tf.random.uniform(()) > 0.5:
            # use original image to preserve high resolution
            input_image = tf.image.central_crop(img_orig, 0.75)
            input_mask = tf.image.central_crop(mask_orig, 0.75)
            # resize
            input_image = tf.image.resize(input_image, (IMG_SIZE, IMG_SIZE))
            input_mask = tf.image.resize(input_mask, (IMG_SIZE, IMG_SIZE))
        
        # random brightness adjustment illumination
        input_image = tf.image.random_brightness(input_image, 0.3)
        # random contrast adjustment
        input_image = tf.image.random_contrast(input_image, 0.2, 0.5)
        
        # flipping random horizontal or vertical
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_left_right(input_image)
            input_mask = tf.image.flip_left_right(input_mask)
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_up_down(input_image)
            input_mask = tf.image.flip_up_down(input_mask)

        # rotation in 30° steps
        rot_factor = tf.cast(tf.random.uniform(shape=[], maxval=12, dtype=tf.int32), tf.float32)
        angle = np.pi/12*rot_factor
        input_image = tfa.image.rotate(input_image, angle)
        input_mask = tfa.image.rotate(input_mask, angle)

    return input_image, input_mask

修复一个共同的种子将对图像和蒙版应用相同的增强。

def Augment(tar_shape=(512,512), seed=37):
    img = tf.keras.Input(shape=(None,None,3))
    msk = tf.keras.Input(shape=(None,None,1))

    i = tf.keras.layers.RandomFlip(seed=seed)(img)
    m = tf.keras.layers.RandomFlip(seed=seed)(msk)
    i = tf.keras.layers.RandomTranslation((-0.75, 0.75),(-0.75, 0.75),seed=seed)(i)
    m = tf.keras.layers.RandomTranslation((-0.75, 0.75),(-0.75, 0.75),seed=seed)(m)
    i = tf.keras.layers.RandomRotation((-0.35, 0.35),seed=seed)(i)
    m = tf.keras.layers.RandomRotation((-0.35, 0.35),seed=seed)(m)
    i = tf.keras.layers.RandomZoom((-0.1, 0.05),(-0.1, 0.05),seed=seed)(i)
    m = tf.keras.layers.RandomZoom((-0.1, 0.05),(-0.1, 0.05),seed=seed)(m)
    i = tf.keras.layers.RandomCrop(tar_shape[0],tar_shape[1],seed=seed)(i)
    m = tf.keras.layers.RandomCrop(tar_shape[0],tar_shape[1],seed=seed)(m)
    
    return tf.keras.Model(inputs=(img,msk), outputs=(i,m))
Augment = Augment()

ds_train = ds_train.map(lambda img,msk: Augment((img,msk)), num_parallel_calls=AUTOTUNE)

小鬼:

  • 以上函数可以将图像和遮罩的dtype从int32/uint8更改为float32。
  • 而且输出掩码可以包含 0/1 以外的值(例如 0.9987,...)。这是由于插值。为了克服这个问题,您可以将插值从双线性更改为最近。

我通过使用 concat 创建一个图像然后使用增强层解决了这个问题。

def augment_using_layers(images, mask, size=None):
    
    if size is None:
        h_s = mask.shape[0]
        w_s = mask.shape[1]
    else:
        h_s = size[0]
        w_s = size[1]
    
    def aug(height=h_s, width=w_s):

        flip = tf.keras.layers.RandomFlip(mode="horizontal")
        
        rota = tf.keras.layers.RandomRotation(0.2, fill_mode='constant')
        
        zoom = tf.keras.layers.RandomZoom(
                            height_factor=(-0.05, -0.15),
                            width_factor=(-0.05, -0.15)
                            )
        
        trans = tf.keras.layers.RandomTranslation(height_factor=(-0.1, 0.1),
                                            width_factor=(-0.1, 0.1), 
                                            fill_mode='constant')
        
        crop = tf.keras.layers.RandomCrop(h_s, w_s)
        
        layers = [flip, zoom, crop, trans, rota]
        aug_model = tf.keras.Sequential(layers)

        return aug_model
    
    aug = aug()
    
    mask = tf.stack([mask, mask, mask], -1)
    mask = tf.cast(mask, 'float32')

    images_mask = tf.concat([images, mask], -1)  
    images_mask = aug(images_mask)  
    
    image = images_mask[:,:,0:3]
    mask = images_mask[:,:,4]
    
    return image, tf.cast(mask, 'uint8')

然后您可以映射您的数据集:

# create dataset
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.map(lambda x: load_dataset(x, (400, 400)))

# aug. dataset
dataset_aug = dataset.map(lambda x, y: augment_using_layers(x, y, (400, 400)))

输出: