Tensorflow Keras 模型和 Estimator 之间有什么区别?

What's the difference between a Tensorflow Keras Model and Estimator?

Tensorflow Keras 模型和 Tensorflow Estimator 都能够训练神经网络模型并使用它们来预测新数据。它们都是位于低级核心 TensorFlow API 之上的高级 API。那么我应该什么时候使用一个而不是另一个?

背景

Estimators API 在 1.1 版中被添加到 Tensorflow 中,并提供了对较低级别 Tensorflow 核心操作的高级抽象。它与 Estimator 实例一起使用,这是 TensorFlow 对完整模型的高级表示。

Keras is similar to the Estimators API in that it abstracts deep learning model components such as layers, activation functions and optimizers, to make it easier for developers. It is a model-level library, and does not handle low-level operations, which is the job of tensor manipulation libraries, or backends. Keras supports three backends - Tensorflow, Theano and CNTK.

Keras 直到 Release 1.4.0(2017 年 11 月 2 日)才成为 Tensorflow 的一部分。现在,当您使用 tf.keras(或谈论 'Tensorflow Keras')时,您只是简单地使用 Keras 接口和 Tensorflow 后端来构建和训练您的模型。

因此,Estimator API 和 Keras API 都提供了一个高级 API 而不是低级核心 Tensorflow API,您可以使用其中任何一个来训练你的模型。但在大多数情况下,如果您使用的是 Tensorflow,出于以下原因,您会希望使用 Estimators API。

分布

您可以使用 Estimator API 在多个服务器上进行分布式训练,但不能使用 Keras API。

来自Tensorflow Keras Guide,它说:

The Estimators API is used for training models for distributed environments.

Tensorflow Estimators Guide 中可以看出:

You can run Estimator-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Estimator-based models on CPUs, GPUs, or TPUs without recoding your model.

预制估算器

虽然 Keras 提供了使构建模型更容易的抽象,但您仍然需要编写代码来构建模型。借助 Estimators,Tensorflow 提供了预制 Estimators,您可以直接使用这些模型,只需插入超参数即可。

预制估算器类似于您使用 scikit-learn 中的 scikit-learn. For example, the tf.estimator.LinearRegressor from Tensorflow is similar to the sklearn.linear_model.LinearRegression 的方式。

与其他 Tensorflow 工具集成

Tensorflow 提供了一个名为 TensorBoard 的可视化工具,可帮助您可视化图表和统计数据。通过使用 Estimator,您可以轻松保存要使用 Tensorboard 可视化的摘要。

将 Keras 模型转换为 Estimator

要将 Keras 模型迁移到 Estimator,请使用 tf.keras.estimator.model_to_estimator 方法。

在我的理解中,estimator是为了大规模训练数据并服务于生产目的,因为cloud ML engine只能接受estimator。

下面来自 tensorflow doc 之一的描述提到了这一点:

” Estimators API 用于训练分布式环境的模型。这针对行业用例,例如可以导出生产模型的大型数据集的分布式训练。 “