Tensorflow 如何检查张量行是否仅为零?

Tensorflow how to check if a tensor row is only zeroes?

我正在训练一个简单的网络来预测单个对象的边界框坐标。然而,有些图片找不到对象。由于网络总是进行预测,因此它还会预测一个介于 0 和 1 之间的置信度值,该值应该表示图片中存在对象的概率。我的预测张量称为 logits,它是一个 (batch_size, 5) 张量(置信度、x、y、宽度和高度)。同样,labels 张量也是 (batch_size, 5).

以前我只训练总是有对象的图像,所以我基本上可以做到

loss = tf.l2_loss(logits - labels)

我也想用没有物体的图片开始训练,当图片中没有物体时,我不希望网络因其预测的坐标而受到惩罚。在这种情况下,重要的是置信度值,它应该接近 0(没有对象)。

我应该如何构建我的标签和损失函数来完成这个?我可以将没有对象的图像标签设置为全零,但如何检查特定行是否只有零?并且在这种情况下,logits中的相应行也需要设置为零(置信度值除外!)以便因坐标而产生的损失也为零。

一种查找张量的一行是否全为零的方法:

import tensorflow as tf

image = tf.fill([8,8], 0)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
image_row = tf.slice(image, [1,0], [1, -1])
total = tf.reduce_sum(tf.abs(image_row))
is_all_zero = tf.equal(total, 0)
print sess.run([total, is_all_zero, image_row])

结果:

[0, True, array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)]

您可以使用tf.math.count_nonzero() 来检查张量是否全为零。您可以查看指南 here.