在 Tensorflow 的数据集中使用 flat_map API
Using flat_map in Tensorflow's Dataset API
我正在使用数据集API,读取数据如下:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
我现在想使用 flat_map
来过滤掉一些样本,同时在训练时动态复制一些其他样本(这是导致我的模型的输入函数)。
flat_map
的 API 需要 return 一个 Dataset
对象,但我不知道如何创建它。这是我想要实现的伪代码实现:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]
dataset.flat_map(flat_map_impl)
如何在 flat_map
函数中实现它?
注意:我想可以通过 py_func
实现这一点,但我宁愿避免这种情况。
也许创建 tf.data.Dataset
when returning from a Dataset.flat_map()
is to use Dataset.from_tensors()
or Dataset.from_tensor_slices()
. In this case, because tf_example
is a dictionary, it is probably easiest to use a combination of Dataset.from_tensors()
and Dataset.repeat(count)
, where a conditional expression 最常见的方法是计算 count
:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
def flat_map_impl(tf_example):
count = tf.cond(tf.equal(tf_example["a"], 1)),
lambda: tf.constant(0, dtype=tf.int64),
lambda: tf.constant(2, dtype=tf.int64))
return tf.data.Dataset.from_tensors(tf_example).repeat(count)
dataset = dataset.flat_map(flat_map_impl)
我正在使用数据集API,读取数据如下:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
我现在想使用 flat_map
来过滤掉一些样本,同时在训练时动态复制一些其他样本(这是导致我的模型的输入函数)。
flat_map
的 API 需要 return 一个 Dataset
对象,但我不知道如何创建它。这是我想要实现的伪代码实现:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]
dataset.flat_map(flat_map_impl)
如何在 flat_map
函数中实现它?
注意:我想可以通过 py_func
实现这一点,但我宁愿避免这种情况。
也许创建 tf.data.Dataset
when returning from a Dataset.flat_map()
is to use Dataset.from_tensors()
or Dataset.from_tensor_slices()
. In this case, because tf_example
is a dictionary, it is probably easiest to use a combination of Dataset.from_tensors()
and Dataset.repeat(count)
, where a conditional expression 最常见的方法是计算 count
:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
def flat_map_impl(tf_example):
count = tf.cond(tf.equal(tf_example["a"], 1)),
lambda: tf.constant(0, dtype=tf.int64),
lambda: tf.constant(2, dtype=tf.int64))
return tf.data.Dataset.from_tensors(tf_example).repeat(count)
dataset = dataset.flat_map(flat_map_impl)