Tensorflow 估计器:无需每次从检查点加载即可进行预测

Tensorflow estimator: predict without loading from checkpoint everytime

我在 Tensorflow(1.8) 和 python3.6 中使用估算器来为我的强化学习项目构建神经网络。我注意到每次使用 estimator.predict() 时,tensorflow 都会加载 model_dir 下的检查点。但是,如果您必须对同一个检查点多次使用此功能,则效率非常低,例如在强化学习中,我可能需要根据当前状态预测下一个动作,而下一个状态只有在您选择特定动作后才会实现。所以调用这个函数几千次是家常便饭

所以我的问题是,如何在每次不加载检查点(相同检查点)的情况下调用此函数。

谢谢。

好吧,我想我刚刚找到了一个很好的答案来解决我自己的问题。一个很好的解决这个问题的方法是通过generator构造一个tf.dataset。 Link is here.

生成器将使您的 estimator.predict 保持打开状态,这样您就无需继续加载检查点。您唯一需要做的就是在必要时更改此 fastpredict 对象(在本例中为 self.next_feature)中的 yield 对象。

但是,如果您的最终目标是将整个事情变成服务或类似的东西,我需要提一下。您可能需要 tf.serving 之类的东西。所以我建议你直接走那条路。我在这个过程中浪费了很多时间。所以我希望这个答案能帮助你拯救你的。