使用 MXnet 时如何保存模型

How to save a model when using MXnet

我正在使用 MXnet 训练 CNN(在 R 中),我可以使用以下代码毫无错误地训练模型:

model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )

但由于这个过程比较耗时,所以我运行晚上在服务器上,想保存模型,以便训练完成后使用。

我用过:

save(list = ls(), file="mymodel.RData")

mx.model.save("mymodel", 10)

但是其中none个可以保存模型!例如,当我加载 "mymodel.RData" 时,我无法预测测试集的标签!

另一个例子是当我加载 "mymodel.RData" 并尝试使用以下代码绘制它时:

graph.viz(model$symbol$as.json())

我收到以下错误:

Error in model$symbol$as.json() : external pointer is not valid

任何人都可以给我一个保存然后加载此模型以供将来使用的解决方案吗?

谢谢

保存训练进度快照的最佳做法是在每个时期训练后使用 save_snapshot(http://mxnet.io/api/python/module.html#mxnet.module.Module.save_checkpoint)作为 回调 的一部分.在 R 中,等效命令可能是 mx.callback.save.checkpoint,但我没有使用 R,也不确定它的用法。

使用这些快照还可以让您利用使用 AWS Spot 市场 (https://aws.amazon.com/ec2/spot/pricing/ ) 的低成本选项,例如,现在以 3.8 美元/小时的价格提供 16 个 K80 GPU 实例,相比之下按需价格为 14.4 美元。这种 80%-90% 的折扣在现货市场很常见,只要您正确使用这些快照,就可以优化您的训练速度和成本。

您可以通过

保存模型
model <- mx.model.FeedForward.create(symbol=network,
                                 X=train.iter,
                                 ctx=mx.gpu(0),
                                 num.round=20,
                                 array.batch.size=batch.size,
                                 learning.rate=0.1,
                                 momentum=0.1,  
                                 eval.metric=mx.metric.accuracy,
                                 wd=0.001,
                                 epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
                                 batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)

mxnet 模型是一个 R 列表,但它的第一个组件不是 R 对象而是 C++ 指针,不能作为 R 对象保存和重新加载。因此,模型需要 序列化 才能表现得像一个实际的 R 对象。序列化后的对象也是一个列表,但它的第一个对象是包含模型信息的文本字符串。

要保存模型:

modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")

取回并再次使用:

load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)