Tensorflow:自定义数据扩充

Tensorflow: Custom data augmentation

我正在尝试定义自定义数据增强层。我的目标是调用现有的tf.keras.layers.RandomZoom,概率为

这是我做的:

class random_zoom_layer(tf.keras.layers.Layer):
    def __init__(self, probability=0.5, **kwargs):
        super().__init__(**kwargs)
        self.probability = probability

    def call(self, x):
        if tf.random.uniform([]) < self.probability:
            return tf.keras.layers.RandomZoom(height_factor=(-0.1, 0.1), width_factor=(-0.1, 0.1), fill_mode='constant')(x)
        else:
            return x


data_augmentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.Normalization(),
    random_zoom_layer(probability=0.2)
])

但是在训练过程中,我收到这个错误:

tensorflow.python.framework.errors_impl.NotFoundError: 2 root error(s) found.
  (0) NOT_FOUND:  2 root error(s) found.
  (0) NOT_FOUND:  Resource localhost/_AnonymousVar10/class tensorflow::Var does not exist.
     [[{{node sequential_1/random_zoom_layer/cond/random_zoom/stateful_uniform/RngReadAndSkip}}]]
     [[sequential_1/random_zoom_layer/cond/then/_0/sequential_1/random_zoom_layer/cond/random_zoom/stateful_uniform_1/RngReadAndSkip/_15]]
  (1) NOT_FOUND:  Resource localhost/_AnonymousVar10/class tensorflow::Var does not exist.
     [[{{node sequential_1/random_zoom_layer/cond/random_zoom/stateful_uniform/RngReadAndSkip}}]]

非常感谢您的帮助!

也许你可以尝试这样的事情:

import tensorflow as tf


class random_zoom_layer(tf.keras.layers.Layer):
    def __init__(self, probability=0.5, **kwargs):
        super().__init__(**kwargs)
        self.probability = probability
        self.layer = tf.keras.layers.RandomZoom(height_factor=(-0.1, 0.1), width_factor=(-0.1, 0.1), fill_mode='constant')

    def call(self, x):
        return tf.cond(tf.less(tf.random.uniform([]), self.probability), lambda: self.layer(x), lambda: x)


data_augmentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.Normalization(),
    random_zoom_layer(probability=0.2)
])

print(data_augmentation(tf.random.normal((1, 32, 32, 4))))
import matplotlib.pyplot as plt

image = tf.random.normal((1, 32, 32, 4))
plt.imshow(tf.squeeze(image, axis=0))

plt.imshow(tf.squeeze(data_augmentation(tf.random.normal((1, 32, 32, 4))), axis=0))