尝试理解 AutoGraph 和 tf.function:打印损失 tf.function

Try to understand AutoGraph and tf.function: print loss in tf.function

def train_one_step():
    with tf.GradientTape() as tape:
        a = tf.random.normal([1, 3, 1])
        b = tf.random.normal([1, 3, 1])
        loss = mse(a, b)

    tf.print('inner tf print', loss)
    print("inner py print", loss)

    return loss


@tf.function
def train():
    loss = train_one_step()

    tf.print('outer tf print', loss)
    print('outer py print', loss)

    return loss

loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)

我正在尝试进一步了解 tf.functional。我用不同的方法在四个地方打印了损失。它会产生这样的结果

inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
  1. tf.print 和 python 打印有什么区别?
  2. 看起来 python print 将在图形评估期间执行,但 tf 仅在执行时打印?
  3. 以上仅适用于有 tf.function 装饰器的情况?除此之外,tf.print 在 python print?
  4. 之前运行

print 是正常的 python 打印。 tf.print 是张量流图的一部分。 在急切模式下,tensorflow 将直接执行图形。这就是为什么在 @tf.function 函数之外, python 打印的输出是一个数字(tensorflow 直接执行图形并将数字给出给普通打印函数),这也是为什么 tf.print 立即打印。

另一方面,在 @tf.function 函数内,tensorflow 不会立即执行图形。相反,它将 "stack" 您调用的 tensorflow 函数转换成一个更大的图形,我们将在 @tf.function 结束时一次性执行。

这就是为什么 python 打印不给你 @tf.function 函数内的数字的原因(此时图形尚未执行)。但是函数结束后,图形正在执行,连同图形中的tf.print。因此 tf.print 在 python 打印之后打印并为您提供实际损失数字。

我在一篇 three-part 文章中涵盖并回答了您的所有问题:"Analyzing tf.function to discover AutoGraph strengths and subtleties":part 1, part 2, part 3

总结并回答您的 3 个问题:

  • tf.print 和 python 打印有什么区别?

tf.print 是一个 Tensorflow 构造,默认情况下在 标准错误 上打印,更重要的是,它在评估时产生一个操作。

当一个操作是 运行 时,在急切执行中,它或多或少地以与 Tensorflow 1.x.

相同的方式产生 "node"

tf.function能够捕获tf.print的生成操作并将其转换为图形节点。

相反,print 是一个 Python 构造,默认情况下在标准输出上打印,执行时不会 生成任何操作。因此,tf.function 无法将其转换为等效图形,只能在函数跟踪期间执行。

  • 看起来 python 打印将在图形评估期间执行,但 tf 打印仅在执行时?

我已经在上一点回答了这个问题,但是再一次,print 只在函数跟踪期间执行,而 tf.print 在跟踪期间和它的 graph-representation 被执行(在 tf.function 成功将函数转换为图形后)。

  • 以上仅适用于有tf.function装饰器的情况?除此之外,tf.print 运行s before python print?

是的。 tf.print 不会在 print 之前或之后 运行。在 eager execution 中,一旦 Python 解释器找到语句,它们就会被评估。急切执行的唯一区别是输出流。

无论如何,我建议您阅读链接的三篇文章,因为它们详细介绍了 tf.function 的这一点和其他特点。