tf.where() 在处理张量时表现不佳

tf.where() not behaving as expected for manipulating tensors

我试过以下代码:

a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19))

不会产生与以下相同的结果:

a = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19)

这里 x 是二进制变量 0 或 1。b 是介于 0 和 1 之间的实数。

有什么我遗漏的吗?
我比较 2 个答案的方式是 tf.reduce_sum(a)

找到解决方案:对于 x = 0 或 x = 1,这 2 个确实等价。我使用的数据是一个二维张量,其中有些位不是 0 或 1。 这是通过发现的 tf.unique(tf.reshape(x, (-1,))

代码示例:

# When x = 0.0 

x = 0.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # 0 + (1-x)*tf.math.log(1 - b + 1e-19) = -0.6931472

# When x = 1.0 

x = 1.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from x*tf.math.log(b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + 0 = -0.6931472


# When x = 0.4 

x = 0.4
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.41588834 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) +  (1-x)*tf.math.log(1 - b + 1e-19) = -0.27725 + -0.41588 = -0.6931471824645996.

您问题中提到的两种代码产生相同结果的唯一情况是 x = 1 或 0。

tf.reduce_sum(a)

这里,a 是一个标量,所以这不会改变 a 的值。