理解改进版的Tensorflow对象检测API

Understanding the improved version of Tensorflow object detection API

我在我的项目中使用 Tensorflow 对象检测 API 并遇到了这个 link: https://github.com/tensorflow/models/issues/3270 代码附在 link 上的 zip 文件中。具体没看懂的部分是这篇:

input_graph = tf.Graph()
with tf.Session(graph=input_graph):
    score = tf.placeholder(tf.float32, shape=(None, 1917, 90), name="Postprocessor/convert_scores")
    expand = tf.placeholder(tf.float32, shape=(None, 1917, 1, 4), name="Postprocessor/ExpandDims_1")
    for node in input_graph.as_graph_def().node:
        if node.name == "Postprocessor/convert_scores":
            score_def = node
        if node.name == "Postprocessor/ExpandDims_1":
            expand_def = node

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        dest_nodes = ['Postprocessor/convert_scores','Postprocessor/ExpandDims_1']

        edges = {}
        name_to_node_map = {}
        node_seq = {}
        seq = 0
        for node in od_graph_def.node:
            n = _node_name(node.name)
            name_to_node_map[n] = node
            edges[n] = [_node_name(x) for x in node.input]
            node_seq[n] = seq
            seq += 1

        for d in dest_nodes:
            assert d in name_to_node_map, "%s is not in graph" % d

        nodes_to_keep = set()
        next_to_visit = dest_nodes[:]
        while next_to_visit:
            n = next_to_visit[0]
            del next_to_visit[0]
            if n in nodes_to_keep:
                continue
            nodes_to_keep.add(n)
            next_to_visit += edges[n]

        nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])

        nodes_to_remove = set()
        for n in node_seq:
            if n in nodes_to_keep_list: 
                continue
            nodes_to_remove.add(n)
        nodes_to_remove_list = sorted(list(nodes_to_remove), key=lambda n: node_seq[n])

        keep = graph_pb2.GraphDef()
        for n in nodes_to_keep_list:
            keep.node.extend([copy.deepcopy(name_to_node_map[n])])

        remove = graph_pb2.GraphDef()
        remove.node.extend([score_def])
        remove.node.extend([expand_def])
        for n in nodes_to_remove_list:
            remove.node.extend([copy.deepcopy(name_to_node_map[n])])

        with tf.device('/gpu:0'):
            tf.import_graph_def(keep, name='')
        with tf.device('/cpu:0'):
            tf.import_graph_def(remove, name='')

通过将操作正确分配给 GPU 和 CPU,减少了处理每张图像所需的时间。我的基本想法是它试图在 CPU 和 GPU 上分配操作,但对这两个图的解释、它们的结构和工作方式将非常有帮助。谢谢!

我对这段代码的理解是:

  • 它创建了一个包含 2 个占位符 'Postprocessor/convert_scores''Postprocessor/ExpandDims_1' 的图表。
  • 将其转换为 graph_def 并保留与占位符相对应的节点。

    • 这2个节点对应模型输出的1917个框,第一个是class概率,第二个是框坐标。
  • 它创建第二个图并加载经过训练的模型。

  • 它列出了图中的所有节点以及它们之间的连接方式。
  • 列出连接到'Postprocessor/convert_scores''Postprocessor/ExpandDims_1'的所有节点,并将它们存储在keep list中。
  • 列出所有不在nodes_to_keep_list中的节点,并存储在nodes_to_remove_list中。

  • 然后它创建一个图形定义并用所有 nodes_to_keep_list 节点的副本填充它。

  • 然后是第二个图形定义,其中包含所有 nodes_to_remove_list 节点的副本。

  • 最后加载两个图形定义,第一个带有设备 '/gpu:0',第二个 '/cpu:0'.

正如作者所说,这样做的目的是 运行 GPU 上的 CNN,以及 CPU 上的 post 处理,因为那里的速度要快得多. 如果您查看 mobilenet+SSD,您会看到该模型输出了一堆方框 (1917),然后在这些方框上进行了相当复杂的(至少从图形的角度来看)post 处理以提供最终输出(detection_boxesdetection_scoresdetection_classesnum_detections)。

在这段代码中看不到,但占位符后来用于将keep图的输出插入到remove图中。执行分两步执行(2次调用sess.run()

(score, expand) = sess.run([score_out, expand_out], feed_dict={image_tensor: image_np_expanded})
(boxes, scores, classes, num) = sess.run(
      [detection_boxes, detection_scores, detection_classes, num_detections],
      feed_dict={score_in:score, expand_in: expand})
print 'Iteration %d: %.3f sec'%(i, time.time()-start_time)

编辑

1917 值来自原始图表,不同的模型将是不同的值,但即使是不同的节点等......这就是为什么这个解决方案比真正的解决方案更像是一个 hack,因为它需要针对您要应用的每个新模型进行定制...

我刚才看了这张图,我认为该模型输出了一堆特定尺寸或纵横比的盒子,以及另一堆不同纵横比的盒子,所有这些都合并在一起,你最终得到这个 1917 方框图。

ExpandDims 只是操作的名称,因为它没有在图中命名。 _1 在那里,因为可能在此范围内的图中已经有一个。至于为什么具体是这些节点,那只是作者在调查这些性能问题后随意选择的。基本上缓慢的部分是在这些节点之后。但是他可以选择稍微不同的节点,例如在 ExpandDims 操作之前它会执行相同的操作。这些特定节点的实际目的与他在这里所做的事情无关。在 ExpandDims 的情况下,这是一个非常普通的操作,只需添加一个维度 1.