"TypeError: 'Tensor' object is not iterable" error with tensorflow Estimator
"TypeError: 'Tensor' object is not iterable" error with tensorflow Estimator
我有一个程序生成的(无限)数据源,我正在尝试将其用作高级 Tensorflow Estimator
的输入来训练基于图像的 3D 对象检测器。
我像在 Tensorflor Estimator 中一样设置数据集 Quickstart, and my dataset_input_fn
returns a tuple of features and labels Tensor
's, just as the Estimator.train
function specifies, and how this tutorial shows,但在尝试调用训练函数时出现错误:
TypeError: 'Tensor' object is not iterable.
我做错了什么?
def data_generator():
"""
Generator for image (features) and ground truth object positions (labels)
Sample an image and object positions from a procedurally generated data source
"""
while True:
source.step() # generate next data point
object_ground_truth = source.get_ground_truth() # list of 9 floats
cam_img = source.get_cam_frame() # image (224, 224, 3)
yield (cam_img, object_ground_truth)
def dataset_input_fn():
"""
Tensorflow `Dataset` object from generator
"""
dataset = tf.data.Dataset.from_generator(data_generator, (tf.uint8, tf.float32), \
(tf.TensorShape([224, 224, 3]), tf.TensorShape([9])))
dataset = dataset.batch(16)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
def main():
"""
Estimator [from Keras model](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models)
Try to call `est_vgg.train()` leads to the error
"""
....
est_vgg16 = tf.keras.estimator.model_to_estimator(keras_model=keras_vgg16)
est_vgg16.train(input_fn=dataset_input_fn, steps=10)
....
这里是full code
(注意:事物的命名与本题不同)
这是堆栈跟踪:
Traceback (most recent call last):
File "./rock_detector.py", line 155, in <module>
main()
File "./rock_detector.py", line 117, in main
est_vgg16.train(input_fn=dataset_input_fn, steps=10)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 145, in model_fn
labels)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 92, in _clone_and_build_model
keras_model, features)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 58, in _create_ordered_io
for key in estimator_io_dict:
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
使您的输入函数 return 成为这样的特征字典:
def dataset_input_fn():
...
features, labels = iterator.get_next()
return {'image': features}, labels
我有一个程序生成的(无限)数据源,我正在尝试将其用作高级 Tensorflow Estimator
的输入来训练基于图像的 3D 对象检测器。
我像在 Tensorflor Estimator 中一样设置数据集 Quickstart, and my dataset_input_fn
returns a tuple of features and labels Tensor
's, just as the Estimator.train
function specifies, and how this tutorial shows,但在尝试调用训练函数时出现错误:
TypeError: 'Tensor' object is not iterable.
我做错了什么?
def data_generator():
"""
Generator for image (features) and ground truth object positions (labels)
Sample an image and object positions from a procedurally generated data source
"""
while True:
source.step() # generate next data point
object_ground_truth = source.get_ground_truth() # list of 9 floats
cam_img = source.get_cam_frame() # image (224, 224, 3)
yield (cam_img, object_ground_truth)
def dataset_input_fn():
"""
Tensorflow `Dataset` object from generator
"""
dataset = tf.data.Dataset.from_generator(data_generator, (tf.uint8, tf.float32), \
(tf.TensorShape([224, 224, 3]), tf.TensorShape([9])))
dataset = dataset.batch(16)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
def main():
"""
Estimator [from Keras model](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models)
Try to call `est_vgg.train()` leads to the error
"""
....
est_vgg16 = tf.keras.estimator.model_to_estimator(keras_model=keras_vgg16)
est_vgg16.train(input_fn=dataset_input_fn, steps=10)
....
这里是full code
(注意:事物的命名与本题不同)
这是堆栈跟踪:
Traceback (most recent call last):
File "./rock_detector.py", line 155, in <module>
main()
File "./rock_detector.py", line 117, in main
est_vgg16.train(input_fn=dataset_input_fn, steps=10)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 145, in model_fn
labels)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 92, in _clone_and_build_model
keras_model, features)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 58, in _create_ordered_io
for key in estimator_io_dict:
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
使您的输入函数 return 成为这样的特征字典:
def dataset_input_fn():
...
features, labels = iterator.get_next()
return {'image': features}, labels