Dataset.map() 中的 Tensorflow 解析和重塑浮点列表
Tensorflow parsing and reshaping float list in Dataset.map()
我正在尝试将 3D 浮动列表写入 TFrecord,所以我成功地通过先将其展平来写入它,我解析了它,但在重塑它时引发了错误。
错误:ValueError: Shapes () and (8,) are not compatible
这就是我编写 TFrecord 文件的方式
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))
def write(output_path, data_rgb, data_depth, data_decalib):
with tf.python_io.TFRecordWriter(output_path) as writer:
feature = {'data_rgb': _floats_feature(data_rgb),
'data_depth': _floats_feature(data_depth),
'data_decalib': _floats_feature(data_decalib)}
sample = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(sample.SerializeToString())
这就是我读取 TFrecord 文件的方式
def get_batches(date, drives, batch_size=1):
"""
Create a generator that returns batches of tuples
rgb, depth and calibration
:param date: date of the drive
:param drives: array of the drive_numbers within the drive date
:return: batch generator
"""
filenames = get_paths_drives(date, drives)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(input_parser) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(batch_size)
return dataset
config = configparser.ConfigParser()
config.read(path_helpers.get_config_file_path())
IMAGE_WIDTH = int(config['DATA_INFORMATION']['IMAGE_WIDTH'])
IMAGE_HEIGHT = int(config['DATA_INFORMATION']['IMAGE_HEIGHT'])
INPUT_RGB_SHAPE = [IMAGE_HEIGHT, IMAGE_WIDTH, 3]
INPUT_DEPTH_SHAPE = [IMAGE_HEIGHT, IMAGE_WIDTH, 1]
LABEL_CALIB_SHAPE = [8]
def input_parser(example_proto):
features = {'data_rgb': tf.FixedLenFeature([], tf.float32),
'data_depth': tf.FixedLenFeature([], tf.float32),
'data_decalib': tf.FixedLenFeature([], tf.float32)}
parsed_features = tf.parse_single_example(example_proto, features)
data_rgb = parsed_features['data_rgb']
data_rgb.set_shape(np.prod(INPUT_RGB_SHAPE))
img_rgb = tf.reshape(data_rgb, INPUT_RGB_SHAPE)
data_depth = parsed_features['data_depth']
data_depth.set_shape(np.prod(INPUT_DEPTH_SHAPE))
img_depth = tf.reshape(data_depth, INPUT_DEPTH_SHAPE)
data_decalib = parsed_features['data_decalib']
data_decalib.set_shape(LABEL_CALIB_SHAPE)
return img_rgb, img_depth, data_decalib
原来我需要按如下方式更改我的输入解析器:
def input_parser(example_proto):
features = {'data_rgb': tf.FixedLenFeature(shape=[np.prod(INPUT_RGB_SHAPE)], dtype=tf.float32),
'data_depth': tf.FixedLenFeature(shape=[np.prod(INPUT_DEPTH_SHAPE)], dtype=tf.float32),
'data_decalib': tf.FixedLenFeature(shape=LABEL_CALIB_SHAPE, dtype=tf.float32)}
parsed_features = tf.parse_single_example(example_proto, features)
如 tf.FixedLenFeature(现在 tf.io.FixedLenFeature)的文档所述。第一个参数是 shape
,我将其设置为 []
,因此出现错误 ValueError: Shapes () and (8,) are not compatible
。将其设置为实际值即可。
我正在尝试将 3D 浮动列表写入 TFrecord,所以我成功地通过先将其展平来写入它,我解析了它,但在重塑它时引发了错误。
错误:ValueError: Shapes () and (8,) are not compatible
这就是我编写 TFrecord 文件的方式
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))
def write(output_path, data_rgb, data_depth, data_decalib):
with tf.python_io.TFRecordWriter(output_path) as writer:
feature = {'data_rgb': _floats_feature(data_rgb),
'data_depth': _floats_feature(data_depth),
'data_decalib': _floats_feature(data_decalib)}
sample = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(sample.SerializeToString())
这就是我读取 TFrecord 文件的方式
def get_batches(date, drives, batch_size=1):
"""
Create a generator that returns batches of tuples
rgb, depth and calibration
:param date: date of the drive
:param drives: array of the drive_numbers within the drive date
:return: batch generator
"""
filenames = get_paths_drives(date, drives)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(input_parser) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(batch_size)
return dataset
config = configparser.ConfigParser()
config.read(path_helpers.get_config_file_path())
IMAGE_WIDTH = int(config['DATA_INFORMATION']['IMAGE_WIDTH'])
IMAGE_HEIGHT = int(config['DATA_INFORMATION']['IMAGE_HEIGHT'])
INPUT_RGB_SHAPE = [IMAGE_HEIGHT, IMAGE_WIDTH, 3]
INPUT_DEPTH_SHAPE = [IMAGE_HEIGHT, IMAGE_WIDTH, 1]
LABEL_CALIB_SHAPE = [8]
def input_parser(example_proto):
features = {'data_rgb': tf.FixedLenFeature([], tf.float32),
'data_depth': tf.FixedLenFeature([], tf.float32),
'data_decalib': tf.FixedLenFeature([], tf.float32)}
parsed_features = tf.parse_single_example(example_proto, features)
data_rgb = parsed_features['data_rgb']
data_rgb.set_shape(np.prod(INPUT_RGB_SHAPE))
img_rgb = tf.reshape(data_rgb, INPUT_RGB_SHAPE)
data_depth = parsed_features['data_depth']
data_depth.set_shape(np.prod(INPUT_DEPTH_SHAPE))
img_depth = tf.reshape(data_depth, INPUT_DEPTH_SHAPE)
data_decalib = parsed_features['data_decalib']
data_decalib.set_shape(LABEL_CALIB_SHAPE)
return img_rgb, img_depth, data_decalib
原来我需要按如下方式更改我的输入解析器:
def input_parser(example_proto):
features = {'data_rgb': tf.FixedLenFeature(shape=[np.prod(INPUT_RGB_SHAPE)], dtype=tf.float32),
'data_depth': tf.FixedLenFeature(shape=[np.prod(INPUT_DEPTH_SHAPE)], dtype=tf.float32),
'data_decalib': tf.FixedLenFeature(shape=LABEL_CALIB_SHAPE, dtype=tf.float32)}
parsed_features = tf.parse_single_example(example_proto, features)
如 tf.FixedLenFeature(现在 tf.io.FixedLenFeature)的文档所述。第一个参数是 shape
,我将其设置为 []
,因此出现错误 ValueError: Shapes () and (8,) are not compatible
。将其设置为实际值即可。