tf.contrib.data.Dataset 好像不支持SparseTensor
tf.contrib.data.Dataset seems does not support SparseTensor
我使用 tensorflow object detection API 中的代码生成了一个 pascal voc 2007 tfrecords 文件。我使用 tf.contrib.data.Dataset
API 从 tfrecords 读取数据。我尝试了没有tf.contrib.data.Dataset
API的方法,代码可以运行没有任何错误,但是当更改为tf.contrib.data.Dataset
API时,它无法正常工作。
没有tf.contrib.data.Dataset
的代码:
import tensorflow as tf
if __name__ == '__main__':
slim_example_decoder = tf.contrib.slim.tfexample_decoder
features = {"image/height": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/width": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
"image/source_id": tf.FixedLenFeature((), tf.string, default_value=""),
"image/key/sha256": tf.FixedLenFeature((), tf.string, default_value=""),
"image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
"image/format": tf.FixedLenFeature((), tf.string, default_value="jpeg"),
"image/object/bbox/xmin": tf.VarLenFeature(tf.float32),
"image/object/bbox/xmax": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymin": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymax": tf.VarLenFeature(tf.float32),
"image/object/class/text": tf.VarLenFeature(tf.string),
"image/object/class/label": tf.VarLenFeature(tf.int64),
"image/object/difficult": tf.VarLenFeature(tf.int64),
"image/object/truncated": tf.VarLenFeature(tf.int64),
"image/object/view": tf.VarLenFeature(tf.int64)}
items_to_handlers = {
'image': slim_example_decoder.Image(
image_key='image/encoded', format_key='image/format', channels=3),
'source_id': (
slim_example_decoder.Tensor('image/source_id')),
'key': (
slim_example_decoder.Tensor('image/key/sha256')),
'filename': (
slim_example_decoder.Tensor('image/filename')),
# Object boxes and classes.
'groundtruth_boxes': (
slim_example_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
'groundtruth_classes': (
slim_example_decoder.Tensor('image/object/class/label')),
'groundtruth_difficult': (
slim_example_decoder.Tensor('image/object/difficult')),
'image/object/truncated': (
slim_example_decoder.Tensor('image/object/truncated')),
'image/object/view': (
slim_example_decoder.Tensor('image/object/view')),
}
decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
keys = decoder.list_items()
for example in tf.python_io.tf_record_iterator(
"/home/aurora/workspaces/data/tfrecords_data/oxford_pet/pet_train.record"):
serialized_example = tf.reshape(example, shape=[])
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
tensor_dict['image'].set_shape([None, None, 3])
print(tensor_dict)
以上代码的输出是:
{'image': <tf.Tensor 'case/If_1/Merge:0' shape=(?, ?, 3) dtype=uint8>,
'filename': <tf.Tensor 'Reshape_2:0' shape=() dtype=string>,
'groundtruth_boxes': <tf.Tensor 'transpose:0' shape=(?, 4) dtype=float32>,
'key': <tf.Tensor 'Reshape_5:0' shape=() dtype=string>,
'image/object/truncated': <tf.Tensor 'SparseToDense:0' shape=(?,) dtype=int64>,
'groundtruth_classes': <tf.Tensor 'SparseToDense_2:0' shape=(?,) dtype=int64>,
'image/object/view': <tf.Tensor 'SparseToDense_1:0' shape=(?,) dtype=int64>,
'source_id': <tf.Tensor 'Reshape_6:0' shape=() dtype=string>,
'groundtruth_difficult': <tf.Tensor 'SparseToDense_3:0' shape=(?,) dtype=int64>}
...
代码 tf.contrib.data.Dataset
:
import tensorflow as tf
from tensorflow.contrib.data import Iterator
slim_example_decoder = tf.contrib.slim.tfexample_decoder
flags = tf.app.flags
flags.DEFINE_string('data_dir',
'/home/aurora/workspaces/data/tfrecords_data/voc_dataset/trainval.tfrecords',
'tfrecords file output path')
flags.DEFINE_integer('batch_size', 32, 'training batch size')
flags.DEFINE_integer('capacity', 10000, 'training batch size')
FLAGS = flags.FLAGS
features = {"image/height": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/width": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
"image/source_id": tf.FixedLenFeature((), tf.string, default_value=""),
"image/key/sha256": tf.FixedLenFeature((), tf.string, default_value=""),
"image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
"image/format": tf.FixedLenFeature((), tf.string, default_value="jpeg"),
"image/object/bbox/xmin": tf.VarLenFeature(tf.float32),
"image/object/bbox/xmax": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymin": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymax": tf.VarLenFeature(tf.float32),
"image/object/class/text": tf.VarLenFeature(tf.string),
"image/object/class/label": tf.VarLenFeature(tf.int64),
"image/object/difficult": tf.VarLenFeature(tf.int64),
"image/object/truncated": tf.VarLenFeature(tf.int64),
"image/object/view": tf.VarLenFeature(tf.int64)
}
items_to_handlers = {
'image': slim_example_decoder.Image(
image_key='image/encoded', format_key='image/format', channels=3),
'source_id': (
slim_example_decoder.Tensor('image/source_id')),
'key': (
slim_example_decoder.Tensor('image/key/sha256')),
'filename': (
slim_example_decoder.Tensor('image/filename')),
# Object boxes and classes.
'groundtruth_boxes': (
slim_example_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
'groundtruth_classes': (
slim_example_decoder.Tensor('image/object/class/label')),
'groundtruth_difficult': (
slim_example_decoder.Tensor('image/object/difficult')),
'image/object/truncated': (
slim_example_decoder.Tensor('image/object/truncated')),
'image/object/view': (
slim_example_decoder.Tensor('image/object/view')),
}
decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
keys = decoder.list_items()
def _parse_function_train(example):
serialized_example = tf.reshape(example, shape=[])
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
tensor_dict['image'].set_shape([None, None, 3])
print(tensor_dict)
return tensor_dict
if __name__ == '__main__':
train_dataset = tf.contrib.data.TFRecordDataset(FLAGS.data_dir)
train_dataset = train_dataset.map(_parse_function_train)
train_dataset = train_dataset.repeat(1)
train_dataset = train_dataset.batch(FLAGS.batch_size)
train_dataset = train_dataset.shuffle(buffer_size=FLAGS.capacity)
iterator = Iterator.from_structure(train_dataset.output_types,
train_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(train_dataset)
sess = tf.Session()
sess.run(training_init_op)
counter = 0
while True:
try:
sess.run(next_element)
counter += 1
except tf.errors.OutOfRangeError:
print('End of training data in step %d' %counter)
break
当运行以上代码时,报如下错误:
2017-10-09 23:41:43.488439: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
2017-10-09 23:41:43.488554: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
[[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
Traceback (most recent call last):
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
File "/usr/software/anaconda3/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
[[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?,?,?,3], [?,?], [?,?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_INT64, DT_STRING, DT_STRING], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/aurora/workspaces/PycharmProjects/object_detection_models/datasets/voc_dataset/voc_tfrecords_decode_test.py", line 83, in <module>
sess.run(next_element)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
[[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?,?,?,3], [?,?], [?,?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_INT64, DT_STRING, DT_STRING], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
生成tfrecords文件的代码可以参考create_pascal_tf_record.py.
编辑 (2018/01/25): tf.SparseTensor
支持已添加到 TensorFlow 1.5 中的 tf.data
。问题中的代码应该在 TensorFlow 1.5 或更高版本中工作。
直到 TF 1.4,tf.contrib.data
API 不支持数据集元素中的 tf.SparseTensor
对象。有几个解决方法:
(更难,但可能更快。) 如果 tf.SparseTensor
对象 st
表示一个可变长度的特征列表,您可以 return st.values
而不是 map()
函数中的 st
。请注意,您可能需要使用 Dataset.padded_batch()
instead of Dataset.batch()
.
填充结果
(更简单,但可能更慢。) 在 _parse_function_train()
函数中,迭代 tensor_dict
并生成一个新版本,其中任何 tf.SparseTensor
个对象都已使用 tf.serialize_sparse()
转换为 tf.Tensor
。当你
# NOTE: You could probably infer these from `keys`.
sparse_keys = set()
def _parse_function_train(example):
serialized_example = tf.reshape(example, shape=[])
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
tensor_dict['image'].set_shape([None, None, 3])
rewritten_tensor_dict = {}
for key, value in tensor_dict.items():
if isinstance(value, tf.SparseTensor):
rewritten_tensor_dict[key] = tf.serialize_sparse(value)
sparse_keys.add(key)
else:
rewritten_tensor_dict[key] = value
return rewritten_tensor_dict
然后,在您从 iterator.get_next()
获得 next_element
字典后,您可以使用 tf.deserialize_many_sparse()
:
反转此转换
next_element = iterator.get_next()
for key in sparse_keys:
next_element[key] = tf.deserialize_many_sparse(key)
除了 tf.SparseTensor
, there is tf.scatter_nd
之外似乎也能达到相同的结果,只是您可能需要稍后恢复索引。
这两个代码块实现相同的结果
indices = tf.concat([xy_indices, z_indices], axis=-1)
values = tf.concat([bboxes, objectness, one_hot], axis=1)
SparseTensor = tf.SparseTensor(indices= indices,
values = values,
dense_shape =[output_size, output_size,
len(anchors), 4 + 1 + num_classes])
DenseTensor = tf.scatter_nd(indices = indices,
updates = values
shape = [output_size, output_size,
len(anchors),4 + 1 + num_classes])
我使用 tensorflow object detection API 中的代码生成了一个 pascal voc 2007 tfrecords 文件。我使用 tf.contrib.data.Dataset
API 从 tfrecords 读取数据。我尝试了没有tf.contrib.data.Dataset
API的方法,代码可以运行没有任何错误,但是当更改为tf.contrib.data.Dataset
API时,它无法正常工作。
没有tf.contrib.data.Dataset
的代码:
import tensorflow as tf
if __name__ == '__main__':
slim_example_decoder = tf.contrib.slim.tfexample_decoder
features = {"image/height": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/width": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
"image/source_id": tf.FixedLenFeature((), tf.string, default_value=""),
"image/key/sha256": tf.FixedLenFeature((), tf.string, default_value=""),
"image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
"image/format": tf.FixedLenFeature((), tf.string, default_value="jpeg"),
"image/object/bbox/xmin": tf.VarLenFeature(tf.float32),
"image/object/bbox/xmax": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymin": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymax": tf.VarLenFeature(tf.float32),
"image/object/class/text": tf.VarLenFeature(tf.string),
"image/object/class/label": tf.VarLenFeature(tf.int64),
"image/object/difficult": tf.VarLenFeature(tf.int64),
"image/object/truncated": tf.VarLenFeature(tf.int64),
"image/object/view": tf.VarLenFeature(tf.int64)}
items_to_handlers = {
'image': slim_example_decoder.Image(
image_key='image/encoded', format_key='image/format', channels=3),
'source_id': (
slim_example_decoder.Tensor('image/source_id')),
'key': (
slim_example_decoder.Tensor('image/key/sha256')),
'filename': (
slim_example_decoder.Tensor('image/filename')),
# Object boxes and classes.
'groundtruth_boxes': (
slim_example_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
'groundtruth_classes': (
slim_example_decoder.Tensor('image/object/class/label')),
'groundtruth_difficult': (
slim_example_decoder.Tensor('image/object/difficult')),
'image/object/truncated': (
slim_example_decoder.Tensor('image/object/truncated')),
'image/object/view': (
slim_example_decoder.Tensor('image/object/view')),
}
decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
keys = decoder.list_items()
for example in tf.python_io.tf_record_iterator(
"/home/aurora/workspaces/data/tfrecords_data/oxford_pet/pet_train.record"):
serialized_example = tf.reshape(example, shape=[])
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
tensor_dict['image'].set_shape([None, None, 3])
print(tensor_dict)
以上代码的输出是:
{'image': <tf.Tensor 'case/If_1/Merge:0' shape=(?, ?, 3) dtype=uint8>,
'filename': <tf.Tensor 'Reshape_2:0' shape=() dtype=string>,
'groundtruth_boxes': <tf.Tensor 'transpose:0' shape=(?, 4) dtype=float32>,
'key': <tf.Tensor 'Reshape_5:0' shape=() dtype=string>,
'image/object/truncated': <tf.Tensor 'SparseToDense:0' shape=(?,) dtype=int64>,
'groundtruth_classes': <tf.Tensor 'SparseToDense_2:0' shape=(?,) dtype=int64>,
'image/object/view': <tf.Tensor 'SparseToDense_1:0' shape=(?,) dtype=int64>,
'source_id': <tf.Tensor 'Reshape_6:0' shape=() dtype=string>,
'groundtruth_difficult': <tf.Tensor 'SparseToDense_3:0' shape=(?,) dtype=int64>}
...
代码 tf.contrib.data.Dataset
:
import tensorflow as tf
from tensorflow.contrib.data import Iterator
slim_example_decoder = tf.contrib.slim.tfexample_decoder
flags = tf.app.flags
flags.DEFINE_string('data_dir',
'/home/aurora/workspaces/data/tfrecords_data/voc_dataset/trainval.tfrecords',
'tfrecords file output path')
flags.DEFINE_integer('batch_size', 32, 'training batch size')
flags.DEFINE_integer('capacity', 10000, 'training batch size')
FLAGS = flags.FLAGS
features = {"image/height": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/width": tf.FixedLenFeature((), tf.int64, default_value=1),
"image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
"image/source_id": tf.FixedLenFeature((), tf.string, default_value=""),
"image/key/sha256": tf.FixedLenFeature((), tf.string, default_value=""),
"image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
"image/format": tf.FixedLenFeature((), tf.string, default_value="jpeg"),
"image/object/bbox/xmin": tf.VarLenFeature(tf.float32),
"image/object/bbox/xmax": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymin": tf.VarLenFeature(tf.float32),
"image/object/bbox/ymax": tf.VarLenFeature(tf.float32),
"image/object/class/text": tf.VarLenFeature(tf.string),
"image/object/class/label": tf.VarLenFeature(tf.int64),
"image/object/difficult": tf.VarLenFeature(tf.int64),
"image/object/truncated": tf.VarLenFeature(tf.int64),
"image/object/view": tf.VarLenFeature(tf.int64)
}
items_to_handlers = {
'image': slim_example_decoder.Image(
image_key='image/encoded', format_key='image/format', channels=3),
'source_id': (
slim_example_decoder.Tensor('image/source_id')),
'key': (
slim_example_decoder.Tensor('image/key/sha256')),
'filename': (
slim_example_decoder.Tensor('image/filename')),
# Object boxes and classes.
'groundtruth_boxes': (
slim_example_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
'groundtruth_classes': (
slim_example_decoder.Tensor('image/object/class/label')),
'groundtruth_difficult': (
slim_example_decoder.Tensor('image/object/difficult')),
'image/object/truncated': (
slim_example_decoder.Tensor('image/object/truncated')),
'image/object/view': (
slim_example_decoder.Tensor('image/object/view')),
}
decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
keys = decoder.list_items()
def _parse_function_train(example):
serialized_example = tf.reshape(example, shape=[])
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
tensor_dict['image'].set_shape([None, None, 3])
print(tensor_dict)
return tensor_dict
if __name__ == '__main__':
train_dataset = tf.contrib.data.TFRecordDataset(FLAGS.data_dir)
train_dataset = train_dataset.map(_parse_function_train)
train_dataset = train_dataset.repeat(1)
train_dataset = train_dataset.batch(FLAGS.batch_size)
train_dataset = train_dataset.shuffle(buffer_size=FLAGS.capacity)
iterator = Iterator.from_structure(train_dataset.output_types,
train_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(train_dataset)
sess = tf.Session()
sess.run(training_init_op)
counter = 0
while True:
try:
sess.run(next_element)
counter += 1
except tf.errors.OutOfRangeError:
print('End of training data in step %d' %counter)
break
当运行以上代码时,报如下错误:
2017-10-09 23:41:43.488439: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
2017-10-09 23:41:43.488554: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
[[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
Traceback (most recent call last):
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
File "/usr/software/anaconda3/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
[[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?,?,?,3], [?,?], [?,?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_INT64, DT_STRING, DT_STRING], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/aurora/workspaces/PycharmProjects/object_detection_models/datasets/voc_dataset/voc_tfrecords_decode_test.py", line 83, in <module>
sess.run(next_element)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: <unknown>, Key: image/object/view, Index: 0. Data types don't match. Expected type: int64
[[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?,?,?,3], [?,?], [?,?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_INT64, DT_STRING, DT_STRING], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
生成tfrecords文件的代码可以参考create_pascal_tf_record.py.
编辑 (2018/01/25): tf.SparseTensor
支持已添加到 TensorFlow 1.5 中的 tf.data
。问题中的代码应该在 TensorFlow 1.5 或更高版本中工作。
直到 TF 1.4,tf.contrib.data
API 不支持数据集元素中的 tf.SparseTensor
对象。有几个解决方法:
(更难,但可能更快。) 如果
tf.SparseTensor
对象st
表示一个可变长度的特征列表,您可以 returnst.values
而不是map()
函数中的st
。请注意,您可能需要使用Dataset.padded_batch()
instead ofDataset.batch()
. 填充结果
(更简单,但可能更慢。) 在
_parse_function_train()
函数中,迭代tensor_dict
并生成一个新版本,其中任何tf.SparseTensor
个对象都已使用tf.serialize_sparse()
转换为tf.Tensor
。当你# NOTE: You could probably infer these from `keys`. sparse_keys = set() def _parse_function_train(example): serialized_example = tf.reshape(example, shape=[]) tensors = decoder.decode(serialized_example, items=keys) tensor_dict = dict(zip(keys, tensors)) tensor_dict['image'].set_shape([None, None, 3]) rewritten_tensor_dict = {} for key, value in tensor_dict.items(): if isinstance(value, tf.SparseTensor): rewritten_tensor_dict[key] = tf.serialize_sparse(value) sparse_keys.add(key) else: rewritten_tensor_dict[key] = value return rewritten_tensor_dict
然后,在您从
反转此转换iterator.get_next()
获得next_element
字典后,您可以使用tf.deserialize_many_sparse()
:next_element = iterator.get_next() for key in sparse_keys: next_element[key] = tf.deserialize_many_sparse(key)
除了 tf.SparseTensor
, there is tf.scatter_nd
之外似乎也能达到相同的结果,只是您可能需要稍后恢复索引。
这两个代码块实现相同的结果
indices = tf.concat([xy_indices, z_indices], axis=-1)
values = tf.concat([bboxes, objectness, one_hot], axis=1)
SparseTensor = tf.SparseTensor(indices= indices,
values = values,
dense_shape =[output_size, output_size,
len(anchors), 4 + 1 + num_classes])
DenseTensor = tf.scatter_nd(indices = indices,
updates = values
shape = [output_size, output_size,
len(anchors),4 + 1 + num_classes])