使用 python 函数和 tf.Dataset API 的数据扩充
Data augmentation using python function with tf.Dataset API
我正在寻找动态读取的图像并为我的图像分割问题应用数据增强。到目前为止,我所看到的最好的方法是 tf.Dataset
API 和 .map
函数。
但是,从我看到的示例来看,我认为我必须使我的所有函数适应 tensorflow 样式(使用 tf.cond
而不是 if
,等等)。问题是我需要应用一些非常复杂的函数。因此我正在考虑像这样使用 tf.py_func
:
import tensorflow as tf
img_path_list = [...] # List of paths to read
mask_path_list = [...] # List of paths to read
dataset = tf.data.Dataset.from_tensor_slices((img_path_list, mask_path_list))
def parse_function(img_path_list, mask_path_list):
'''load image and mask from paths'''
return img, mask
def data_augmentation(img, mask):
'''process data with complex logic'''
return aug_img, aug_mask
# py_func wrappers
def parse_function_wrapper(img_path_list, mask_path_list):
return tf.py_func(func=parse_function,
inp=(img_path_list, mask_path_list),
Tout=(tf.float32, tf.float32))
def data_augmentation_wrapper(img, mask):
return tf.py_func(func=data_augmentation,
inp=(img, mask),
Tout=(tf.float32, tf.float32))
# Maps py_funcs to dataset
dataset = dataset.map(parse_function_wrapper,
num_parallel_calls=4)
dataset = dataset.map(data_augmentation_wrapper,
num_parallel_calls=4)
dataset = dataset.batch(32)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
但是,从 this answer 看来,使用 py_func
进行并行处理似乎不起作用。还有其他选择吗?
py_func 受到 python GIL 的限制,因此您不会在那里获得太多并行性。你最好的选择是在 tensorflow 中编写你的数据扩充(或者预先计算它并将它序列化到磁盘)。
如果你确实想在tensorflow中编写它,你可以尝试使用tf.contrib.autograph将简单的python ifs和for循环转换为tf.conds和tf.while_loops,这可能会大大简化您的代码。
我正在寻找动态读取的图像并为我的图像分割问题应用数据增强。到目前为止,我所看到的最好的方法是 tf.Dataset
API 和 .map
函数。
但是,从我看到的示例来看,我认为我必须使我的所有函数适应 tensorflow 样式(使用 tf.cond
而不是 if
,等等)。问题是我需要应用一些非常复杂的函数。因此我正在考虑像这样使用 tf.py_func
:
import tensorflow as tf
img_path_list = [...] # List of paths to read
mask_path_list = [...] # List of paths to read
dataset = tf.data.Dataset.from_tensor_slices((img_path_list, mask_path_list))
def parse_function(img_path_list, mask_path_list):
'''load image and mask from paths'''
return img, mask
def data_augmentation(img, mask):
'''process data with complex logic'''
return aug_img, aug_mask
# py_func wrappers
def parse_function_wrapper(img_path_list, mask_path_list):
return tf.py_func(func=parse_function,
inp=(img_path_list, mask_path_list),
Tout=(tf.float32, tf.float32))
def data_augmentation_wrapper(img, mask):
return tf.py_func(func=data_augmentation,
inp=(img, mask),
Tout=(tf.float32, tf.float32))
# Maps py_funcs to dataset
dataset = dataset.map(parse_function_wrapper,
num_parallel_calls=4)
dataset = dataset.map(data_augmentation_wrapper,
num_parallel_calls=4)
dataset = dataset.batch(32)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
但是,从 this answer 看来,使用 py_func
进行并行处理似乎不起作用。还有其他选择吗?
py_func 受到 python GIL 的限制,因此您不会在那里获得太多并行性。你最好的选择是在 tensorflow 中编写你的数据扩充(或者预先计算它并将它序列化到磁盘)。
如果你确实想在tensorflow中编写它,你可以尝试使用tf.contrib.autograph将简单的python ifs和for循环转换为tf.conds和tf.while_loops,这可能会大大简化您的代码。