tf.where() 在处理张量时表现不佳
tf.where() not behaving as expected for manipulating tensors
我试过以下代码:
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19))
不会产生与以下相同的结果:
a = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19)
这里 x 是二进制变量 0 或 1。b 是介于 0 和 1 之间的实数。
有什么我遗漏的吗?
我比较 2 个答案的方式是 tf.reduce_sum(a)
找到解决方案:对于 x = 0 或 x = 1,这 2 个确实等价。我使用的数据是一个二维张量,其中有些位不是 0 或 1。
这是通过发现的
tf.unique(tf.reshape(x, (-1,))
代码示例:
# When x = 0.0
x = 0.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # 0 + (1-x)*tf.math.log(1 - b + 1e-19) = -0.6931472
# When x = 1.0
x = 1.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from x*tf.math.log(b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + 0 = -0.6931472
# When x = 0.4
x = 0.4
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.41588834 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) = -0.27725 + -0.41588 = -0.6931471824645996.
您问题中提到的两种代码产生相同结果的唯一情况是 x = 1 或 0。
tf.reduce_sum(a)
这里,a 是一个标量,所以这不会改变 a 的值。
我试过以下代码:
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19))
不会产生与以下相同的结果:
a = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19)
这里 x 是二进制变量 0 或 1。b 是介于 0 和 1 之间的实数。
有什么我遗漏的吗?
我比较 2 个答案的方式是 tf.reduce_sum(a)
找到解决方案:对于 x = 0 或 x = 1,这 2 个确实等价。我使用的数据是一个二维张量,其中有些位不是 0 或 1。
这是通过发现的
tf.unique(tf.reshape(x, (-1,))
代码示例:
# When x = 0.0
x = 0.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # 0 + (1-x)*tf.math.log(1 - b + 1e-19) = -0.6931472
# When x = 1.0
x = 1.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from x*tf.math.log(b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + 0 = -0.6931472
# When x = 0.4
x = 0.4
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.41588834 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) = -0.27725 + -0.41588 = -0.6931471824645996.
您问题中提到的两种代码产生相同结果的唯一情况是 x = 1 或 0。
tf.reduce_sum(a)
这里,a 是一个标量,所以这不会改变 a 的值。