使用 TensorFlow Estimator 进行迁移学习/再训练

Transfer learning/ retraining with TensorFlow Estimators

我一直无法弄清楚如何在新的 TF Estimator API 中使用 transfer learning/last 层再训练 Estimator API

Estimator 需要 model_fn,其中包含 documentation. An example of a model_fn using a CNN architecture is here 中定义的网络架构、训练和评估操作。

如果我想重新训练最后一层,例如初始架构,我不确定我是否需要在此model_fn中指定整个模型,然后加载预训练的权重,或者是否有一种方法可以像 'traditional' 方法(示例 here)那样使用保存的图形。

这已作为 issue 提出,但仍然开放,我不清楚答案。

可以在模型定义期间加载元图并使用 SessionRunHook 从 ckpt 文件加载权重。

def model(features, labels, mode, params):
    # Create the graph here

    return tf.estimator.EstimatorSpec(mode, 
            predictions,
            loss,
            train_op,
            training_hooks=[RestoreHook()])

SessionRunHook 可以是:

class RestoreHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord=None):
        if session.run(tf.train.get_or_create_global_step()) == 0:
            # load weights here

这样,权重在第一步加载并在模型检查点训练期间保存。