tensorflow - 使用估算器实现经验回放记忆 api

tensorflow - implementing experience replay memory with the estimator api

我尝试实现一个 experience replay memory with the tf.estimator.Estimator API。但是,我不确定获得至少适用于所有模式(TRAINEVALUATEPREDICT)的结果的最佳方法是什么。我尝试了以下方法:

目前正在尝试:

我在几个方面都失败了,开始相信 tf.estimator.Estimator API 没有为我提供必要的接口来轻松地写下它。

一些代码(第一种方法,因 batch_size 而失败,因为它已固定用于 exp 的切片,我无法使用该模型进行预测或评估):

 def model_fn(self, features, labels, mode, params):
    batch_size = features["matrix"].get_shape()[0].value

    # get prev_exp
    if mode == tf.estimator.ModeKeys.TRAIN:
        erm = tf.get_variable("erm", shape=[30000, 10], initializer=tf.constant_initializer(self.erm.initial_train_erm()), trainable=False)
        prev_exp = tf.slice(erm, [features["index"][0], 0], [batch_size, 10])

    # model
    pred = model(features["matrix"], prev_exp, params) 

但是:最好将 erm 放在功能字典中。但随后我必须在图表之外管理 erm,并写回我使用 SessionRunHook 的最新体验。有没有更好的方法或者我错过了什么?

我通过在图外实施 ERM、使用 tf.data.Dataset.from_generator() 将其反馈回输入管道并使用 SessionRunHooks 回写来解决我的问题。是的,很乏味,但它确实有效。