我可以从估算器中获取 tensorflow 会话吗?
Can I get the tensorflow session from the estimator?
我正在使用 tf.estimator 的 LinearRegressor 并想将我的学习率衰减(最初是指数衰减)更改为使用损失的衰减。但是为此,我需要将评估损失传递给学习率衰减张量的一些占位符,并且在这一步中,我需要 tf.session.
我尝试tf.get_default_session()
获取估算器创建的会话,但该会话使用的图表与估算器不同。
def my_decay(learning_rate, global_step, decay_step, loss, decay_rate):
# If loss is not reduced, than decay with decay_rate.
loss = tf.placeholder(tf.float32)
estimator = tf.estimator.LinearRegressor(
feature_columns=feature_columns,
optimizer==lambda: tf.train.FtrlOptimizer(
learning_rate=my_decay(learning_rate=0.1,
global_step=tf.get_global_step(), decay_step=10000,
loss=loss, decay_rate=0.96)),
config=sess_config
)
for _ in range(n_epoches):
metrics = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
session.run(loss.assign(metrics['loss']))
使用上面的代码,我需要从估算器中获取 session
。
有什么办法可以得到这个吗?
提前致谢!
像这样的事情的预期解决方案是 subclass tf.train.SessionRunHook
and override the before_run
method to return a suitable tf.train.SessionRunArgs
。这将允许您在训练时提供值并将提取添加到 session.run
调用。您的 class 必须携带对占位符和 loss
状态 in-between 调用的引用。
然后您只需实例化 class 并将挂钩添加到 estimator.train
调用中的 hooks
参数,在本例中是您的 train_spec
。如果你想使用评估损失而不是训练损失,那么这可以通过向 eval_spec
添加另一个钩子来实现,该钩子读取 after_run
方法中的值。
我正在使用 tf.estimator 的 LinearRegressor 并想将我的学习率衰减(最初是指数衰减)更改为使用损失的衰减。但是为此,我需要将评估损失传递给学习率衰减张量的一些占位符,并且在这一步中,我需要 tf.session.
我尝试tf.get_default_session()
获取估算器创建的会话,但该会话使用的图表与估算器不同。
def my_decay(learning_rate, global_step, decay_step, loss, decay_rate):
# If loss is not reduced, than decay with decay_rate.
loss = tf.placeholder(tf.float32)
estimator = tf.estimator.LinearRegressor(
feature_columns=feature_columns,
optimizer==lambda: tf.train.FtrlOptimizer(
learning_rate=my_decay(learning_rate=0.1,
global_step=tf.get_global_step(), decay_step=10000,
loss=loss, decay_rate=0.96)),
config=sess_config
)
for _ in range(n_epoches):
metrics = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
session.run(loss.assign(metrics['loss']))
使用上面的代码,我需要从估算器中获取 session
。
有什么办法可以得到这个吗?
提前致谢!
像这样的事情的预期解决方案是 subclass tf.train.SessionRunHook
and override the before_run
method to return a suitable tf.train.SessionRunArgs
。这将允许您在训练时提供值并将提取添加到 session.run
调用。您的 class 必须携带对占位符和 loss
状态 in-between 调用的引用。
然后您只需实例化 class 并将挂钩添加到 estimator.train
调用中的 hooks
参数,在本例中是您的 train_spec
。如果你想使用评估损失而不是训练损失,那么这可以通过向 eval_spec
添加另一个钩子来实现,该钩子读取 after_run
方法中的值。