使用 Tensorflow 进行文本输入

Text input with Tensorflow

我正在玩 Tensorflow 并尝试构建 RNN 语言模型。 我正在为如何读取原始文本输入文件而苦恼。

Tensorflow guide 提到了一些方法,包括:

  1. tf.data.Dataset.from_tensor_slices() - 假设我的数据在内存中可用(np.array?)
  2. tf.data.TFRecordDataset(不知道怎么用)
  3. tf.data.TextLineDataset(和2有什么区别?API页面几乎一样)

混淆2和3,只能尝试方法1,但面临以下问题:

  1. 如果我的数据太大而无法放入内存怎么办?
  2. TF 需要固定长度的填充格式,我该怎么做? - 我: 确定固定长度值(例如 30), 将每一行读入一个列表,如果列表更长,则将其截断为 30 然后30, 填充'0'使每行至少30长, 将列表附加到 numpy array/matrix ?

我相信这些都是常见的问题,tensorflow 已经提供了很多内置函数!

如果您的数据是文本文件(csv、tsv 或只是行的集合),如果您需要一些详细信息,最好的方法是使用 tf.data.TextLineDataset; tf.data.TFRecordDataset has a similar API, but it's for TFRecord binary format (checkout this nice post 处理它。

通过数据集 API 处理文本行的一个很好的例子是 TensorFlow Wide & Deep Learning Tutorial (the code is here)。这是那里使用的输入函数:

def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
      '%s not found. Please make sure you have either run data_download.py or '
      'set both arguments --train_data and --test_data.' % data_file)

  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
  dataset = tf.data.TextLineDataset(data_file)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

  dataset = dataset.map(parse_csv, num_parallel_calls=5)

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)

  iterator = dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

以下是这段代码中发生的事情:

  • tf.data.TextLineDataset(data_file) 行创建一个 Dataset 对象,分配给 dataset。它是一个包装器,而不是内容容器,所以数据从不完全读入内存。

  • Dataset API 允许 pre-process 数据,例如用shufflemapbatch等方法。请注意,API 是函数式的,这意味着当您调用 Dataset 方法时不会处理任何数据,它们只是定义在会话实际开始并评估迭代器时将使用张量执行哪些转换(见下文) .

  • 最后,dataset.make_one_shot_iterator() returns 一个迭代器张量,可以从中读取值。你可以评估featureslabels,他们将得到转换后的数据批次的值。

  • 另请注意,如果您在 GPU 上训练模型,数据将直接流式传输到设备,无需在客户端(python 脚本本身)中途停止。

根据您的特定格式,您可能不需要解析 csv 列,只需逐行读取即可。


推荐阅读Importing Data指南。