Tensorflow Estimator:在单独的脚本中使用 predict() 函数

Tensorflow Estimator: using predict() function in separate script

我已经成功地(我希望)使用 tf.Estimator 训练和评估了一个模型,我达到了大约 83-85% 的 train/eval 准确度。所以现在,我想使用 Estimator class 中的 predict() 函数调用在单独的数据集上测试我的模型。我最好在单独的脚本中执行此操作。

我在 which says that I need to export as a SavedModel, but is this really necessary? Looking at the documentation 的 Estimator class,似乎我可以通过 model_dir 参数将路径传递到我的检查点和图形文件。有没有人有这方面的经验?当我 运行 我的模型在我用于验证的同一数据集上时,我没有获得与验证阶段相同的性能...... :-(

我认为您只需要一个包含您的 model_fn 定义的单独文件。比起在另一个脚本中使用相同的 model_fn 定义和相同的 model_dir.

实例化相同的估计器 class

这是有效的,因为 Estimator API 会自行恢复 tf.Graph 定义和最新的 model.ckpt 文件,因此您可以继续训练、评估和预测。