如何从 Tensorflow Object Detection API 正确地提供对象检测模型?
How to properly serve an object detection model from Tensorflow Object Detection API?
我正在使用 Tensorflow 对象检测 API(github.com/tensorflow/models/tree/master/object_detection) 进行一项对象检测任务。现在我在使用 Tensorflow Serving(tensorflow.github.io/serving/).
训练的检测模型时遇到问题
1. 我遇到的第一个问题是关于将模型导出到可服务文件。
对象检测 api 包含导出脚本,以便我能够将 ckpt 文件转换为带有变量的 pb 文件。但是,输出文件在 'variables' 文件夹中不会有任何内容。我虽然这是一个错误并在 Github 上报告了它,但似乎他们实习了将变量转换为常量,这样就不会有变量了。详情可见HERE.
我在导出保存的模型时使用的标志如下:
CUDA_VISIBLE_DEVICES=0 python export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path configs/rfcn_resnet50_car_Jul_20.config \
--checkpoint_path resnet_ckpt/model.ckpt-17586 \
--inference_graph_path serving_model/1 \
--export_as_saved_model True
当我将 --export_as_saved_model 切换为 False 时,运行 在 python 中完全没问题。
但是,我仍然无法为模型提供服务。
当我试图 运行:
~/serving$ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=gan --model_base_path=<my_model_path>
我得到了:
2017-07-27 16:11:53.222439: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:155] Restoring SavedModel bundle.
2017-07-27 16:11:53.222497: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:165] The specified SavedModel has no variables; no checkpoints were restored.
2017-07-27 16:11:53.222502: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running LegacyInitOp on SavedModel bundle.
2017-07-27 16:11:53.229463: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: success. Took 281805 microseconds.
2017-07-27 16:11:53.229508: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: gan version: 1}
2017-07-27 16:11:53.244716: I tensorflow_serving/model_servers/main.cc:290] Running ModelServer at 0.0.0.0:9000 ...
我认为模型未正确加载,因为它显示 "The specified SavedModel has no variables; no checkpoints were restored."
不过既然我们已经把所有的变量都转换成常量了,似乎也合情合理。我在这里不确定。
2。我无法使用客户端调用服务器并对样本图像进行检测。
客户凭证如下:
from __future__ import print_function
from __future__ import absolute_import
# Communication to TensorFlow server via gRPC
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
from PIL import Image
# TensorFlow serving stuff to send messages
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
# Command line arguments
tf.app.flags.DEFINE_string('server', 'localhost:9000',
'PredictionService host:port')
tf.app.flags.DEFINE_string('image', '', 'path to image in JPEG format')
FLAGS = tf.app.flags.FLAGS
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
def main(_):
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
# Send request
request = predict_pb2.PredictRequest()
image = Image.open(FLAGS.image)
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
# Call GAN model to make prediction on the image
request.model_spec.name = 'gan'
request.model_spec.signature_name = 'predict_images'
request.inputs['inputs'].CopyFrom(
tf.contrib.util.make_tensor_proto(image_np_expanded))
result = stub.Predict(request, 60.0) # 60 secs timeout
print(result)
if __name__ == '__main__':
tf.app.run()
为了匹配request.model_spec.signature_name = 'predict_images'
,我修改了对象检测api中的exporter.py脚本(github.com/tensorflow/models/blob/master/object_detection/exporter.py)从第289行开始:
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
收件人:
signature_def_map={
'predict_images': detection_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
因为我不知道如何调用默认签名密钥。
当我运行以下命令时:
bazel-bin/tensorflow_serving/example/client --server=localhost:9000 --image=<my_image_file>
我收到以下错误消息:
Traceback (most recent call last):
File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 54, in <module>
tf.app.run()
File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 49, in main
result = stub.Predict(request, 60.0) # 60 secs timeout
File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__
self._request_serializer, self._response_deserializer)
File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary
raise _abortion_error(rpc_error_call)
grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0")
不太清楚这里发生了什么。
最初我虽然可能是我的客户端脚本不正确,但在我发现 AbortionError 来自 github.com/tensorflow/tensorflow/blob/f488419cd6d9256b25ba25cbe736097dfeee79f9/tensorflow/core/graph/subgraph.cc 之后。似乎我在构建图表时遇到了这个错误。所以这可能是我遇到的第一个问题引起的。
我是新手,所以我真的很困惑。我想我一开始可能是错的。有什么方法可以正确导出和提供检测模型吗?任何建议都会有很大帮助!
当前的出口商代码未正确填充签名字段。所以使用模型服务器服务是行不通的。对此表示歉意。一个更好地支持导出模型的新版本即将推出。它包括服务所需的一些重要修复和改进,尤其是在 Cloud ML Engine 上服务。如果您想试用它的早期版本,请参阅 github issue。
对于"The specified SavedModel has no variables; no checkpoints were restored."消息,由于您所说的确切原因,这是预期的,因为所有变量都已转换为图中的常量。对于"FeedInputs: unable to find feed output ToFloat:0"的错误,请确保在构建模型服务器时使用TF 1.2。
我正在努力解决确切的问题。我试图从 Tensorflow Object Detection API Zoo
托管预训练的 SSDMobileNet-COCO 检查点
原来我使用的是 tensorflow/models 的旧提交,它恰好是 serving
的默认子模块
我只是用
拉取了最近的提交
cd serving/tf_models
git pull origin master
git checkout master
之后,再次搭建模型服务器
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
错误消失了,我能够得到准确的预测
对于错误
grpc.framework.interfaces.face.face.AbortionError:
AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0"
只需将 tf_models 升级到最新版本,然后重新导出模型即可。
你的想法很好。有那个警告没关系。
问题是输入需要按照模型的预期转换为 uint8
。这是对我有用的代码片段。
request = predict_pb2.PredictRequest()
request.model_spec.name = 'gan'
request.model_spec.signature_name =
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
image = Image.open('any.jpg')
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
request.inputs['inputs'].CopyFrom(
tf.contrib.util.make_tensor_proto(image_np_expanded,
shape=image_np_expanded.shape, dtype='uint8'))
这部分对你很重要 shape=image_np_expanded.shape, dtype='uint8' 并确保拉取最新的服务更新。
我正在使用 Tensorflow 对象检测 API(github.com/tensorflow/models/tree/master/object_detection) 进行一项对象检测任务。现在我在使用 Tensorflow Serving(tensorflow.github.io/serving/).
训练的检测模型时遇到问题1. 我遇到的第一个问题是关于将模型导出到可服务文件。 对象检测 api 包含导出脚本,以便我能够将 ckpt 文件转换为带有变量的 pb 文件。但是,输出文件在 'variables' 文件夹中不会有任何内容。我虽然这是一个错误并在 Github 上报告了它,但似乎他们实习了将变量转换为常量,这样就不会有变量了。详情可见HERE.
我在导出保存的模型时使用的标志如下:
CUDA_VISIBLE_DEVICES=0 python export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path configs/rfcn_resnet50_car_Jul_20.config \
--checkpoint_path resnet_ckpt/model.ckpt-17586 \
--inference_graph_path serving_model/1 \
--export_as_saved_model True
当我将 --export_as_saved_model 切换为 False 时,运行 在 python 中完全没问题。
但是,我仍然无法为模型提供服务。
当我试图 运行:
~/serving$ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=gan --model_base_path=<my_model_path>
我得到了:
2017-07-27 16:11:53.222439: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:155] Restoring SavedModel bundle.
2017-07-27 16:11:53.222497: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:165] The specified SavedModel has no variables; no checkpoints were restored.
2017-07-27 16:11:53.222502: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running LegacyInitOp on SavedModel bundle.
2017-07-27 16:11:53.229463: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: success. Took 281805 microseconds.
2017-07-27 16:11:53.229508: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: gan version: 1}
2017-07-27 16:11:53.244716: I tensorflow_serving/model_servers/main.cc:290] Running ModelServer at 0.0.0.0:9000 ...
我认为模型未正确加载,因为它显示 "The specified SavedModel has no variables; no checkpoints were restored."
不过既然我们已经把所有的变量都转换成常量了,似乎也合情合理。我在这里不确定。
2。我无法使用客户端调用服务器并对样本图像进行检测。
客户凭证如下:
from __future__ import print_function
from __future__ import absolute_import
# Communication to TensorFlow server via gRPC
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
from PIL import Image
# TensorFlow serving stuff to send messages
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
# Command line arguments
tf.app.flags.DEFINE_string('server', 'localhost:9000',
'PredictionService host:port')
tf.app.flags.DEFINE_string('image', '', 'path to image in JPEG format')
FLAGS = tf.app.flags.FLAGS
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
def main(_):
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
# Send request
request = predict_pb2.PredictRequest()
image = Image.open(FLAGS.image)
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
# Call GAN model to make prediction on the image
request.model_spec.name = 'gan'
request.model_spec.signature_name = 'predict_images'
request.inputs['inputs'].CopyFrom(
tf.contrib.util.make_tensor_proto(image_np_expanded))
result = stub.Predict(request, 60.0) # 60 secs timeout
print(result)
if __name__ == '__main__':
tf.app.run()
为了匹配request.model_spec.signature_name = 'predict_images'
,我修改了对象检测api中的exporter.py脚本(github.com/tensorflow/models/blob/master/object_detection/exporter.py)从第289行开始:
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
收件人:
signature_def_map={
'predict_images': detection_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
因为我不知道如何调用默认签名密钥。
当我运行以下命令时:
bazel-bin/tensorflow_serving/example/client --server=localhost:9000 --image=<my_image_file>
我收到以下错误消息:
Traceback (most recent call last):
File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 54, in <module>
tf.app.run()
File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 49, in main
result = stub.Predict(request, 60.0) # 60 secs timeout
File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__
self._request_serializer, self._response_deserializer)
File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary
raise _abortion_error(rpc_error_call)
grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0")
不太清楚这里发生了什么。
最初我虽然可能是我的客户端脚本不正确,但在我发现 AbortionError 来自 github.com/tensorflow/tensorflow/blob/f488419cd6d9256b25ba25cbe736097dfeee79f9/tensorflow/core/graph/subgraph.cc 之后。似乎我在构建图表时遇到了这个错误。所以这可能是我遇到的第一个问题引起的。
我是新手,所以我真的很困惑。我想我一开始可能是错的。有什么方法可以正确导出和提供检测模型吗?任何建议都会有很大帮助!
当前的出口商代码未正确填充签名字段。所以使用模型服务器服务是行不通的。对此表示歉意。一个更好地支持导出模型的新版本即将推出。它包括服务所需的一些重要修复和改进,尤其是在 Cloud ML Engine 上服务。如果您想试用它的早期版本,请参阅 github issue。
对于"The specified SavedModel has no variables; no checkpoints were restored."消息,由于您所说的确切原因,这是预期的,因为所有变量都已转换为图中的常量。对于"FeedInputs: unable to find feed output ToFloat:0"的错误,请确保在构建模型服务器时使用TF 1.2。
我正在努力解决确切的问题。我试图从 Tensorflow Object Detection API Zoo
托管预训练的 SSDMobileNet-COCO 检查点原来我使用的是 tensorflow/models 的旧提交,它恰好是 serving
的默认子模块我只是用
拉取了最近的提交
cd serving/tf_models
git pull origin master
git checkout master
之后,再次搭建模型服务器
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
错误消失了,我能够得到准确的预测
对于错误
grpc.framework.interfaces.face.face.AbortionError:
AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0"
只需将 tf_models 升级到最新版本,然后重新导出模型即可。
你的想法很好。有那个警告没关系。
问题是输入需要按照模型的预期转换为
uint8
。这是对我有用的代码片段。
request = predict_pb2.PredictRequest() request.model_spec.name = 'gan' request.model_spec.signature_name = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY image = Image.open('any.jpg') image_np = load_image_into_numpy_array(image) image_np_expanded = np.expand_dims(image_np, axis=0) request.inputs['inputs'].CopyFrom( tf.contrib.util.make_tensor_proto(image_np_expanded, shape=image_np_expanded.shape, dtype='uint8'))
这部分对你很重要 shape=image_np_expanded.shape, dtype='uint8' 并确保拉取最新的服务更新。