使用 tf.function 的 Tensorflow 2.0 模型非常慢,每次列车计数发生变化时都会重新编译。 Eager 的运行速度提高了大约 4 倍

Tensorflow 2.0 model using tf.function very slow and is recompiling every time the train count changes. Eager runs about 4x faster

我有使用未编译的 keras 代码构建的模型,我正在尝试通过自定义训练循环 运行 它们。

TF 2.0 eager(默认)代码 运行s 在 CPU(笔记本电脑)上大约 30s。当我使用包装的 tf.function 调用方法创建一个 keras 模型时,它 运行 慢得多,而且似乎需要很长时间才能启动,尤其是 "first" 时间。

例如,在 tf.function 代码中,10 个样本的初始训练需要 40 秒,10 个样本的后续训练需要 2 秒。

20个样本,初始耗时50s,后续耗时4s。

1 个样本的第一个训练需要 2 秒,后续需要 200 毫秒。

所以看起来每次调用 train 都在创建一个新图,其中复杂性随着 train 数量的增加而增加!?

我正在做这样的事情:

@tf.function
def train(n=10):
    step = 0
    loss = 0.0
    accuracy = 0.0
    for i in range(n):
        step += 1
        d, dd, l = train_one_step(model, opt, data)
        tf.print(dd)
        with tf.name_scope('train'):
            for k in dd:
                tf.summary.scalar(k, dd[k], step=step)
        if tf.equal(step % 10, 0):
            tf.print(dd)
    d.update(dd)
    return d

根据示例,模型 keras.model.Model 使用 @tf.function 修饰 call 方法。

我在这里 Using a Python native type 分析了 @tf.function 的这种行为。

简而言之:tf.function 的设计不会自动将 Python 本机类型装箱到具有明确 dtypetf.Tensor 对象。

如果您的函数接受 tf.Tensor 对象,则在第一次调用该函数时会对其进行分析,并构建图表并将其与该函数相关联。在每个非第一次调用中,如果 tf.Tensor 对象的 dtype 匹配,则重新使用该图。

但是在使用 Python 本机类型的情况下,每次使用不同的值调用函数时都会构建图形

简而言之:如果您打算使用 @tf.function.

,请将您的代码设计为在任何地方都使用 tf.Tensor 而不是 Python 变量

tf.function 不是一个可以神奇地加速在急切模式下运行良好的函数的包装器;是一个包装器,需要设计 eager 函数(正文、输入参数、dytpes)以了解创建图形后会发生什么,以获得真正的加速。