直接来自 tf.train.SequenceExample 的数据集

Dataset directly from tf.train.SequenceExample

我正在 tensorflow 中使用类似 NER 的序列标记,并决定尝试 tf.data 看看我的模型是否可以提高 IO 性能。

目前我正在使用 TFRecordWriter 预处理和保存我的 training/validation 数据,这是一个 tf.train.SequenceExample() 序列化为字符串。然后我用 tf.data.TFRecordDataset、parse/shuffle/padded_batch 加载它并继续训练,效果很好。

问题是:

在这种情况下可以使用 tf.data.Dataset.from_generator()。例如,假设您的示例看起来像以下非常简单的数据,具有两个特征(其中第二个代表顺序数据):

examples = [("foo", [1, 2, 3, 4, 5]),
            ("bar", [6, 7]),
            ("baz", [8, 9, 10])]

您可以使用以下代码将其转换为 tf.data.Dataset

def example_generator():
  for string_feature, sequence_feature in examples:
    yield string_feature, sequence_feature

dataset = tf.data.Dataset.from_generator(
    example_generator,
    output_types=(tf.string, tf.int32),
    output_shapes=([], [None]),  # A scalar and a variable-length vector.  
)