Keras 实验性 RandomFlip 和 RandomRotation 不适用于地图

Keras experimental RandomFlip and RandomRotation do not work with map

这段代码产生了一个我不明白的错误。有人可以解释一下吗?

import tensorflow as tf

def augment(img):
    data_augmentation = tf.keras.Sequential([
              tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
              tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
             ])
    img = tf.expand_dims(img, 0)
    return data_augmentation(img)

# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))
dataset = tf.data.Dataset.from_tensor_slices(data)

# and augment... -> bug
dataset = dataset.map(augment)

# note that the follwing works
for im in dataset:
   augment(im)

和一个get

ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.

我试过 Google Colab 并在我的电脑上安装了 Tensorflow 2.4.1。请注意,通过调整大小或重新缩放它可以工作(就像在这个例子中一样 https://www.tensorflow.org/tutorials/images/data_augmentation 但他们没有尝试使用 RandomRotate 即使他们在循环中使用它)。

我认为您混淆了 tf.keras.layers.experimental.preprocessing.* 的目的。它们将与您的模型结合使用。这样数据增强就可以通过它自己的模型得到简化。

换句话说,这些层是您模型的一部分,不是您的数据管道(例如,您正尝试将其与 dataset.map 一起使用).如果您想将这些层与 tf.data.Dataset 一起使用,这里有一个工作示例。

import tensorflow as tf
import numpy as np

def augment(img):
    data_augmentation = tf.keras.Sequential([
              tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
              tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
             ])    
    return data_augmentation(img)

# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))

dataset = tf.data.Dataset.from_tensor_slices(data).batch(5)

for d in dataset:
  aug_d = augment(d)

答案在这里...

import numpy as np
import tensorflow as tf

data_augmentation = tf.keras.Sequential([
              tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
              tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
             ])

# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))
dataset = tf.data.Dataset.from_tensor_slices(data).batch(5)

# and augment... -> bug
dataset = dataset.map(lambda x: data_augmentation(x))

奇怪,如果我们使用 lambda 函数它可以工作,如果我们定义一个只调用 data_augmentation 的函数它会失败...