tensorflow 在计算得到最大框后记住索引

tensorflow remember the index after calculating getting the maximum box

假设我有两个盒子数组,每个盒子的形状分别为(?, b1, 4)(?, b2, 4)(将?视为未知批量大小):

box1: [[[1,2,3,4], [2,3,4,5], [3,4,5,6]...]...]
box2: [[[4,3,2,1], [3,2,5,4], [4,3,5,6]...]...]

(以上数字任意设置)

我想:

  1. 在每个批次中,对于box1中的每个框A,在box2中找到框B A 具有最大的 IOU(并集交集)(当然在同一批中),然后附加元组 (A, B) 到列表 list_max.

  2. 追加到list_nonmax box2中所有与[=14=中的任何框没有最大IOU的框](当然要分批)

您可以假设:

  1. b1b2都是python变量,不是tensorflow张量。

  2. 计算单个框之间或批次框之间的IOU的方法已经存在,可以按字面意思使用:

    iou_single_box(box1, box2)box1box2 的形状都是 (4,).

    iou_multiple_boxes(bbox1, bbox2)bbox1bbox2 的形状分别为 (b1, 4)(b2, 4) .

    iou_batch_boxes(bbbox1, bbbox2) : bbbox1bbbox2 的形状分别是 (?, b1, 4)(?, b2, 4) (将 ? 视为未知批量大小)。

我发现这些在 tensorflow 中特别难,特别是对于 list_nonmax 的情况,因为,虽然使用填充然后 tf.reduce_max() 很容易得到最大iou的框元组,不可能记住它们的索引,然后提取出list_nonmax.

的框

为此你需要 tf.nn.top_k()。它 returns 最大值 它在最后一个维度的索引。

val, idx = tf.nn.top_k( iou_batch_boxes( bbbox1, bbbox2 ), k = 1 )

将为您提供 box2 索引以及每个 box1 和批次的最大 iou。

为了得到你的 list_max 你需要 tf.stack() box1 with box2's entries by idx with tf.gather_nd() 沿轴 1。这是一个带有虚拟 iou 函数的工作代码:

import tensorflow as tf

box1 = tf.reshape( tf.constant( range( 16 ), dtype = tf.float32 ), ( 2, 2, 4 ) )
box2 = tf.reshape( tf.constant( range( 2, 26 ), dtype = tf.float32 ), ( 2, 3, 4 ) )
batch_size = box1.get_shape().as_list()[ 0 ]

def dummy_iou_batch_boxes( box1, box2 ):
    b1s, b2s = box1.get_shape().as_list(), box2.get_shape().as_list()
    return tf.constant( [ [ [9.0,8,7], [1,2,3],
                            [0  ,1,2], [0,5,0] ] ] )

iou = dummy_iou_batch_boxes( box1, box2 )
val, idx = tf.nn.top_k( iou, k = 1 )
idx = tf.reshape( idx, ( batch_size, box1.get_shape().as_list()[ 1 ] ) )
one_hot_idx = tf.one_hot( idx, depth = box2.get_shape().as_list()[ 1 ] )
full_idx = tf.where( tf.equal( 1.0, one_hot_idx ) )
box1_idx = full_idx[ :, 0 : 2 ]
box2_idx = full_idx[ :, 0 : 3 : 2 ]
box12 = tf.gather_nd( box1, box1_idx )
box22 = tf.gather_nd( box2, box2_idx )
list_max = tf.stack( [ box12, box22 ], axis = 1 )

with tf.Session() as sess:
    res = sess.run( [ list_max ] )
    for v in res:
        print( v )
        print( "-----------------------------")

将输出:

[[[ 0. 1. 2. 3.]
[ 2. 3. 4. 5.]]

[[ 4. 5. 6. 7.]
[10. 11. 12. 13.]]

[[ 8. 9. 10. 11.]
[22. 23. 24. 25.]]

[[12. 13. 14. 15.]
[18. 19. 20. 21.]]]

如果你想把它作为列表或元组,你可以在上面list_max.

上使用tf.unstack()

要获得 list_nonmax 你需要的是将索引合并到一个掩码中,我相信我已经已在 中回答,但重要的部分是:

mask = tf.reduce_max( tf.one_hot( idx, depth = num_bbbox2 ), axis = -2 )

这会给你一个形状为 ( batch, num_box2 ) 的掩码,告诉你每个批次和每个box2 如果 box2any 的最大借据 box1.

从这里,您可以使用掩码或使用 tf.where() 获取索引列表,如下所示:

was_never_max_idx = tf.where( tf.equal( 0, mask ) )