Keras 扩充不适用于 tf.data.Dataset 地图
Keras augmentation does not work with tf.data.Dataset map
我试图让预处理函数与数据集地图一起工作,但我收到以下错误(底部的完整堆栈跟踪):
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.
下面是重现问题的完整代码段。我的问题是,为什么在一个用例(仅限裁剪)中它有效,而在使用 RandomFlip 时却无效?如何解决?
import functools
import numpy as np
import tensorflow as tf
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = image
y = label
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
y = tf.keras.layers.RandomFlip(mode="horizontal")(y)
x = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(x)
y = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(y)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return x, y
print(tf.__version__) # 2.6.0
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype='uint8'),
tf.TensorSpec(shape=(40, 40, 1), dtype='uint8')
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
# This works
crop_dataset = dataset.map(crop_only_fn)
# This fails: ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable
train_dataset = dataset.map(train_preprocess_fn)
全栈跟踪:
Traceback (most recent call last):
File "./issue_dataaug.py", line 50, in <module>
train_dataset = dataset.map(train_preprocess_fn)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1861, in map
return MapDataset(self, map_func, preserve_cardinality=True)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4985, in __init__
use_legacy_function=use_legacy_function)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4218, in __init__
self._function = fn_factory()
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3151, in get_concrete_function
*args, **kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3116, in _get_concrete_function_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3308, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4195, in wrapped_fn
ret = wrapper_helper(*args)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4125, in wrapper_helper
ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
./issue_dataaug.py:25 preprocess *
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/keras/layers/preprocessing/image_preprocessing.py:414 __init__ **
self._rng = make_generator(self.seed)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/keras/layers/preprocessing/image_preprocessing.py:1375 make_generator
return tf.random.Generator.from_non_deterministic_state()
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:396 from_non_deterministic_state
return cls(state=state, alg=alg)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:476 __init__
trainable=False)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:489 _create_variable
return variables.Variable(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:268 __call__
return cls._variable_v2_call(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 _variable_v2_call
shape=shape)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:243 <lambda>
previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py:2675 default_variable_creator_v2
shape=shape)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:270 __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1613 __init__
distribute_strategy=distribute_strategy)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1695 _init_from_args
raise ValueError("Tensor-typed variable initializers must either be "
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.
我不太确定这是否与您的问题直接相关,但是在 TF
2.7 上您的代码根本不起作用,因为所有 Keras
增强层都期望 float
值而不是 uint8
。所以,也许尝试像这样投射你的数据:
import functools
import numpy as np
import tensorflow as tf
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = tf.cast(image, dtype=tf.float32)
y = tf.cast(label, dtype=tf.float32)
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
y = tf.keras.layers.RandomFlip(mode="horizontal")(y)
x = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(x)
y = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(y)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return tf.cast(x, dtype=tf.uint8), tf.cast(y, dtype=tf.uint8)
print(tf.__version__) # 2.6.0
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype=tf.uint8),
tf.TensorSpec(shape=(40, 40, 1), dtype=tf.uint8)
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
# This works
crop_dataset = dataset.map(crop_only_fn)
# This fails: ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable
train_dataset = dataset.map(train_preprocess_fn)
附带说明一下,Keras
增强层通常用作您计划训练的模型的一部分。您也可以使用 tf.image
函数,例如 tf.image.central_crop
、tf.image.random_flip_left_right
甚至 tfa.image.rotate
.
更新 1: 您收到了评论中提到的错误,因为如 here 所述,图层 tf.keras.layers.RandomFlip
和 tf.keras.layers.RandomRotation
仅在训练期间有效。所以尝试使用其他方法:
import functools
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = tf.cast(image, dtype=tf.float32)
y = tf.cast(label, dtype=tf.float32)
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.image.random_flip_left_right(x)
y = tf.image.random_flip_left_right(y)
x = tfa.image.rotate(x, 90, fill_mode='constant')
y = tfa.image.rotate(y, 90, fill_mode='constant')
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return tf.cast(x, dtype=tf.uint8), tf.cast(y, dtype=tf.uint8)
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype=tf.uint8),
tf.TensorSpec(shape=(40, 40, 1), dtype=tf.uint8)
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
crop_dataset = dataset.map(crop_only_fn)
train_dataset = dataset.map(train_preprocess_fn)
image, _ = next(iter(train_dataset.take(1)))
plt.imshow(image.numpy())
我排除了 tf.keras.preprocessing.image.random_rotation
,因为它现在似乎不适用于张量。
正如我评论的那样,我发现您提到的错误无法重现。但是,它只需要在 __init___
方法中初始化增强层。
ValueError: Tensor-typed variable initializers must either be wrapped
in an init_scope or callable (e.g., tf.Variable(lambda : tf.truncated_normal([10, 40]))
) when building functions. Please file
a feature request if this restriction inconveniences you.
这是完整的工作代码。
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
class Augment(tf.keras.layers.Layer):
def __init__(self, seed=42):
super().__init__()
self.flip_a = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.flip_b = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.rot_a = tf.keras.layers.RandomRotation(factor=1.0,
fill_mode='constant', seed=seed)
self.rot_b = tf.keras.layers.RandomRotation(factor=1.0,
fill_mode='constant', seed=seed)
def call(self, inputs, labels):
x = self.flip_a(inputs)
x = self.rot_a(x)
y = self.flip_b(labels)
y = self.rot_b(y)
return x, y
def preprocess(image, label, cropped_image_size, cropped_label_size):
x = image
y = label
x_size = cropped_image_size
y_size = cropped_label_size
x = tf.cast(x, dtype=tf.float32)
y = tf.cast(y, dtype=tf.float32)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
x = tf.cast(x, dtype=tf.uint8)
y = tf.cast(y, dtype=tf.uint8)
return x, y
数据
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype='uint8'),
tf.TensorSpec(shape=(40, 40, 1), dtype='uint8')
))
测试 1
crop_only_fn = functools.partial(preprocess,
cropped_image_size=50,
cropped_label_size=25)
# This works
crop_dataset = dataset.map(crop_only_fn)
x, y = next(iter(crop_dataset))
x.shape, y.shape
(TensorShape([50, 50, 3]), TensorShape([25, 25, 1]))
测试 2
train_preprocess_fn = functools.partial(preprocess,
cropped_image_size=50,
cropped_label_size=25)
train_dataset = dataset.map(train_preprocess_fn)
train_dataset = train_dataset.map(Augment()) # < calling now.
x, y = next(iter(train_dataset))
x.shape, y.shape
(TensorShape([50, 50, 3]), TensorShape([25, 25, 1]))
我试图让预处理函数与数据集地图一起工作,但我收到以下错误(底部的完整堆栈跟踪):
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.
下面是重现问题的完整代码段。我的问题是,为什么在一个用例(仅限裁剪)中它有效,而在使用 RandomFlip 时却无效?如何解决?
import functools
import numpy as np
import tensorflow as tf
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = image
y = label
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
y = tf.keras.layers.RandomFlip(mode="horizontal")(y)
x = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(x)
y = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(y)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return x, y
print(tf.__version__) # 2.6.0
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype='uint8'),
tf.TensorSpec(shape=(40, 40, 1), dtype='uint8')
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
# This works
crop_dataset = dataset.map(crop_only_fn)
# This fails: ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable
train_dataset = dataset.map(train_preprocess_fn)
全栈跟踪:
Traceback (most recent call last):
File "./issue_dataaug.py", line 50, in <module>
train_dataset = dataset.map(train_preprocess_fn)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1861, in map
return MapDataset(self, map_func, preserve_cardinality=True)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4985, in __init__
use_legacy_function=use_legacy_function)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4218, in __init__
self._function = fn_factory()
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3151, in get_concrete_function
*args, **kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3116, in _get_concrete_function_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3308, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4195, in wrapped_fn
ret = wrapper_helper(*args)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4125, in wrapper_helper
ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
./issue_dataaug.py:25 preprocess *
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/keras/layers/preprocessing/image_preprocessing.py:414 __init__ **
self._rng = make_generator(self.seed)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/keras/layers/preprocessing/image_preprocessing.py:1375 make_generator
return tf.random.Generator.from_non_deterministic_state()
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:396 from_non_deterministic_state
return cls(state=state, alg=alg)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:476 __init__
trainable=False)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:489 _create_variable
return variables.Variable(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:268 __call__
return cls._variable_v2_call(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 _variable_v2_call
shape=shape)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:243 <lambda>
previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py:2675 default_variable_creator_v2
shape=shape)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:270 __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1613 __init__
distribute_strategy=distribute_strategy)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1695 _init_from_args
raise ValueError("Tensor-typed variable initializers must either be "
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.
我不太确定这是否与您的问题直接相关,但是在 TF
2.7 上您的代码根本不起作用,因为所有 Keras
增强层都期望 float
值而不是 uint8
。所以,也许尝试像这样投射你的数据:
import functools
import numpy as np
import tensorflow as tf
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = tf.cast(image, dtype=tf.float32)
y = tf.cast(label, dtype=tf.float32)
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
y = tf.keras.layers.RandomFlip(mode="horizontal")(y)
x = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(x)
y = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(y)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return tf.cast(x, dtype=tf.uint8), tf.cast(y, dtype=tf.uint8)
print(tf.__version__) # 2.6.0
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype=tf.uint8),
tf.TensorSpec(shape=(40, 40, 1), dtype=tf.uint8)
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
# This works
crop_dataset = dataset.map(crop_only_fn)
# This fails: ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable
train_dataset = dataset.map(train_preprocess_fn)
附带说明一下,Keras
增强层通常用作您计划训练的模型的一部分。您也可以使用 tf.image
函数,例如 tf.image.central_crop
、tf.image.random_flip_left_right
甚至 tfa.image.rotate
.
更新 1: 您收到了评论中提到的错误,因为如 here 所述,图层 tf.keras.layers.RandomFlip
和 tf.keras.layers.RandomRotation
仅在训练期间有效。所以尝试使用其他方法:
import functools
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = tf.cast(image, dtype=tf.float32)
y = tf.cast(label, dtype=tf.float32)
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.image.random_flip_left_right(x)
y = tf.image.random_flip_left_right(y)
x = tfa.image.rotate(x, 90, fill_mode='constant')
y = tfa.image.rotate(y, 90, fill_mode='constant')
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return tf.cast(x, dtype=tf.uint8), tf.cast(y, dtype=tf.uint8)
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype=tf.uint8),
tf.TensorSpec(shape=(40, 40, 1), dtype=tf.uint8)
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
crop_dataset = dataset.map(crop_only_fn)
train_dataset = dataset.map(train_preprocess_fn)
image, _ = next(iter(train_dataset.take(1)))
plt.imshow(image.numpy())
我排除了 tf.keras.preprocessing.image.random_rotation
,因为它现在似乎不适用于张量。
正如我评论的那样,我发现您提到的错误无法重现。但是,它只需要在 __init___
方法中初始化增强层。
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g.,
tf.Variable(lambda : tf.truncated_normal([10, 40]))
) when building functions. Please file a feature request if this restriction inconveniences you.
这是完整的工作代码。
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
class Augment(tf.keras.layers.Layer):
def __init__(self, seed=42):
super().__init__()
self.flip_a = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.flip_b = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.rot_a = tf.keras.layers.RandomRotation(factor=1.0,
fill_mode='constant', seed=seed)
self.rot_b = tf.keras.layers.RandomRotation(factor=1.0,
fill_mode='constant', seed=seed)
def call(self, inputs, labels):
x = self.flip_a(inputs)
x = self.rot_a(x)
y = self.flip_b(labels)
y = self.rot_b(y)
return x, y
def preprocess(image, label, cropped_image_size, cropped_label_size):
x = image
y = label
x_size = cropped_image_size
y_size = cropped_label_size
x = tf.cast(x, dtype=tf.float32)
y = tf.cast(y, dtype=tf.float32)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
x = tf.cast(x, dtype=tf.uint8)
y = tf.cast(y, dtype=tf.uint8)
return x, y
数据
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype='uint8'),
tf.TensorSpec(shape=(40, 40, 1), dtype='uint8')
))
测试 1
crop_only_fn = functools.partial(preprocess,
cropped_image_size=50,
cropped_label_size=25)
# This works
crop_dataset = dataset.map(crop_only_fn)
x, y = next(iter(crop_dataset))
x.shape, y.shape
(TensorShape([50, 50, 3]), TensorShape([25, 25, 1]))
测试 2
train_preprocess_fn = functools.partial(preprocess,
cropped_image_size=50,
cropped_label_size=25)
train_dataset = dataset.map(train_preprocess_fn)
train_dataset = train_dataset.map(Augment()) # < calling now.
x, y = next(iter(train_dataset))
x.shape, y.shape
(TensorShape([50, 50, 3]), TensorShape([25, 25, 1]))