如何从 TFLite 对象检测中获取有用数据 Python

How to get useful data from TFLite Object Detection Python

我有一个 raspberry pi 4,我想以良好的帧率进行对象检测。我尝试了 tensorflow 和 YOLO,但都以 1 fps 的速度 运行。所以我正在尝试 TensorFlow Lite。我已经下载了 tflite 文件和 labelmap.txt 文件。我已经使用 this link 来尝试 运行 推理。在这里我遇到了一个问题。我不明白如何从输出中获取结果(分类、边界框协调和配置)。

这是我的代码:

import tensorflow as tf 
import numpy as np
import cv2

interpreter = tf.lite.Interpreter(model_path="/content/drive/My Drive/detect.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
print()

input_shape = input_details[0]['shape']
im = cv2.imread("/content/drive/My Drive/doggy.jpg")
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_rgb = cv2.resize(im_rgb, (input_shape[1], input_shape[2]))
input_data = np.expand_dims(im_rgb, axis=0)
print(input_data.shape)
print()

interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
print()
print(output_data)

这是我的输出:

[{'name': 'normalized_input_image_tensor', 'index': 175, 'shape': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128)}]
[{'name': 'TFLite_Detection_PostProcess', 'index': 167, 'shape': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 168, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 169, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 170, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

(1, 300, 300, 3)

(1, 10, 4)

[[[ 1.66415479e-02  5.48024022e-04  8.67791831e-01  3.35325867e-01]
  [ 7.41335377e-02  3.22245747e-01  9.64617252e-01  9.71388936e-01]
  [-2.11861148e-03  5.41743517e-01  2.60241032e-01  7.02846169e-01]
  [-5.67546487e-03  3.26282382e-01  8.59034657e-01  6.30770981e-01]
  [ 7.27111334e-03  7.90268779e-01  2.86753297e-01  9.56545353e-01]
  [ 2.07318692e-03  7.96441555e-01  5.48386931e-01  9.96111989e-01]
  [-1.04907183e-02  2.38761827e-01  6.75976276e-01  7.01156497e-01]
  [ 3.12007014e-02  1.34294275e-02  5.82291842e-01  3.10949832e-01]
  [-1.95578858e-03  7.05318868e-01  9.18281525e-02  7.96184599e-01]
  [-5.43205580e-03  3.23292404e-01  6.34427786e-01  5.68508685e-01]]]

输出(最后一个列表)似乎是一个非常小的数字数组,我如何从中得到结果?

谢谢

我在 github 的@daverim 的帮助下解决了这个问题,我在那里打开了一个问题。 https://github.com/tensorflow/tensorflow/issues/34761。这是获取有用数据的代码:

detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])
print(num_boxes)
for i in range(int(num_boxes[0])):
  if detection_scores[0, i] > .5:
       class_id = detection_classes[0, i]
       print(class_id)

使用 labelmap.txt 文件我们可以获得 class 名称。