如何读写二维数组的tfrecord文件
How to read and write tfrecord files of 2d array
我想制作一个大小为 (n, 3) 的二维数组 tfrecord file
,然后读取它。
我写的代码tfrecord file
是
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
example = tf.train.Example(
features=tf.train.Features(
feature={
'arry_x':_float_feature(array[:,0]),
'arry_y':_float_feature(array[:,1]),
'arry_z':_float_feature(array[:,2])}
)
)
with tf.compat.v1.python_io.TFRecordWriter(file_name) as writer:
writer.write(example.SerializeToString())
我尝试用 TFRecordReader
读取文件
def get_tfrecord_feature():
return{
'arry_x': tf.compat.v1.io.FixedLenFeature([], tf.float32),
'arry_y': tf.compat.v1.io.FixedLenFeature([], tf.float32),
'arry_z': tf.compat.v1.io.FixedLenFeature([], tf.float32)
}
filenames = [file_name, file_name2, ...]
file_name_queue = tf.train.string_input_producer(filenames)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_name_queue)
data = tf.compat.v1.io.parse_single_example(serialized_example, features=get_tfrecord_feature())
x = data['arry_x']
y = data['arry_y']
z = data['arry_z']
x, y, z = tf.train.batch([x, y, z], batch_size=1)
我用tf.Session检查了代码
with tf.compat.v1.Session() as sess:
print(sess.run(x))
代码运行没有错误,但会话没有打印任何值。
我认为阅读 tfrecord file
的方式是错误的。
谁能帮帮我?
我认为您应该在解析 tf 记录时将列表长度(在您的情况下为 array.shape[0] 添加到功能定义中,如下所示。
def get_tfrecord_feature():
return{
'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
}
如果 FixedLenFeature 只有一个元素,您可以将形状保留为 []。
https://tensorflow.org/versions/r1.15/api_docs/python/tf/io/FixedLenFeature
感谢donglinjy的指点,我在这里修正了我的代码
def get_tfrecord_feature():
return{
'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
}
在这里。
with tf.compat.v1.Session() as sess:
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(coord=coord)
print(sess.run(x))
现在有效。
我想制作一个大小为 (n, 3) 的二维数组 tfrecord file
,然后读取它。
我写的代码tfrecord file
是
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
example = tf.train.Example(
features=tf.train.Features(
feature={
'arry_x':_float_feature(array[:,0]),
'arry_y':_float_feature(array[:,1]),
'arry_z':_float_feature(array[:,2])}
)
)
with tf.compat.v1.python_io.TFRecordWriter(file_name) as writer:
writer.write(example.SerializeToString())
我尝试用 TFRecordReader
def get_tfrecord_feature():
return{
'arry_x': tf.compat.v1.io.FixedLenFeature([], tf.float32),
'arry_y': tf.compat.v1.io.FixedLenFeature([], tf.float32),
'arry_z': tf.compat.v1.io.FixedLenFeature([], tf.float32)
}
filenames = [file_name, file_name2, ...]
file_name_queue = tf.train.string_input_producer(filenames)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_name_queue)
data = tf.compat.v1.io.parse_single_example(serialized_example, features=get_tfrecord_feature())
x = data['arry_x']
y = data['arry_y']
z = data['arry_z']
x, y, z = tf.train.batch([x, y, z], batch_size=1)
我用tf.Session检查了代码
with tf.compat.v1.Session() as sess:
print(sess.run(x))
代码运行没有错误,但会话没有打印任何值。
我认为阅读 tfrecord file
的方式是错误的。
谁能帮帮我?
我认为您应该在解析 tf 记录时将列表长度(在您的情况下为 array.shape[0] 添加到功能定义中,如下所示。
def get_tfrecord_feature():
return{
'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
}
如果 FixedLenFeature 只有一个元素,您可以将形状保留为 []。 https://tensorflow.org/versions/r1.15/api_docs/python/tf/io/FixedLenFeature
感谢donglinjy的指点,我在这里修正了我的代码
def get_tfrecord_feature():
return{
'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
}
在这里。
with tf.compat.v1.Session() as sess:
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(coord=coord)
print(sess.run(x))
现在有效。