使用 TensorFlow Estimator 仅优化模型的某些变量 API

Optimizing only certain variables of the model working with TensorFlow Estimator API

我需要冻结部分模型并仅训练某些变量。现在,使用低级 API,我可以将 var_list 传递给 tf.train.Optimizer.minimize 方法。但是,当我使用 TensorFlow Estimator 时,我只能传递优化器本身,然后用于最小化 Estimator 内部循环内的损失。

我想到的唯一解决方案是定义自定义优化器并覆盖 Optimizer.minimize 方法。像这样:

def minimize(self, *args, **kwargs):
    print("Inside...")
    if not kwargs['var_list']:
       kwargs['var_list'] = self.var_list

    return super(MyOptimizer, self).minimize(*args, **kwargs)

现在,我希望在每个训练步骤中看到 "Inside..." 短语打印在屏幕上;特别是当我看到模型训练得很好时。这有点告诉我,我的 minimize 函数被完全忽略了,我似乎无法弄清楚为什么。

那么,重写 minimize 是否正确,或者是否有更好的方法使用 Estimator 来执行此操作?

您可以通过指定 model_fn 函数

来简单地制作自定义估算器
    def model_fn(features, labels, mode):
      logits = model_architecture(features)
      loss = loss_function(logits, labels)
      if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = optimizer
        train_op = ontimizer.minimize(loss=loss, 
                                      global_step=global_step,
                                      var_list=variables_to_minimize)

      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)