从 Tensorflow 中的一个 TFRecord 示例中读取多个特征向量

Reading multiple feature vectors from one TFRecord example in Tensorflow

我知道如何将每个示例的一个特征存储在 tfrecord 文件中,然后使用如下方式读取它:

import tensorflow as tf
import numpy as np
import os


# This is used to parse an example from tfrecords
def parse(serialized_example):
  features = tf.parse_single_example(
    serialized_example,
    features ={
      "label": tf.FixedLenFeature([], tf.string, default_value=""),
      "feat": tf.FixedLenFeature([], tf.string, default_value="")
    })

  feat = tf.decode_raw(features['feat'], tf.float64)
  label = tf.decode_raw(features['label'], tf.int64)

  return feat, label


################# Generate data

cwd = os.getcwd()
numdata = 10
with tf.python_io.TFRecordWriter(os.path.join(cwd, 'data.tfrecords')) as writer:
    for i in range(numdata):
        feat = np.random.randn(2)
        label = np.array(np.random.randint(0,9))

        featb  = feat.tobytes()
        labelb = label.tobytes()
        import pudb.b
        example = tf.train.Example(features=tf.train.Features(
            feature={
            'feat': tf.train.Feature(bytes_list=tf.train.BytesList(value=[featb])),
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[labelb])),}))
        writer.write(example.SerializeToString())

        print('wrote f {}, l {}'.format(feat, label))

print('Done writing! Start reading and printing data')

################# Read data

filename = ['data.tfrecords']
dataset = tf.data.TFRecordDataset(filename).map(parse)
dataset = dataset.batch(100)
iterator = dataset.make_initializable_iterator()
feat, label = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        while True:
            example = sess.run((feat,label))
            print example
    except tf.errors.OutOfRangeError:
        pass

如果每个示例中都有多个特征向量+标签,我该怎么办。例如,在上面的代码中,如果将 feat 存储为二维数组。我仍然想做和以前一样的事情,即用每个标签一个特征来训练 DNN,但是 tfrecords 文件中的每个示例都有多个特征和多个标签。这应该很简单,但我在使用 tfrecords 解压 tensorflow 中的多个功能时遇到了问题。

首先,请注意 np.ndarray.tobytes() 将多维数组展平为列表,即

feat = np.random.randn(N, 2)
reshaped = np.reshape(feat, (N*2,))
feat.tobytes() == reshaped.tobytes()   ## True

因此,如果您有一个以 TFRecord 格式保存为字节的 N*2 数组,则必须在解析后对其进行整形。

如果这样做,您可以取消对 tf.data.Dataset 的元素进行批处理,这样每次迭代都会为您提供一个特征和一个标签。您的代码应如下所示:

# This is used to parse an example from tfrecords
def parse(serialized_example):
  features = tf.parse_single_example(
    serialized_example,
    features ={
      "label": tf.FixedLenFeature([], tf.string, default_value=""),
      "feat": tf.FixedLenFeature([], tf.string, default_value="")
    })

  feat = tf.decode_raw(features['feat'], tf.float64)    # array of shape (N*2, )
  feat = tf.reshape(feat, (N, 2))                       # array of shape (N, 2)
  label = tf.decode_raw(features['label'], tf.int64)    # array of shape (N, )

  return feat, label


################# Generate data

cwd = os.getcwd()
numdata = 10
with tf.python_io.TFRecordWriter(os.path.join(cwd, 'data.tfrecords')) as writer:
    for i in range(numdata):
        feat = np.random.randn(N, 2)
        label = np.array(np.random.randint(0,9, N))

        featb  = feat.tobytes()
        labelb = label.tobytes()
        example = tf.train.Example(features=tf.train.Features(
            feature={
            'feat': tf.train.Feature(bytes_list=tf.train.BytesList(value=[featb])),
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[labelb])),}))
        writer.write(example.SerializeToString())

        print('wrote f {}, l {}'.format(feat, label))

print('Done writing! Start reading and printing data')

################# Read data

filename = ['data.tfrecords']
dataset = tf.data.TFRecordDataset(filename).map(parse).apply(tf.contrib.data.unbatch())
... etc