相等比较在 TensorFlow 2.0 中不起作用 tf.function()

Equality comparison does not work inside TensorFlow 2.0 tf.function()

TensorFlow 2.0 AutoGraphs 的讨论之后,我一直在玩弄并注意到 >< 等不等式比较是直接指定的,而等式比较则使用 tf.equal.

这里有一个例子来演示。此函数使用 > 运算符,并且 在调用时效果很好

@tf.function
def greater_than_zero(value):
    return value > 0

greater_than_zero(tf.constant(1))
#  <tf.Tensor: id=1377, shape=(), dtype=bool, numpy=True>
greater_than_zero(tf.constant(-1))
# <tf.Tensor: id=1380, shape=(), dtype=bool, numpy=False>

这里还有一个函数使用相等比较,但是不起作用:

@tf.function
def equal_to_zero(value):
    return value == 0

equal_to_zero(tf.constant(1))
# <tf.Tensor: id=1389, shape=(), dtype=bool, numpy=False>  # OK...

equal_to_zero(tf.constant(0))
# <tf.Tensor: id=1392, shape=(), dtype=bool, numpy=False>  # WHAT?

如果我将 == 等式比较更改为 tf.equal,它将起作用。

@tf.function
def equal_to_zero2(value):
    return tf.equal(value, 0)

equal_to_zero2(tf.constant(0))
# <tf.Tensor: id=1402, shape=(), dtype=bool, numpy=True>

我的问题是:为什么在 tf.function 函数中使用不等式比较运算符有效,而等式比较却不行?

我在文章 "Analysing tf.function to discover Autograph strengths and subtleties" 的第 3 部分分析了这种行为(我强烈建议阅读所有 3 部分以了解如何在使用 tf.function 修饰函数之前正确编写函数 - 链接位于答案的底部)。

对于__eq__tf.equal的问题,答案是:

In short: the __eq__ operator (for tf.Tensor) has been overridden, but the operator does not use tf.equal to check for the Tensor equality, it just checks for the Python variable identity (if you are familiar with the Java programming language, this is precisely like the == operator used on string objects). The reason is that the tf.Tensor object needs to be hashable since it is used everywhere in the Tensorflow codebase as key for dict objects.

虽然对于所有其他运算符,答案是 AutoGraph 不会将 Python 运算符转换为 TensorFlow 逻辑运算符。在 How AutoGraph (don’t) converts the operators 部分中,我展示了每个 Python 运算符都会转换为始终被评估为 false 的图形表示。

事实上,下面的例子产生了输出"wat"

@tf.function
def if_elif(a, b):
  if a > b:
    tf.print("a > b", a, b)
  elif a == b:
    tf.print("a == b", a, b)
  elif a < b:
    tf.print("a < b", a, b)
  else:
    tf.print("wat")
x = tf.constant(1)
if_elif(x,x)

实际上,AutoGraph 无法将 Python 代码转换为图形代码;我们必须只使用 TensorFlow 原语来帮助它。在这种情况下,您的代码将按预期工作。

@tf.function
def if_elif(a, b):
  if tf.math.greater(a, b):
    tf.print("a > b", a, b)
  elif tf.math.equal(a, b):
    tf.print("a == b", a, b)
  elif tf.math.less(a, b):
    tf.print("a < b", a, b)
  else:
    tf.print("wat")

我在这里放了所有三篇文章的链接,我想你会发现它们很有用:

part 1, part 2, part 3