从 tfrecords 导入数据时,批处理后标签顺序错误
The order of labels are wrong after batching when importing data from tfrecords
我在从 tfrecords 文件导入数据时遇到问题。 tfrecords中的每个样本由一个长度为100的feautures向量和一个长度为13的one-hot label向量组成。我使用下面的代码从tfrecords导入数据,参考官方指南https://www.tensorflow.org/programmers_guide/datasets
def read_data(examples):
features = {"features": tf.FixedLenFeature([seq_len], tf.int64),
"label": tf.FixedLenFeature([category], tf.int64)}
parsed_features = tf.parse_single_example(examples, features)
return parsed_features['features'], parsed_features['label']
# get next batch of data and label
def next_batch(filename, batch_size):
data = tf.data.TFRecordDataset(filename)
data = data.map(read_data)
data = data.batch(batch_size)
iterator = data.make_one_shot_iterator()
next_data, next_label = iterator.get_next()
return next_data, next_label
with tf.Session() as sess:
filetrain = 'train.tfrecords'
next_data, next_label = next_batch(filetrain, num_example_train)
sess.run(tf.global_variables_initializer())
data = sess.run(next_data)
label = sess.run(next_label)
问题是批处理后标签顺序不对。如果我删除代码 'data = data.batch',一切正常。
我认为一个可能的原因是特征和标签是独立批处理的。所以我尝试在批处理后解析示例,但得到错误"Input serialized must be a scalar"。如果您知道如何处理这个问题,请帮助我,非常感谢!
我确定这是重复的,但我找不到其他问题,所以我会在这里回答。
您的问题是为数据和标签调用 sess.run()
两次。 每当您调用 sess.run
时,您的图表都会被评估(即,提取新的批次并 运行 通过图表直到 all 作为第一个参数传递给 run
的列表中张量的值是已知的)。
这样做,您的 data
和 label
指的是两个不同的批次(因此它们看起来是错误的)。
你需要让他们在同一个电话中:
data, label = sess.run([next_data, next_label])
我在从 tfrecords 文件导入数据时遇到问题。 tfrecords中的每个样本由一个长度为100的feautures向量和一个长度为13的one-hot label向量组成。我使用下面的代码从tfrecords导入数据,参考官方指南https://www.tensorflow.org/programmers_guide/datasets
def read_data(examples):
features = {"features": tf.FixedLenFeature([seq_len], tf.int64),
"label": tf.FixedLenFeature([category], tf.int64)}
parsed_features = tf.parse_single_example(examples, features)
return parsed_features['features'], parsed_features['label']
# get next batch of data and label
def next_batch(filename, batch_size):
data = tf.data.TFRecordDataset(filename)
data = data.map(read_data)
data = data.batch(batch_size)
iterator = data.make_one_shot_iterator()
next_data, next_label = iterator.get_next()
return next_data, next_label
with tf.Session() as sess:
filetrain = 'train.tfrecords'
next_data, next_label = next_batch(filetrain, num_example_train)
sess.run(tf.global_variables_initializer())
data = sess.run(next_data)
label = sess.run(next_label)
问题是批处理后标签顺序不对。如果我删除代码 'data = data.batch',一切正常。
我认为一个可能的原因是特征和标签是独立批处理的。所以我尝试在批处理后解析示例,但得到错误"Input serialized must be a scalar"。如果您知道如何处理这个问题,请帮助我,非常感谢!
我确定这是重复的,但我找不到其他问题,所以我会在这里回答。
您的问题是为数据和标签调用 sess.run()
两次。 每当您调用 sess.run
时,您的图表都会被评估(即,提取新的批次并 运行 通过图表直到 all 作为第一个参数传递给 run
的列表中张量的值是已知的)。
这样做,您的 data
和 label
指的是两个不同的批次(因此它们看起来是错误的)。
你需要让他们在同一个电话中:
data, label = sess.run([next_data, next_label])