使用 Tensorflow Lite 在 Android 上进行对象检测

Object Detection on Android with Tensorflow Lite

尝试使用 Android Studio 通过 Tensorflow Lite 实现 自定义 对象检测模型。我遵循此处提供的指导:Running on mobile with TensorFlow Lite,但没有成功。示例模型正常运行,显示所有检测到的标签。尽管如此,当我尝试使用我的自定义模型时,我根本没有得到任何标签。我也尝试过其他模型(来自互联网,但结果是一样的)。就像标签没有通过写入方式传递一样。我复制了我的detect.tflitelabelmap.txt,我改变了TF_OD_API_INPUT_SIZEDetectorActivity.java 中的 TF_OD_API_IS_QUANTIZED 但仍然没有得到结果(检测到 class 与一个边界框和一个分数)。

Logcat显示如下:

2020-10-11 18:37:54.315 31681-31681/org.tensorflow.lite.examples.detection E/HAL: PATH3 /odm/lib64/hw/gralloc.qcom.so
2020-10-11 18:37:54.315 31681-31681/org.tensorflow.lite.examples.detection E/HAL: PATH2 /vendor/lib64/hw/gralloc.qcom.so
2020-10-11 18:37:54.315 31681-31681/org.tensorflow.lite.examples.detection E/HAL: PATH1 /system/lib64/hw/gralloc.qcom.so
2020-10-11 18:37:54.315 31681-31681/org.tensorflow.lite.examples.detection E/HAL: PATH3 /odm/lib64/hw/gralloc.msm8953.so
2020-10-11 18:37:54.315 31681-31681/org.tensorflow.lite.examples.detection E/HAL: PATH2 /vendor/lib64/hw/gralloc.msm8953.so
2020-10-11 18:37:54.315 31681-31681/org.tensorflow.lite.examples.detection E/HAL: PATH1 /system/lib64/hw/gralloc.msm8953.so
2020-10-11 18:37:54.859 31681-31681/org.tensorflow.lite.examples.detection E/tensorflow: CameraActivity: Exception!
    java.lang.IllegalStateException: This model does not contain associated files, and is not a Zip file.
        at org.tensorflow.lite.support.metadata.MetadataExtractor.assertZipFile(MetadataExtractor.java:325)
        at org.tensorflow.lite.support.metadata.MetadataExtractor.getAssociatedFile(MetadataExtractor.java:165)
        at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:118)
        at org.tensorflow.lite.examples.detection.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:96)
        at org.tensorflow.lite.examples.detection.CameraActivity.onPreviewFrame(CameraActivity.java:200)
        at android.hardware.Camera$EventHandler.handleMessage(Camera.java:1157)
        at android.os.Handler.dispatchMessage(Handler.java:102)
        at android.os.Looper.loop(Looper.java:165)
        at android.app.ActivityThread.main(ActivityThread.java:6375)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:912)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:802)

如何进行检测?我是否需要与标签相关的附加文件(元数据),或者我做错了什么? 上面的案例是用 Android 7 设备测试的。谢谢!

那里看起来像是倒退。 您可以尝试以下方法吗?

<at your TF example repo>
$ git checkout de42482b453de6f7b6488203b20e7eec61ee722e^

这是未更新的文档的问题。

主要问题是样本已更新为使用附加了 Metadata 的模型,特别是将标签嵌入为模型的资产。

将标签文件添加到模型后,一切正常。

为了更好地理解 Gusthema 提出的建议解决方案,我为您提供了适用于我的案例的代码:

pip install tflite-support

from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb


# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "MobileNetV1 image classifier"
model_meta.description = ("Identify Unesco Monuments Route"
                          "image from a set of 18 categories")
model_meta.version = "v1"
model_meta.author = "TensorFlow"
model_meta.license = ("Apache License. Version 2.0 "
                      "http://www.apache.org/licenses/LICENSE-2.0.")


# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()


input_meta.name = "image"
input_meta.description = (
    "Input image to be classified. The expected image is {0} x {1}, with "
    "three channels (red, blue, and green) per pixel. Each value in the "
    "tensor is a single byte between 0 and 255.".format(300, 300))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats



# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the 18 labels respectively."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename('/content/gdrive/My Drive/models/research/deploy/labelmap.txt')
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]


# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = 4*[output_meta]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()


# metadata and the label file are written into the TFLite file
populator = _metadata.MetadataPopulator.with_model_file('/content/gdrive/My Drive/models/research/object_detection/exported_model/detect.tflite')
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(['/content/gdrive/My Drive/models/research/deploy/labelmap.txt'])
populator.populate()

最后,如果您想创建一个 json 文件来显示结果(元数据文件),您可以使用:

displayer = _metadata.MetadataDisplayer.with_model_file('/content/gdrive/My Drive/models/research/object_detection/exported_model/detect.tflite')
export_json_file = os.path.join('/content/gdrive/My Drive/models/research/object_detection/exported_model',
                    os.path.splitext('detect.tflite')[0] + ".json")
json_file = displayer.get_metadata_json()
# Optional: write out the metadata as a json file
with open(export_json_file, "w") as f:
  f.write(json_file)

P.S.: 小心更改代码的少数部分,以便与您的需求兼容。 (例如,如果您使用的是 512x512 的图像,则必须从“input_meta.description”变量中更改它)。