Tensorflow.py 保护分区
Tensorflow.py Protected division
我正在尝试使用 Tensorflow.where
实现一种受保护的除法,但不知何故它似乎跳过了 where
语句中设置的条件。
主要思想是,当划分 x/y
时,如果 y == 0.
那么划分的结果是 x
而不是抛出和错误。
我的代码如下:
def Pdivide(x,y):
result = tf.where(y == 0., x, x/y)
return result
但是以某种方式跳过了该条件:
>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])
>>>Pdivide(a,b)
>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)
预期输出:
>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)
PS:使用eager
执行。
好的,答案显然很简单。
由于某些原因,我无法将张量元素与简单的 ==
进行比较,但使用 tf.equal(y, 0.)
可以解决问题并产生正确的输出。
我正在尝试使用 Tensorflow.where
实现一种受保护的除法,但不知何故它似乎跳过了 where
语句中设置的条件。
主要思想是,当划分 x/y
时,如果 y == 0.
那么划分的结果是 x
而不是抛出和错误。
我的代码如下:
def Pdivide(x,y):
result = tf.where(y == 0., x, x/y)
return result
但是以某种方式跳过了该条件:
>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])
>>>Pdivide(a,b)
>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)
预期输出:
>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)
PS:使用eager
执行。
好的,答案显然很简单。
由于某些原因,我无法将张量元素与简单的 ==
进行比较,但使用 tf.equal(y, 0.)
可以解决问题并产生正确的输出。