在 TensorFlow 中使用 MonitoredTrainingSession 与 Estimator 的原因是什么

What are the reasons to use MonitoredTrainingSession vs Estimator in TensorFlow

我看到很多使用 MonitoredTrainingSessiontf.Estimator 作为训练框架的例子。但是不清楚为什么我会使用一个而不是另一个。两者都可以使用 SessionRunHooks 配置。两者都与 tf.data.Dataset 迭代器集成,并且可以提供 training/val 数据集。我不确定一种设置的好处是什么。

简短的回答是 MonitoredTrainingSession 允许用户访问图形和会话对象以及训练循环,而 Estimator 向用户隐藏图形和会话的详细信息,并且通常更容易运行 培训,尤其是 train_and_evaluate,如果您需要定期评估。

MonitoredTrainingSession 与普通的 tf.Session() 的不同之处在于它处理变量初始化、设置文件编写器以及合并分布式训练的功能。

另一方面,

Estimator APIKeras 一样是一个高级结构。它可能在示例中使用得较少,因为它是后来介绍的。它还允许使用 DistibutedStrategy 分发 training/evaluation,并且它有几个允许快速原型制作的固定估计器。

在模型定义方面,它们非常相似,都允许使用 keras.layers,或者从头开始定义完全自定义的模型。因此,无论出于何种原因,如果您需要访问图形构造或自定义训练循环,请使用 MonitoredTrainingSession。如果您只想定义模型、训练它、运行 验证和预测而无需额外的复杂性和样板代码,请使用 Estimator