当两个函数相互调用时,tf.function 是如何工作的

How does tf.function works when two functions call each other

我使用 tensorflow==1.14 和 tf.enable_eager_execution() 构建我的模型,如下所示:

class Model:
  def __init__(self):
    self.embedding = tf.keras.layers.Embedding(10, 15)
    self.dense = tf.keras.layers.Dense(10)

  @tf.function
  def inference(self, inp):
    print('call function: inference')
    inp_em = self.embedding(inp)
    inp_enc = self.dense(inp_em)

    return inp_enc

  @tf.function
  def fun(self, inp):
    print('call function: fun')
    return self.inference(inp)

model = Model()

当我第一次运行下面的代码时:

a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))

输出是

call function: fun
call function: inference
call function: inference
====================
call function: inference

好像是tensorflow为推理函数建立了三个图,我怎么只为推理函数建立一个图。 我还想知道当两个函数相互调用时 tf.function 是如何工作的。这是构建模型的正确方法吗?

有时 tf.function 的执行方式会给我们带来一些困惑 - 特别是当我们混合 python 等普通操作时,例如 print()

我们应该记住,当我们用 tf.function 装饰一个函数时,它不再是 只是 一个 python 函数。它的行为略有不同,以便在 TF 中快速高效地使用。绝大多数时候,行为的变化几乎不明显(除了速度提高!)但偶尔我们会遇到像这样的细微差别。

首先要注意的是,如果我们使用 tf.print() 代替 print() 那么我们会得到预期的输出:

class Model:
  def __init__(self):
    self.embedding = tf.keras.layers.Embedding(10, 15)
    self.dense = tf.keras.layers.Dense(10)

  @tf.function
  def inference(self, inp):
    tf.print('call function: inference')
    inp_em = self.embedding(inp)
    inp_enc = self.dense(inp_em)

    return inp_enc

  @tf.function
  def fun(self, inp):
    tf.print('call function: fun')
    return self.inference(inp)

model = Model()

a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))

输出:

call function: fun
call function: inference
====================
call function: inference

如果您的问题是现实世界问题的征兆,这可能就是解决方法!

所以这是怎么回事?

好吧,我们第一次调用装饰有 tf.function 的函数时,tensorflow 将构建一个执行图。为了做到这一点,它 "traces" 由 python 函数执行的 tensorflow 操作。

为了进行这种跟踪,tensorflow 可能会多次调用装饰函数

这意味着仅 python 操作(例如 print() 可以执行多次)但 tf 操作(例如 tf.print() 将按您通常预期的方式运行。

这种细微差别的副作用是我们应该知道 tf.function 装饰函数如何处理状态,但这不在您的问题范围内。有关详细信息,请参阅 original RFC and this github issue

And I also want to know how tf.function woks when two functions call each other. Is this the right way to build my model?

一般来说,我们需要只用tf.function装饰"outer"函数(在你的例子中是.fun())但是如果你可以调用内部功能也直接那么你也可以自由装饰它。