Tensorflow.py 保护分区

Tensorflow.py Protected division

我正在尝试使用 Tensorflow.where 实现一种受保护的除法,但不知何故它似乎跳过了 where 语句中设置的条件。

主要思想是,当划分 x/y 时,如果 y == 0. 那么划分的结果是 x 而不是抛出和错误。

我的代码如下:

def Pdivide(x,y):
    result = tf.where(y == 0., x, x/y) 
    return result

但是以某种方式跳过了该条件:

>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])

>>>Pdivide(a,b)

>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)

预期输出:

>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)

PS:使用eager执行。

好的,答案显然很简单。

由于某些原因,我无法将张量元素与简单的 == 进行比较,但使用 tf.equal(y, 0.) 可以解决问题并产生正确的输出。