chainer StandardUpdater 迭代器参数
chainer StandardUpdater iterator Parameters
在 chainer 文档中,它显示在 https://docs.chainer.org/en/stable/reference/core/generated/chainer.training.StandardUpdater.html#chainer.training.StandardUpdater
参数:
迭代器——训练数据集的数据集迭代器。它也可以是将字符串映射到迭代器的字典。如果这只是一个迭代器,则迭代器由名称 'main'.
注册
但实际上在链接器的代码中,我发现了
def update_core(self):
batch = self._iterators['main'].next()
这意味着它只使用名称为 'main' 的迭代器字典?
是的,默认情况下 StandardUpdater
仅使用 'main' 迭代器。
我认为只有当您定义自己的 Updater
class 时,多重迭代器的功能才有用,它是 StandardUpdater
.[=16 的子 class =]
例如:
import chainer
from chainer import training
class MyUpdater(training.updaters.StandardUpdater):
# override `update_core`
def update_core(self):
# You can use several iterators here!
batch0 = self.get_iterator('0').next()
batch1 = self.get_iterator('1').next()
# do what ever you want with `batch0` and `batch1` etc.
...
...
train_iter0 = chainer.iterators.SerialIterator(train0, args.batchsize)
train_iter1 = chainer.iterators.SerialIterator(train1, args.batchsize)
# You can pass several iterators in dict format to your own Updater class.
updater = MyUpdater(
{'0': train_iter0, '1': train_iter1}, optimizer, device=args.gpu)
请注意,我还没有测试过上面的代码是否有效。
对于其他参考,DCGAN 示例代码显示了另一个覆盖 update_core
以定义您自己的更新方案的示例(但它也没有使用多个迭代器)。
https://github.com/chainer/chainer/blob/master/examples/dcgan/updater.py#L30
在 chainer 文档中,它显示在 https://docs.chainer.org/en/stable/reference/core/generated/chainer.training.StandardUpdater.html#chainer.training.StandardUpdater
参数: 迭代器——训练数据集的数据集迭代器。它也可以是将字符串映射到迭代器的字典。如果这只是一个迭代器,则迭代器由名称 'main'.
注册但实际上在链接器的代码中,我发现了
def update_core(self):
batch = self._iterators['main'].next()
这意味着它只使用名称为 'main' 的迭代器字典?
是的,默认情况下 StandardUpdater
仅使用 'main' 迭代器。
我认为只有当您定义自己的 Updater
class 时,多重迭代器的功能才有用,它是 StandardUpdater
.[=16 的子 class =]
例如:
import chainer
from chainer import training
class MyUpdater(training.updaters.StandardUpdater):
# override `update_core`
def update_core(self):
# You can use several iterators here!
batch0 = self.get_iterator('0').next()
batch1 = self.get_iterator('1').next()
# do what ever you want with `batch0` and `batch1` etc.
...
...
train_iter0 = chainer.iterators.SerialIterator(train0, args.batchsize)
train_iter1 = chainer.iterators.SerialIterator(train1, args.batchsize)
# You can pass several iterators in dict format to your own Updater class.
updater = MyUpdater(
{'0': train_iter0, '1': train_iter1}, optimizer, device=args.gpu)
请注意,我还没有测试过上面的代码是否有效。
对于其他参考,DCGAN 示例代码显示了另一个覆盖 update_core
以定义您自己的更新方案的示例(但它也没有使用多个迭代器)。
https://github.com/chainer/chainer/blob/master/examples/dcgan/updater.py#L30