如何从 TFLite 对象检测中获取有用数据 Python
How to get useful data from TFLite Object Detection Python
我有一个 raspberry pi 4,我想以良好的帧率进行对象检测。我尝试了 tensorflow 和 YOLO,但都以 1 fps 的速度 运行。所以我正在尝试 TensorFlow Lite。我已经下载了 tflite 文件和 labelmap.txt 文件。我已经使用 this link 来尝试 运行 推理。在这里我遇到了一个问题。我不明白如何从输出中获取结果(分类、边界框协调和配置)。
这是我的代码:
import tensorflow as tf
import numpy as np
import cv2
interpreter = tf.lite.Interpreter(model_path="/content/drive/My Drive/detect.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
print()
input_shape = input_details[0]['shape']
im = cv2.imread("/content/drive/My Drive/doggy.jpg")
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_rgb = cv2.resize(im_rgb, (input_shape[1], input_shape[2]))
input_data = np.expand_dims(im_rgb, axis=0)
print(input_data.shape)
print()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
print()
print(output_data)
这是我的输出:
[{'name': 'normalized_input_image_tensor', 'index': 175, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128)}]
[{'name': 'TFLite_Detection_PostProcess', 'index': 167, 'shape': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 168, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 169, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 170, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
(1, 300, 300, 3)
(1, 10, 4)
[[[ 1.66415479e-02 5.48024022e-04 8.67791831e-01 3.35325867e-01]
[ 7.41335377e-02 3.22245747e-01 9.64617252e-01 9.71388936e-01]
[-2.11861148e-03 5.41743517e-01 2.60241032e-01 7.02846169e-01]
[-5.67546487e-03 3.26282382e-01 8.59034657e-01 6.30770981e-01]
[ 7.27111334e-03 7.90268779e-01 2.86753297e-01 9.56545353e-01]
[ 2.07318692e-03 7.96441555e-01 5.48386931e-01 9.96111989e-01]
[-1.04907183e-02 2.38761827e-01 6.75976276e-01 7.01156497e-01]
[ 3.12007014e-02 1.34294275e-02 5.82291842e-01 3.10949832e-01]
[-1.95578858e-03 7.05318868e-01 9.18281525e-02 7.96184599e-01]
[-5.43205580e-03 3.23292404e-01 6.34427786e-01 5.68508685e-01]]]
输出(最后一个列表)似乎是一个非常小的数字数组,我如何从中得到结果?
谢谢
我在 github 的@daverim 的帮助下解决了这个问题,我在那里打开了一个问题。 https://github.com/tensorflow/tensorflow/issues/34761。这是获取有用数据的代码:
detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])
print(num_boxes)
for i in range(int(num_boxes[0])):
if detection_scores[0, i] > .5:
class_id = detection_classes[0, i]
print(class_id)
使用 labelmap.txt 文件我们可以获得 class 名称。
我有一个 raspberry pi 4,我想以良好的帧率进行对象检测。我尝试了 tensorflow 和 YOLO,但都以 1 fps 的速度 运行。所以我正在尝试 TensorFlow Lite。我已经下载了 tflite 文件和 labelmap.txt 文件。我已经使用 this link 来尝试 运行 推理。在这里我遇到了一个问题。我不明白如何从输出中获取结果(分类、边界框协调和配置)。
这是我的代码:
import tensorflow as tf
import numpy as np
import cv2
interpreter = tf.lite.Interpreter(model_path="/content/drive/My Drive/detect.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
print()
input_shape = input_details[0]['shape']
im = cv2.imread("/content/drive/My Drive/doggy.jpg")
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_rgb = cv2.resize(im_rgb, (input_shape[1], input_shape[2]))
input_data = np.expand_dims(im_rgb, axis=0)
print(input_data.shape)
print()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
print()
print(output_data)
这是我的输出:
[{'name': 'normalized_input_image_tensor', 'index': 175, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128)}]
[{'name': 'TFLite_Detection_PostProcess', 'index': 167, 'shape': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 168, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 169, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 170, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
(1, 300, 300, 3)
(1, 10, 4)
[[[ 1.66415479e-02 5.48024022e-04 8.67791831e-01 3.35325867e-01]
[ 7.41335377e-02 3.22245747e-01 9.64617252e-01 9.71388936e-01]
[-2.11861148e-03 5.41743517e-01 2.60241032e-01 7.02846169e-01]
[-5.67546487e-03 3.26282382e-01 8.59034657e-01 6.30770981e-01]
[ 7.27111334e-03 7.90268779e-01 2.86753297e-01 9.56545353e-01]
[ 2.07318692e-03 7.96441555e-01 5.48386931e-01 9.96111989e-01]
[-1.04907183e-02 2.38761827e-01 6.75976276e-01 7.01156497e-01]
[ 3.12007014e-02 1.34294275e-02 5.82291842e-01 3.10949832e-01]
[-1.95578858e-03 7.05318868e-01 9.18281525e-02 7.96184599e-01]
[-5.43205580e-03 3.23292404e-01 6.34427786e-01 5.68508685e-01]]]
输出(最后一个列表)似乎是一个非常小的数字数组,我如何从中得到结果?
谢谢
我在 github 的@daverim 的帮助下解决了这个问题,我在那里打开了一个问题。 https://github.com/tensorflow/tensorflow/issues/34761。这是获取有用数据的代码:
detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])
print(num_boxes)
for i in range(int(num_boxes[0])):
if detection_scores[0, i] > .5:
class_id = detection_classes[0, i]
print(class_id)
使用 labelmap.txt 文件我们可以获得 class 名称。