为什么 Tensorflow Function 会对函数的不同整数输入执行回溯?

Why does Tensorflow Function perform retracing for different integer inputs to the function?

我正在遵循关于函数的 Tensorflow 指南 here,并且根据我的理解,TF 将为每次调用具有不同输入签名(即数据类型和形状)的函数跟踪并创建一个图输入)。但是,以下示例使我感到困惑。由于两个输入都是整数并且具有完全相同的形状,TF 不应该只执行一次跟踪并构建图形吗?为什么调用函数时两次都发生跟踪?

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)


# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))

输出:

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

数字 2 和 3 被视为不同的整数值,这就是您看到“Tracing!”的原因。两次。您所指的行为:“TF 将为每次调用具有不同输入签名(即数据类型和输入形状)的函数跟踪并创建一个图形”适用于张量而不是简单的数字。您可以通过将两个数字转换为张量常数来验证:

import tensorflow as tf

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

print(a_function_with_python_side_effect(tf.constant(2)))
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)

这是混合 python 标量和 tf.function 时的副作用。查看跟踪规则 here。在那里你读到:

The cache key generated for a tf.Tensor is its shape and dtype.

The cache key generated for a Python primitive (like int, float, str) is its value.