在 MXNet 中使用高级 API 时如何将额外数据输入网络
How to input additional data into the network when high-level APIs are being used in MXNet
我正在研究 MXNet 框架,我需要在每次迭代期间向网络输入一个矩阵。矩阵存储在外部存储器中,它不是训练数据,并且在每次迭代结束时由网络的输出更新。在迭代过程中,矩阵必须输入到网络中。
如果我使用高级 API,即
model = mx.mod.Module(context=ctx, symbol=sym)
... ...
model.fit(train_data_iter, begin_epoch=begin_epoch,
end_epoch=end_epoch, ......)
可以实现吗?
model.fit()
没有提供您正在寻找的功能。但是,您想要实现的目标在 Apache MXNet 的 Gluon API 中非常容易实现。使用 Gluon API,您可以为训练循环编写 7 行代码,而不是使用单个 model.fit()
。这是一个典型的训练循环代码:
for epoch in range(10):
for data, label in train_data:
# forward + backward
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size) # update parameters
因此,如果您想将网络的输出反馈回输入,您可以轻松实现。要开始使用 Gluon,我推荐 60-minute Gluon Crash Course. To become an expert in Gluon, I recommend Deep Learning - The Straight Dope book as well as a comprehensive set of tutorials on the main MXNet website: http://mxnet.apache.org/tutorials/index.html
我正在研究 MXNet 框架,我需要在每次迭代期间向网络输入一个矩阵。矩阵存储在外部存储器中,它不是训练数据,并且在每次迭代结束时由网络的输出更新。在迭代过程中,矩阵必须输入到网络中。
如果我使用高级 API,即
model = mx.mod.Module(context=ctx, symbol=sym)
... ...
model.fit(train_data_iter, begin_epoch=begin_epoch,
end_epoch=end_epoch, ......)
可以实现吗?
model.fit()
没有提供您正在寻找的功能。但是,您想要实现的目标在 Apache MXNet 的 Gluon API 中非常容易实现。使用 Gluon API,您可以为训练循环编写 7 行代码,而不是使用单个 model.fit()
。这是一个典型的训练循环代码:
for epoch in range(10):
for data, label in train_data:
# forward + backward
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size) # update parameters
因此,如果您想将网络的输出反馈回输入,您可以轻松实现。要开始使用 Gluon,我推荐 60-minute Gluon Crash Course. To become an expert in Gluon, I recommend Deep Learning - The Straight Dope book as well as a comprehensive set of tutorials on the main MXNet website: http://mxnet.apache.org/tutorials/index.html