Tensorflow SSD 对象检测 类 选择
Tensorflow SSD object detection classes selection
我正在研究基本的 tensorflow 对象检测示例 here
它正在检测来自 coco 数据集的所有 90 类。但我只想从中检测到两个 类 。怎么做?
假设您只想检测 摩托车 和 人 。
在文件 visualization_util.py 中,转到 def draw_bounding_box_on_image_array
, 有一个函数 :
draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
thickness, display_str_list,
use_normalized_coordinates)
将此函数调用置于这样的条件下
if (display_str_list[0][0:3]=="per" or display_str_list[0][0:3]=="mot"):
draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
thickness, display_str_list,
use_normalized_coordinates)
这里"per"是"person"的前三个字母,"mot"是摩托车的前三个字母。这样,您就可以从所有其他对象中检测到您想要的对象
我正在研究基本的 tensorflow 对象检测示例 here 它正在检测来自 coco 数据集的所有 90 类。但我只想从中检测到两个 类 。怎么做?
假设您只想检测 摩托车 和 人 。 在文件 visualization_util.py 中,转到 def draw_bounding_box_on_image_array , 有一个函数 :
draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
thickness, display_str_list,
use_normalized_coordinates)
将此函数调用置于这样的条件下
if (display_str_list[0][0:3]=="per" or display_str_list[0][0:3]=="mot"):
draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
thickness, display_str_list,
use_normalized_coordinates)
这里"per"是"person"的前三个字母,"mot"是摩托车的前三个字母。这样,您就可以从所有其他对象中检测到您想要的对象