Tensorflow 估计器 - warm_start_from 和 model_dir
Tensorflow Estimator - warm_start_from and model_dir
当 tf.estimator
与 warm_start_from
和 model_dir
以及 warm_start_from
目录和 model_dir
目录一起使用时包含有效检查点,实际恢复哪个检查点?
为了提供一些背景信息,我的估算器代码如下所示
est = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
warm_start_from=warm_start_dir)
for epoch in range(num_epochs):
est.train(input_fn=train_input_fn)
est.evaluate(input_fn=eval_input_fn)
(输入函数使用一次性迭代器。)
所以在第一次迭代期间,当 model_dir
为空时,我希望加载热启动检查点,但在下一个时期,我希望从要加载的 model_dir
中的最后一次迭代。但至少从日志来看,warm_start_dir
似乎仍在加载中。
我可能会在下一次迭代中覆盖我的估算器,但我想知道它是否不应该以某种方式构建在估算器中。
我遇到过类似的问题,我通过在会话启动时提供一个 运行 的初始化挂钩并使用 tf.estimator.train_and_evaluate
解决了这个问题(尽管我不能接受整个解决方案的功劳,因为我在其他地方看到了类似的用于另一个目的的东西):
class InitHook(tf.train.SessionRunHook):
"""initializes model from a checkpoint_path
args:
modelPath: full path to checkpoint
"""
def __init__(self, checkpoint_dir):
self.modelPath = checkpoint_dir
self.initialized = False
def begin(self):
"""
Restore encoder parameters if a pre-trained encoder model is available and we haven't trained previously
"""
if not self.initialized:
log = logging.getLogger('tensorflow')
checkpoint = tf.train.latest_checkpoint(self.modelPath)
if checkpoint is None:
log.info('No pre-trained model is available, training from scratch.')
else:
log.info('Pre-trained model {0} found in {1} - warmstarting.'.format(checkpoint, self.modelPath))
tf.train.warm_start(checkpoint)
self.initialized = True
然后,训练:
initHook = InitHook(checkpoint_dir = warm_start_dir)
trainSpec = tf.estimator.TrainSpec(
input_fn = train_input_fn,
max_steps = N_STEPS,
hooks = [initHook]
)
evalSpec = tf.estimator.EvalSpec(
input_fn = eval_input_fn,
steps = None,
name = 'eval',
throttle_secs = 3600
)
tf.estimator.train_and_evaluate(estimator, trainSpec, evalSpec)
这运行s 在开始时初始化一次来自warm_start_dir
的变量。稍后,当估计器 model_dir
中有新的检查点时,它会从那里继续 warm_starting。
当 tf.estimator
与 warm_start_from
和 model_dir
以及 warm_start_from
目录和 model_dir
目录一起使用时包含有效检查点,实际恢复哪个检查点?
为了提供一些背景信息,我的估算器代码如下所示
est = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
warm_start_from=warm_start_dir)
for epoch in range(num_epochs):
est.train(input_fn=train_input_fn)
est.evaluate(input_fn=eval_input_fn)
(输入函数使用一次性迭代器。)
所以在第一次迭代期间,当 model_dir
为空时,我希望加载热启动检查点,但在下一个时期,我希望从要加载的 model_dir
中的最后一次迭代。但至少从日志来看,warm_start_dir
似乎仍在加载中。
我可能会在下一次迭代中覆盖我的估算器,但我想知道它是否不应该以某种方式构建在估算器中。
我遇到过类似的问题,我通过在会话启动时提供一个 运行 的初始化挂钩并使用 tf.estimator.train_and_evaluate
解决了这个问题(尽管我不能接受整个解决方案的功劳,因为我在其他地方看到了类似的用于另一个目的的东西):
class InitHook(tf.train.SessionRunHook):
"""initializes model from a checkpoint_path
args:
modelPath: full path to checkpoint
"""
def __init__(self, checkpoint_dir):
self.modelPath = checkpoint_dir
self.initialized = False
def begin(self):
"""
Restore encoder parameters if a pre-trained encoder model is available and we haven't trained previously
"""
if not self.initialized:
log = logging.getLogger('tensorflow')
checkpoint = tf.train.latest_checkpoint(self.modelPath)
if checkpoint is None:
log.info('No pre-trained model is available, training from scratch.')
else:
log.info('Pre-trained model {0} found in {1} - warmstarting.'.format(checkpoint, self.modelPath))
tf.train.warm_start(checkpoint)
self.initialized = True
然后,训练:
initHook = InitHook(checkpoint_dir = warm_start_dir)
trainSpec = tf.estimator.TrainSpec(
input_fn = train_input_fn,
max_steps = N_STEPS,
hooks = [initHook]
)
evalSpec = tf.estimator.EvalSpec(
input_fn = eval_input_fn,
steps = None,
name = 'eval',
throttle_secs = 3600
)
tf.estimator.train_and_evaluate(estimator, trainSpec, evalSpec)
这运行s 在开始时初始化一次来自warm_start_dir
的变量。稍后,当估计器 model_dir
中有新的检查点时,它会从那里继续 warm_starting。