在 Tensorflow 的数据集中使用 flat_map API

Using flat_map in Tensorflow's Dataset API

我正在使用数据集API,读取数据如下:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

我现在想使用 flat_map 来过滤掉一些样本,同时在训练时动态复制一些其他样本(这是导致我的模型的输入函数)。

flat_map 的 API 需要 return 一个 Dataset 对象,但我不知道如何创建它。这是我想要实现的伪代码实现:

def flat_map_impl(tf_example):
    # Pseudo-code:
    # if tf_example["a"] == 1:
    #     return []
    # else:
    #     return [tf_example, tf_example]

dataset.flat_map(flat_map_impl)

如何在 flat_map 函数中实现它?

注意:我想可以通过 py_func 实现这一点,但我宁愿避免这种情况。

也许创建 tf.data.Dataset when returning from a Dataset.flat_map() is to use Dataset.from_tensors() or Dataset.from_tensor_slices(). In this case, because tf_example is a dictionary, it is probably easiest to use a combination of Dataset.from_tensors() and Dataset.repeat(count), where a conditional expression 最常见的方法是计算 count:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

def flat_map_impl(tf_example):
  count = tf.cond(tf.equal(tf_example["a"], 1)),
                  lambda: tf.constant(0, dtype=tf.int64),
                  lambda: tf.constant(2, dtype=tf.int64))

  return tf.data.Dataset.from_tensors(tf_example).repeat(count)

dataset = dataset.flat_map(flat_map_impl)