tensorflow - 使用估算器实现经验回放记忆 api
tensorflow - implementing experience replay memory with the estimator api
我尝试实现一个 experience replay memory with the tf.estimator.Estimator API。但是,我不确定获得至少适用于所有模式(TRAIN
、EVALUATE
、PREDICT
)的结果的最佳方法是什么。我尝试了以下方法:
- 使用
tf.Variable
实现内存,这会导致批处理和输入管道出现问题(我无法在测试或预测阶段输入自定义体验)
目前正在尝试:
- 实现
tf.Graph
之外的内存。使用 tf.train.SessionRunHook
在每个 运行 之后设置值。在训练和测试期间使用 tf.data.Dataset.from_generator()
加载经验。自己管理状态。
我在几个方面都失败了,开始相信 tf.estimator.Estimator API 没有为我提供必要的接口来轻松地写下它。
一些代码(第一种方法,因 batch_size 而失败,因为它已固定用于 exp 的切片,我无法使用该模型进行预测或评估):
def model_fn(self, features, labels, mode, params):
batch_size = features["matrix"].get_shape()[0].value
# get prev_exp
if mode == tf.estimator.ModeKeys.TRAIN:
erm = tf.get_variable("erm", shape=[30000, 10], initializer=tf.constant_initializer(self.erm.initial_train_erm()), trainable=False)
prev_exp = tf.slice(erm, [features["index"][0], 0], [batch_size, 10])
# model
pred = model(features["matrix"], prev_exp, params)
但是:最好将 erm 放在功能字典中。但随后我必须在图表之外管理 erm,并写回我使用 SessionRunHook 的最新体验。有没有更好的方法或者我错过了什么?
我通过在图外实施 ERM、使用 tf.data.Dataset.from_generator() 将其反馈回输入管道并使用 SessionRunHooks 回写来解决我的问题。是的,很乏味,但它确实有效。
我尝试实现一个 experience replay memory with the tf.estimator.Estimator API。但是,我不确定获得至少适用于所有模式(TRAIN
、EVALUATE
、PREDICT
)的结果的最佳方法是什么。我尝试了以下方法:
- 使用
tf.Variable
实现内存,这会导致批处理和输入管道出现问题(我无法在测试或预测阶段输入自定义体验)
目前正在尝试:
- 实现
tf.Graph
之外的内存。使用tf.train.SessionRunHook
在每个 运行 之后设置值。在训练和测试期间使用tf.data.Dataset.from_generator()
加载经验。自己管理状态。
我在几个方面都失败了,开始相信 tf.estimator.Estimator API 没有为我提供必要的接口来轻松地写下它。
一些代码(第一种方法,因 batch_size 而失败,因为它已固定用于 exp 的切片,我无法使用该模型进行预测或评估):
def model_fn(self, features, labels, mode, params):
batch_size = features["matrix"].get_shape()[0].value
# get prev_exp
if mode == tf.estimator.ModeKeys.TRAIN:
erm = tf.get_variable("erm", shape=[30000, 10], initializer=tf.constant_initializer(self.erm.initial_train_erm()), trainable=False)
prev_exp = tf.slice(erm, [features["index"][0], 0], [batch_size, 10])
# model
pred = model(features["matrix"], prev_exp, params)
但是:最好将 erm 放在功能字典中。但随后我必须在图表之外管理 erm,并写回我使用 SessionRunHook 的最新体验。有没有更好的方法或者我错过了什么?
我通过在图外实施 ERM、使用 tf.data.Dataset.from_generator() 将其反馈回输入管道并使用 SessionRunHooks 回写来解决我的问题。是的,很乏味,但它确实有效。