Deeplearning4j - 如何使用保存的模型?

Deeplearning4j - how to use saved model?

我正在研究用于构建神经网络的 Deeplearning4j(版本 1.0.0-M1.1)。

我以 Deeplearning4j 的 IrisClassifier 为例。

//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData);     //Apply normalization to the training data
normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

final int numInputs = 4;
int outputNum = 3;
long seed = 6;

log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(seed)
    .activation(Activation.TANH)
    .weightInit(WeightInit.XAVIER)
    .updater(new Sgd(0.1))
    .l2(1e-4)
    .list()
    .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
        .build())
    .layer(new DenseLayer.Builder().nIn(3).nOut(3)
        .build())
    .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
        .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
        .nIn(3).nOut(outputNum).build())
    .build();

//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));

for(int i=0; i<1000; i++ ) {
    model.fit(trainingData);
}

//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());

eval.eval(testData.getLabels(), output);
log.info(eval.stats());

训练模型的输入如下:

5.1,3.5,1.4,0.2,0
...
7.0,3.2,4.7,1.4,1
...
6.3,3.3,6.0,2.5,2

其中最后一项是 class 用于设置输入。

效果很好,可以训练模型和测试。

现在我想使用经过训练的模型来预测 classes 个新输入,但不知道如何去做。

好的,我可以保存模型,然后重新加载:

// Save the Model
File locationToSave = new File("C:/Projects/deeplearning4j/trained_iris_model.zip");
ModelSerializer.writeModel(model, locationToSave, false);

// Open the model
File locationToLoad = new File("C:/Projects/deeplearning4j/trained_iris_model.zip");
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToLoad);

接下来,我加载用于训练的相同数据作为示例,但没有 classes。

5.1,3.5,1.4,0.2
...
7.0,3.2,4.7,1.4
...
6.3,3.3,6.0,2.5

int numLinesToSkip = 0;
char delimiter = ',';
CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);  //skip no lines at the top - i.e. no header
recordReader.initialize(new FileSplit(new File("C:/Projects/deeplearning4j/iris-to-predict.txt")));

但是接下来呢?

如何获得预测?

谢谢!

所以,添加这段代码解决了我的问题:

int batchSize = 150;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize);
DataSet allData = iterator.next();

DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allData);
normalizer.transform(allData);

INDArray output = model.output(allData.getFeatures());

// Output
System.out.println(output);