Tensorflow 数据集结构
Tensorflow Dataset structure
我想通过研究 https://github.com/tensorflow/models/blob/master/official/resnet/cifar10_main.py
上关于 cifar10 的官方示例来弄清楚如何使用 Tensorflow 的数据集模块
为了自己构建数据集,我在函数'input_fn'中替换了以下代码:
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
来自
dataset = creat_dataset()
其中 'creat_dataset' 定义为:
def creat_dataset():
def unpickle(file):
import cPickle
with open(file, 'rb') as fo:
dict = cPickle.load(fo)
ll = dict['labels']
return dict['data'], np.array(ll).reshape(len(ll), 1)
dir = './cifar_10/data_batch_'
data = None
label = None
for i in range(1,6):
if data is None:
data, label = unpickle(dir + '1')
else:
data_, label_ = unpickle(dir + str(i))
data = np.concatenate((data, data_), 0)
label = np.concatenate((label, label_))
data = np.concatenate((label, data), 1)
data = tf.constant(data, tf.uint8)
dataset = tf.data.Dataset.from_tensor_slices(data)
return dataset
但是我得到如下错误信息:
Traceback (most recent call last):
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 260, in <module>
tf.app.run(argv=[sys.argv[0]] + unparsed)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 244, in main
resnet.resnet_main(FLAGS, cifar10_model_fn, input_function)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 766, in resnet_main
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 352, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 809, in _train_model
input_fn, model_fn_lib.ModeKeys.TRAIN))
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 668, in _get_features_and_labels_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 760, in _call_input_fn
return input_fn(**kwargs)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 764, in input_fn_train
flags.multi_gpu)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 162, in input_fn
examples_per_epoch=num_images, multi_gpu=multi_gpu)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 104, in process_record_dataset
num_parallel_calls=num_parallel_calls)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 792, in map
return ParallelMapDataset(self, map_func, num_parallel_calls)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1628, in __init__
super(ParallelMapDataset, self).__init__(input_dataset, map_func)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1597, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 486, in add_to_graph
self._create_definition_if_needed()
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
outputs = self._func(*inputs)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1562, in tf_map_func
ret = map_func(nested_args)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 103, in <lambda>
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 69, in parse_record
record_vector = tf.decode_raw(raw_record, tf.uint8)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_parsing_ops.py", line 195, in decode_raw
little_endian=little_endian, name=name)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 533, in _apply_op_helper
(prefix, dtypes.as_dtype(input_arg.type).name))
TypeError: Input 'bytes' of 'DecodeRaw' Op has type uint8 that does not match expected type of string.
谁能告诉我如何解决这个错误?
这个问题通过简单地将表达式 record_vector = tf.decode_raw(raw_record, tf.uint8)
更改为 record_vector = raw_record
来解决,似乎 cifar 数据集中的项目不是张量。
我遇到了和你一样的错误,上面的回答很好。所以要清楚,也许你的代码中有以下情况:
_, serialized_example = reader.read(filename_queue)
img_features = tf.parse_single_example(serialized=serialized_example,
features={
'image':tf.FixedLenFeature([], tf.float32),
'label':tf.FixedLenFeature([], tf.int64)
})
# image = tf.decode_raw(img_features['image'], tf.uint8)
image = img_features['image']
现在看:
'image':tf.FixedLenFeature([], tf.float32),
您在互联网上观看的大多数教程是:
'image':tf.FixedLenFeature([], tf.string),
和 运行 以下行代码没问题:
image = tf.decode_raw(img_features['data'], tf.uint8)
但是当你的 tfrecord 的初始 FixedLenFeature
只是这些允许值之一时:float16, float32, float64, int32, uint16, uint8, int16, int8, int64
,
那么哪里不需要 decode_raw
这可能会导致你得到的错误,只是
image = img_features['image']
.
顺便说一下,如果你使用Jupyter Notebook,记得修改代码后Restart & Clear Output
内核,然后运行你的程序再一步一步
我想通过研究 https://github.com/tensorflow/models/blob/master/official/resnet/cifar10_main.py
上关于 cifar10 的官方示例来弄清楚如何使用 Tensorflow 的数据集模块为了自己构建数据集,我在函数'input_fn'中替换了以下代码:
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
来自
dataset = creat_dataset()
其中 'creat_dataset' 定义为:
def creat_dataset():
def unpickle(file):
import cPickle
with open(file, 'rb') as fo:
dict = cPickle.load(fo)
ll = dict['labels']
return dict['data'], np.array(ll).reshape(len(ll), 1)
dir = './cifar_10/data_batch_'
data = None
label = None
for i in range(1,6):
if data is None:
data, label = unpickle(dir + '1')
else:
data_, label_ = unpickle(dir + str(i))
data = np.concatenate((data, data_), 0)
label = np.concatenate((label, label_))
data = np.concatenate((label, data), 1)
data = tf.constant(data, tf.uint8)
dataset = tf.data.Dataset.from_tensor_slices(data)
return dataset
但是我得到如下错误信息:
Traceback (most recent call last):
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 260, in <module>
tf.app.run(argv=[sys.argv[0]] + unparsed)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 244, in main
resnet.resnet_main(FLAGS, cifar10_model_fn, input_function)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 766, in resnet_main
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 352, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 809, in _train_model
input_fn, model_fn_lib.ModeKeys.TRAIN))
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 668, in _get_features_and_labels_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 760, in _call_input_fn
return input_fn(**kwargs)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 764, in input_fn_train
flags.multi_gpu)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 162, in input_fn
examples_per_epoch=num_images, multi_gpu=multi_gpu)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 104, in process_record_dataset
num_parallel_calls=num_parallel_calls)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 792, in map
return ParallelMapDataset(self, map_func, num_parallel_calls)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1628, in __init__
super(ParallelMapDataset, self).__init__(input_dataset, map_func)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1597, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 486, in add_to_graph
self._create_definition_if_needed()
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
outputs = self._func(*inputs)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1562, in tf_map_func
ret = map_func(nested_args)
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/resnet.py", line 103, in <lambda>
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
File "/Users/ritsuko/PycharmProjects/master/tf_official_resnet/main.py", line 69, in parse_record
record_vector = tf.decode_raw(raw_record, tf.uint8)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_parsing_ops.py", line 195, in decode_raw
little_endian=little_endian, name=name)
File "/Users/ritsuko/Tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 533, in _apply_op_helper
(prefix, dtypes.as_dtype(input_arg.type).name))
TypeError: Input 'bytes' of 'DecodeRaw' Op has type uint8 that does not match expected type of string.
谁能告诉我如何解决这个错误?
这个问题通过简单地将表达式 record_vector = tf.decode_raw(raw_record, tf.uint8)
更改为 record_vector = raw_record
来解决,似乎 cifar 数据集中的项目不是张量。
我遇到了和你一样的错误,上面的回答很好。所以要清楚,也许你的代码中有以下情况:
_, serialized_example = reader.read(filename_queue)
img_features = tf.parse_single_example(serialized=serialized_example,
features={
'image':tf.FixedLenFeature([], tf.float32),
'label':tf.FixedLenFeature([], tf.int64)
})
# image = tf.decode_raw(img_features['image'], tf.uint8)
image = img_features['image']
现在看:
'image':tf.FixedLenFeature([], tf.float32),
您在互联网上观看的大多数教程是:
'image':tf.FixedLenFeature([], tf.string),
和 运行 以下行代码没问题:
image = tf.decode_raw(img_features['data'], tf.uint8)
但是当你的 tfrecord 的初始 FixedLenFeature
只是这些允许值之一时:float16, float32, float64, int32, uint16, uint8, int16, int8, int64
,
那么哪里不需要 decode_raw
这可能会导致你得到的错误,只是
image = img_features['image']
.
顺便说一下,如果你使用Jupyter Notebook,记得修改代码后Restart & Clear Output
内核,然后运行你的程序再一步一步