TensorFlow 对象检测模型与库存模型一起正常工作,但失败并出现有关未实施已实施操作的错误
TensorFlow object detection model works properly with stock model, but fails with error about an implemented operation not being implemented
使用 Google Cloud 上的 TPU 训练模式,我训练了一个 SSD MobileNet V1 FPN 模型来识别两种类型的对象。模型训练无误,我能够在 TensorBoard 中进行评估。转换为 TensorFlow Lite 并尝试 运行 演示应用程序中的模型进行对象检测时,模型失败并出现有关未实现自定义操作的错误,尽管 TensorFlow 的文档说明该操作是在 Lite 版本中实现的。
我在 Google Cloud 上使用 TPU 训练模型,并从存储桶下载模型。
接下来,我使用对象检测模型的最新版本导出了模型(路径是故意通用的):
python -m object_detection/export_tflite_ssd_graph \
--pipeline_config_path=$PATH_TO_CONFIG_FILE \
--trained_checkpoint=model.ckpt-$CHECKPOINT \
--output_directory=$OUTPUT_DIR \
--add_postprocessing_op=true
接下来,我使用 TensorFlow Git 存储库中的最新 1.12 标签转换模型(使用 Bazel 0.21 以避免 Bazel 错误):
bazel run -c opt //tensorflow/contrib/lite/toco:toco \
--incompatible_package_name_is_a_function=false \
-- \
--input_file=$OUTPUT_DIR/tflite_graph.pb \
--output_file=$OUTPUT_DIR/detect.tflite \
--input_shapes=1,640,640,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--allow_custom_ops
接下来,我编辑了 Bazel 构建闭包以包含我复制到应用程序目录的模型,并编辑了 DetectionActivity Java 文件以引用我的模型:
# out of context
assets = [
#"//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
#"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
#"@tflite_conv_actions_frozen//:conv_actions_frozen.tflite",
#"//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt",
#"@tflite_mobilenet_ssd//:mobilenet_ssd.tflite",
"//tensorflow/contrib/lite/examples/android/app/src/main/assets:detect.tflite",
#"//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt",
"//tensorflow/contrib/lite/examples/android/app/src/main/assets:pascal_labels.txt",
],
# out of context
private static final int TF_OD_API_INPUT_SIZE = 640;
private static final boolean TF_OD_API_IS_QUANTIZED = false;
private static final String TF_OD_API_MODEL_FILE = "file:///android_asset/detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/pascal_labels.txt";
最后,我使用以下命令构建并部署了应用程序:
bazel build -c opt --config=android_arm64 --cxxopt='--std=c++11' "//tensorflow/contrib/lite/examples/android:tflite_demo"
adb install -r bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk
我修改的所有代码都存储在 https://github.com/tensorflow/models/tree/master/research/object_detection 的存储库中。
我希望结果是一个工作的应用程序,可以通过构建应用程序而不对程序进行任何修改来证明(存储库中的库存)。
实际结果是应用程序在启动后立即崩溃并显示以下错误消息,使用 Logcat:
捕获
2019-02-09 16:38:28.229 32716-32716/? E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.lite.demo, PID: 32716
java.lang.RuntimeException: java.lang.IllegalArgumentException: Internal error: Cannot create interpreter: Didn't find custom op for name 'ResizeNearestNeighbor' with version 1
Registration failed.
at org.tensorflow.demo.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:124)
at org.tensorflow.demo.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:110)
at org.tensorflow.demo.CameraActivity.onPreviewSizeChosen(CameraActivity.java:362)
at org.tensorflow.demo.CameraConnectionFragment.setUpCameraOutputs(CameraConnectionFragment.java:401)
at org.tensorflow.demo.CameraConnectionFragment.openCamera(CameraConnectionFragment.java:408)
at org.tensorflow.demo.CameraConnectionFragment.access[=15=]0(CameraConnectionFragment.java:64)
at org.tensorflow.demo.CameraConnectionFragment.onSurfaceTextureAvailable(CameraConnectionFragment.java:95)
at android.view.TextureView.getHardwareLayer(TextureView.java:390)
at android.view.TextureView.draw(TextureView.java:339)
at android.view.View.updateDisplayListIfDirty(View.java:18150)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.updateDisplayListIfDirty(View.java:18141)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.draw(View.java:19203)
at android.view.View.updateDisplayListIfDirty(View.java:18150)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.updateDisplayListIfDirty(View.java:18141)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.updateDisplayListIfDirty(View.java:18141)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.draw(View.java:19203)
at com.android.internal.policy.DecorView.draw(DecorView.java:825)
at android.view.View.updateDisplayListIfDirty(View.java:18150)
at android.view.ThreadedRenderer.updateViewTreeDisplayList(ThreadedRenderer.java:669)
at android.view.ThreadedRenderer.updateRootDisplayList(ThreadedRenderer.java:675)
at android.view.ThreadedRenderer.draw(ThreadedRenderer.java:783)
at android.view.ViewRootImpl.draw(ViewRootImpl.java:3098)
at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:2912)
at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:2465)
at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:1453)
at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:6958)
at android.view.Choreographer$CallbackRecord.run(Choreographer.java:911)
at android.view.Choreographer.doCallbacks(Choreographer.java:723)
at android.view.Choreographer.doFrame(Choreographer.java:658)
at android.view.Choreographer$FrameDisplayEventReceiver.run(Choreographer.java:897)
at android.os.Handler.handleCallback(Handler.java:790)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loop(Looper.java:164)
at android.app.ActivityThread.main(ActivityThread.java:6626)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:438)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:811)
Caused by: java.lang.IllegalArgumentException: Internal error: Cannot create interpreter: Didn't find custom op for name 'ResizeNearestNeighbor' with version 1
Registration failed.
at org.tensorflow.lite.NativeInterpreterWrapper.createInterpreter(Native Method)
2019-02-09 16:38:28.229 32716-32716/? E/AndroidRuntime: at org.tensorflow.lite.NativeInterpreterWrapper.<init>(NativeInterpreterWrapper.java:70)
at org.tensorflow.lite.Interpreter.<init>(Interpreter.java:175)
at org.tensorflow.lite.Interpreter.<init>(Interpreter.java:163)
at org.tensorflow.demo.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:122)
... 51 more
以下是我可以提供的所有资源来帮助调试此问题:
- 训练配置:https://storage.cloud.google.com/robocubs-ml/debug/config/tpu.config
- 检查点文件
- 检查点描述符:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/checkpoint
- 图表:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/graph.pbtxt
- 检查点数据:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/model.ckpt-246400.data-00000-of-00001
- 检查点索引:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/model.ckpt-246400.index
- 检查点元:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/model.ckpt-246400.meta
- 保存的模型
- 检查点描述符:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/checkpoint
- 冻结推理图:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/checkpoint
- 检查点数据:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/model.ckpt.data-00000-of-00001
- 检查点索引:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/model.ckpt.index
- 检查点元:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/model.ckpt.meta
- 管道配置:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/pipeline.config
- 保存的模型:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/saved_model/saved_model.pb
- TensorFlow Lite 文件
我找到了问题的解决方案。此时使用 TensorFlow v1.13.0-rc1
似乎可以解决问题。
这是因为 TensorFlow Lite 的 ResizeNearestNeighbor
操作直到 v1.13 才存在,我意识到我的主要错误是尽管使用了 v1.12 但还是查看了 v1.13 的文档。
使用 Google Cloud 上的 TPU 训练模式,我训练了一个 SSD MobileNet V1 FPN 模型来识别两种类型的对象。模型训练无误,我能够在 TensorBoard 中进行评估。转换为 TensorFlow Lite 并尝试 运行 演示应用程序中的模型进行对象检测时,模型失败并出现有关未实现自定义操作的错误,尽管 TensorFlow 的文档说明该操作是在 Lite 版本中实现的。
我在 Google Cloud 上使用 TPU 训练模型,并从存储桶下载模型。
接下来,我使用对象检测模型的最新版本导出了模型(路径是故意通用的):
python -m object_detection/export_tflite_ssd_graph \
--pipeline_config_path=$PATH_TO_CONFIG_FILE \
--trained_checkpoint=model.ckpt-$CHECKPOINT \
--output_directory=$OUTPUT_DIR \
--add_postprocessing_op=true
接下来,我使用 TensorFlow Git 存储库中的最新 1.12 标签转换模型(使用 Bazel 0.21 以避免 Bazel 错误):
bazel run -c opt //tensorflow/contrib/lite/toco:toco \
--incompatible_package_name_is_a_function=false \
-- \
--input_file=$OUTPUT_DIR/tflite_graph.pb \
--output_file=$OUTPUT_DIR/detect.tflite \
--input_shapes=1,640,640,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--allow_custom_ops
接下来,我编辑了 Bazel 构建闭包以包含我复制到应用程序目录的模型,并编辑了 DetectionActivity Java 文件以引用我的模型:
# out of context
assets = [
#"//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
#"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
#"@tflite_conv_actions_frozen//:conv_actions_frozen.tflite",
#"//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt",
#"@tflite_mobilenet_ssd//:mobilenet_ssd.tflite",
"//tensorflow/contrib/lite/examples/android/app/src/main/assets:detect.tflite",
#"//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt",
"//tensorflow/contrib/lite/examples/android/app/src/main/assets:pascal_labels.txt",
],
# out of context
private static final int TF_OD_API_INPUT_SIZE = 640;
private static final boolean TF_OD_API_IS_QUANTIZED = false;
private static final String TF_OD_API_MODEL_FILE = "file:///android_asset/detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/pascal_labels.txt";
最后,我使用以下命令构建并部署了应用程序:
bazel build -c opt --config=android_arm64 --cxxopt='--std=c++11' "//tensorflow/contrib/lite/examples/android:tflite_demo"
adb install -r bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk
我修改的所有代码都存储在 https://github.com/tensorflow/models/tree/master/research/object_detection 的存储库中。
我希望结果是一个工作的应用程序,可以通过构建应用程序而不对程序进行任何修改来证明(存储库中的库存)。
实际结果是应用程序在启动后立即崩溃并显示以下错误消息,使用 Logcat:
捕获2019-02-09 16:38:28.229 32716-32716/? E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.lite.demo, PID: 32716
java.lang.RuntimeException: java.lang.IllegalArgumentException: Internal error: Cannot create interpreter: Didn't find custom op for name 'ResizeNearestNeighbor' with version 1
Registration failed.
at org.tensorflow.demo.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:124)
at org.tensorflow.demo.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:110)
at org.tensorflow.demo.CameraActivity.onPreviewSizeChosen(CameraActivity.java:362)
at org.tensorflow.demo.CameraConnectionFragment.setUpCameraOutputs(CameraConnectionFragment.java:401)
at org.tensorflow.demo.CameraConnectionFragment.openCamera(CameraConnectionFragment.java:408)
at org.tensorflow.demo.CameraConnectionFragment.access[=15=]0(CameraConnectionFragment.java:64)
at org.tensorflow.demo.CameraConnectionFragment.onSurfaceTextureAvailable(CameraConnectionFragment.java:95)
at android.view.TextureView.getHardwareLayer(TextureView.java:390)
at android.view.TextureView.draw(TextureView.java:339)
at android.view.View.updateDisplayListIfDirty(View.java:18150)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.updateDisplayListIfDirty(View.java:18141)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.draw(View.java:19203)
at android.view.View.updateDisplayListIfDirty(View.java:18150)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.updateDisplayListIfDirty(View.java:18141)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.updateDisplayListIfDirty(View.java:18141)
at android.view.View.draw(View.java:18928)
at android.view.ViewGroup.drawChild(ViewGroup.java:4240)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4026)
at android.view.View.draw(View.java:19203)
at com.android.internal.policy.DecorView.draw(DecorView.java:825)
at android.view.View.updateDisplayListIfDirty(View.java:18150)
at android.view.ThreadedRenderer.updateViewTreeDisplayList(ThreadedRenderer.java:669)
at android.view.ThreadedRenderer.updateRootDisplayList(ThreadedRenderer.java:675)
at android.view.ThreadedRenderer.draw(ThreadedRenderer.java:783)
at android.view.ViewRootImpl.draw(ViewRootImpl.java:3098)
at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:2912)
at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:2465)
at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:1453)
at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:6958)
at android.view.Choreographer$CallbackRecord.run(Choreographer.java:911)
at android.view.Choreographer.doCallbacks(Choreographer.java:723)
at android.view.Choreographer.doFrame(Choreographer.java:658)
at android.view.Choreographer$FrameDisplayEventReceiver.run(Choreographer.java:897)
at android.os.Handler.handleCallback(Handler.java:790)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loop(Looper.java:164)
at android.app.ActivityThread.main(ActivityThread.java:6626)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:438)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:811)
Caused by: java.lang.IllegalArgumentException: Internal error: Cannot create interpreter: Didn't find custom op for name 'ResizeNearestNeighbor' with version 1
Registration failed.
at org.tensorflow.lite.NativeInterpreterWrapper.createInterpreter(Native Method)
2019-02-09 16:38:28.229 32716-32716/? E/AndroidRuntime: at org.tensorflow.lite.NativeInterpreterWrapper.<init>(NativeInterpreterWrapper.java:70)
at org.tensorflow.lite.Interpreter.<init>(Interpreter.java:175)
at org.tensorflow.lite.Interpreter.<init>(Interpreter.java:163)
at org.tensorflow.demo.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:122)
... 51 more
以下是我可以提供的所有资源来帮助调试此问题:
- 训练配置:https://storage.cloud.google.com/robocubs-ml/debug/config/tpu.config
- 检查点文件
- 检查点描述符:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/checkpoint
- 图表:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/graph.pbtxt
- 检查点数据:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/model.ckpt-246400.data-00000-of-00001
- 检查点索引:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/model.ckpt-246400.index
- 检查点元:https://storage.cloud.google.com/robocubs-ml/debug/checkpoint/model.ckpt-246400.meta
- 保存的模型
- 检查点描述符:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/checkpoint
- 冻结推理图:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/checkpoint
- 检查点数据:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/model.ckpt.data-00000-of-00001
- 检查点索引:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/model.ckpt.index
- 检查点元:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/model.ckpt.meta
- 管道配置:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/pipeline.config
- 保存的模型:https://storage.cloud.google.com/robocubs-ml/debug/exported_model/saved_model/saved_model.pb
- TensorFlow Lite 文件
我找到了问题的解决方案。此时使用 TensorFlow v1.13.0-rc1
似乎可以解决问题。
这是因为 TensorFlow Lite 的 ResizeNearestNeighbor
操作直到 v1.13 才存在,我意识到我的主要错误是尽管使用了 v1.12 但还是查看了 v1.13 的文档。