使用 TensorFlow Estimator 进行迁移学习/再训练
Transfer learning/ retraining with TensorFlow Estimators
我一直无法弄清楚如何在新的 TF Estimator API 中使用 transfer learning/last 层再训练 Estimator API。
Estimator
需要 model_fn
,其中包含 documentation. An example of a model_fn
using a CNN architecture is here 中定义的网络架构、训练和评估操作。
如果我想重新训练最后一层,例如初始架构,我不确定我是否需要在此model_fn
中指定整个模型,然后加载预训练的权重,或者是否有一种方法可以像 'traditional' 方法(示例 here)那样使用保存的图形。
这已作为 issue 提出,但仍然开放,我不清楚答案。
可以在模型定义期间加载元图并使用 SessionRunHook 从 ckpt 文件加载权重。
def model(features, labels, mode, params):
# Create the graph here
return tf.estimator.EstimatorSpec(mode,
predictions,
loss,
train_op,
training_hooks=[RestoreHook()])
SessionRunHook 可以是:
class RestoreHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord=None):
if session.run(tf.train.get_or_create_global_step()) == 0:
# load weights here
这样,权重在第一步加载并在模型检查点训练期间保存。
我一直无法弄清楚如何在新的 TF Estimator API 中使用 transfer learning/last 层再训练 Estimator API。
Estimator
需要 model_fn
,其中包含 documentation. An example of a model_fn
using a CNN architecture is here 中定义的网络架构、训练和评估操作。
如果我想重新训练最后一层,例如初始架构,我不确定我是否需要在此model_fn
中指定整个模型,然后加载预训练的权重,或者是否有一种方法可以像 'traditional' 方法(示例 here)那样使用保存的图形。
这已作为 issue 提出,但仍然开放,我不清楚答案。
可以在模型定义期间加载元图并使用 SessionRunHook 从 ckpt 文件加载权重。
def model(features, labels, mode, params):
# Create the graph here
return tf.estimator.EstimatorSpec(mode,
predictions,
loss,
train_op,
training_hooks=[RestoreHook()])
SessionRunHook 可以是:
class RestoreHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord=None):
if session.run(tf.train.get_or_create_global_step()) == 0:
# load weights here
这样,权重在第一步加载并在模型检查点训练期间保存。