chainer StandardUpdater 迭代器参数

chainer StandardUpdater iterator Parameters

在 chainer 文档中,它显示在 https://docs.chainer.org/en/stable/reference/core/generated/chainer.training.StandardUpdater.html#chainer.training.StandardUpdater

参数: 迭代器——训练数据集的数据集迭代器。它也可以是将字符串映射到迭代器的字典。如果这只是一个迭代器,则迭代器由名称 'main'.

注册

但实际上在链接器的代码中,我发现了

def update_core(self):
    batch = self._iterators['main'].next()

这意味着它只使用名称为 'main' 的迭代器字典?

是的,默认情况下 StandardUpdater 仅使用 'main' 迭代器。

我认为只有当您定义自己的 Updater class 时,多重迭代器的功能才有用,它是 StandardUpdater.[=16 的子 class =]

例如:

import chainer
from chainer import training

class MyUpdater(training.updaters.StandardUpdater):

    # override `update_core`
    def update_core(self):
        # You can use several iterators here!
        batch0 = self.get_iterator('0').next()
        batch1 = self.get_iterator('1').next()

        # do what ever you want with `batch0` and `batch1` etc.
        ...


...

train_iter0 = chainer.iterators.SerialIterator(train0, args.batchsize)
train_iter1 = chainer.iterators.SerialIterator(train1, args.batchsize)

# You can pass several iterators in dict format to your own Updater class.
updater = MyUpdater(
    {'0': train_iter0, '1': train_iter1}, optimizer, device=args.gpu)

请注意,我还没有测试过上面的代码是否有效。

对于其他参考,DCGAN 示例代码显示了另一个覆盖 update_core 以定义您自己的更新方案的示例(但它也没有使用多个迭代器)。 https://github.com/chainer/chainer/blob/master/examples/dcgan/updater.py#L30