How to visualize detected boxes from TFLite model(How to get category index from TFLite model?)

How to visualize detected boxes from TFLite model (How to get category index from TFLite model?)

我有一个对象检测 TFLite 模型保存为 model.tflite 文件。我可以 运行 它作为

interpreter = tf.lite.Interpreter("model.tflite")

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_image)

interpreter.invoke()

然后得到输出为

detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])

我想在图片中绘制给定 类 的检测到的框。最简单的解决方案似乎是使用工具 viz_utils.visualize_boxes_and_labels_on_image_array as.

viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        detection_boxes,
        detection_classes,
        detection_scores,
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=20,
        min_score_thresh=.1,
        agnostic_mode=False

但是,为此需要 category_index(将 类 索引转换为人类可读的标签)。通常,您可以从包含标签的文件中加载它,对于 .tflite 模型,如果我没记错的话,标签在 .tflite 文件中应该是 included/packed。

但是,我不知道该怎么做,或者我应该使用哪些函数(我也查看了 tflite_support 库,但不知道如何从相关联的库中提取类别文件)。

使用 .tflite 文件可视化带有标签的检测框的正确方法是什么?它不必使用 viz_utils。任何帮助表示赞赏。谢谢。

# labels variable contains the list of the names of the category and
# it generates by reading the labels.txt
with open("labels.txt", "r") as f:
  txt = f.read()

labels = txt.splitlines()

for idx, box in enumerate(detection_boxes[0]):
    if detection_scores[0][idx] > threshold:
        class_name = labels[int(detection_classes[0][idx])]

我根据 https://github.com/tensorflow/models/issues/7458#issuecomment-523904465 创建了这个代码片段。