pycaffe 中的求解器回调函数是什么,我该如何使用它们?

What are solvers callback functions in pycaffe, and how can I use them?

查看 this PR,我发现可以为 caffe.Solver 对象定义 on_starton_gradient 回调。

import caffe
solver = caffe.AdamSolver('solver.prototxt')
solver.add_callback(on_start, on_gradient)  # <- ??

on_starton_gradient 是什么类型的对象?
这些回调有什么用?
如何使用它们(一个例子会很好...)?

1.在哪里以及如何定义回调?

回调是求解器的一部分,因此在 solver.hpp file. To be exact, there is a Callback class 中定义,如下所示:

  // Invoked at specific points during an iteration
  class Callback {
   protected:
    virtual void on_start() = 0;
    virtual void on_gradients_ready() = 0;

    template <typename T>
    friend class Solver;
  };
  const vector<Callback*>& callbacks() const { return callbacks_; }
  void add_callback(Callback* value) {
    callbacks_.push_back(value);
  }

和一个 protected vector 这样的回调,它是 Solver class.

的成员
  vector<Callback*> callbacks_;

因此,这基本上为 Solver class 提供了一个 add_callback 函数,它允许将 Callback 类型的对象添加到向量中。这是为了确保每个回调都有两个方法:on_start()on_gradients_ready().

2。在哪里调用回调?

这发生在 solver.cpp file, in the step() 函数中,它包含主工作循环。这是主循环部分(为简单起见,删除了很多东西):

while (iter_ < stop_iter) {

    for (int i = 0; i < callbacks_.size(); ++i) {
        callbacks_[i]->on_start();
    }

    // accumulate the loss and gradient
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {
        loss += net_->ForwardBackward();
    }
    loss /= param_.iter_size();

    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready();
    }

    ApplyUpdate();

    ++iter_;
}

3。这个用在什么地方?

此回调功能是在添加 multi-GPU 支持时实现的。唯一使用回调的地方(据我所知)是在多个 GPU 之间同步求解器:

P2PSync class in parallel.hpp inherits from the Solver::Callback class, and implements an on_start() and on_gradients_ready()方法,同步GPU并最终累积所有梯度更新。

4.如何使用来自 Python?

的回调

正如拉取请求 #3020 所解释的那样,

on_start and on_gradient are python functions.

所以应该straight-forward使用。我创建的 this Github Gist 中显示了一个完整的、可运行的示例。

5.这有什么用?

由于这两个回调函数不接受 任何 参数,您不能简单地使用它们来跟踪丢失或类似的事情。为此,您必须围绕 Solver class 创建一个包装函数,并使用两个方法调用 add_callback 作为回调函数。这允许您使用 self.solver.net 从回调中访问网络。在下面的示例中,我使用 on_start 回调将数据加载到网络中,并使用 on_gradients_ready 回调来打印损失函数。

class SolverWithCallback:
    def __init__(self, solver_file):
        self.solver = caffe.SGDSolver(solver_file)
        self.solver.add_callback(self.load, self.loss)

    def solve(self):
        self.solver.solve()

    def load(self):
        inp = np.random.randint(0, 255)
        self.solver.net.blobs['data'].data[...] = inp
        self.solver.net.blobs['labels'].data[...] = 2 * inp

    def loss(self):
        print "loss: " + str(self.solver.net.blobs['loss'].data)

if __name__=='__main__':
    solver = SolverWithCallback('solver.prototxt')
    solver.solve()