如何监控 Chainer 框架中验证集的错误?
How to monitor error on a validation set in Chainer framework?
我是 Chainer 的新手,我编写了一段代码来训练一个简单的前馈神经网络。我有一个验证集和一个训练集,并且想在每次 500 次迭代时对验证集进行测试,如果结果更好,我想保存我的网络权重。谁能告诉我该怎么做?
这是我的代码:
optimizer = optimizers.Adam()
optimizer.setup(model)
updater = training.StandardUpdater(train_iter, optimizer, device=0)
trainer = training.Trainer(updater, (10000, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(validation_iter, model, device=0))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
trainer.run()
- 验证集错误
由Evaluator
报告,由PrintReport
打印。因此,它应该与上面的代码一起显示。为了控制这些扩展的执行频率,您可以在 trainer.extend
函数中指定 trigger
关键字参数。
例如,下面的代码指定每 500 次迭代打印一次。
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']), trigger=(500, 'iteration'))
您还可以将触发器指定为 Evaluator
。
- 保存网络权重
您可以使用 snapshot_object 扩展名。
默认每个epoch调用一次。
如果你想在损失改善时调用它,我想你可以设置trigger
使用MinValueTrigger
。
http://docs.chainer.org/en/stable/reference/generated/chainer.training.triggers.MinValueTrigger.html
我是 Chainer 的新手,我编写了一段代码来训练一个简单的前馈神经网络。我有一个验证集和一个训练集,并且想在每次 500 次迭代时对验证集进行测试,如果结果更好,我想保存我的网络权重。谁能告诉我该怎么做?
这是我的代码:
optimizer = optimizers.Adam()
optimizer.setup(model)
updater = training.StandardUpdater(train_iter, optimizer, device=0)
trainer = training.Trainer(updater, (10000, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(validation_iter, model, device=0))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
trainer.run()
- 验证集错误
由Evaluator
报告,由PrintReport
打印。因此,它应该与上面的代码一起显示。为了控制这些扩展的执行频率,您可以在 trainer.extend
函数中指定 trigger
关键字参数。
例如,下面的代码指定每 500 次迭代打印一次。
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']), trigger=(500, 'iteration'))
您还可以将触发器指定为 Evaluator
。
- 保存网络权重
您可以使用 snapshot_object 扩展名。
默认每个epoch调用一次。
如果你想在损失改善时调用它,我想你可以设置trigger
使用MinValueTrigger
。
http://docs.chainer.org/en/stable/reference/generated/chainer.training.triggers.MinValueTrigger.html