如何使用现有的 DL4J 训练模型对新输入进行分类
How to use an existing DL4J trained model to classify new input
我有一个 DL4J LSTM 模型,可以生成顺序输入的二元分类。我已经训练和测试了模型并且对 precision/recall 感到满意。现在我想用这个模型来预测新输入的二元分类。我该怎么做呢?即我如何为训练有素的神经网络提供单个输入(包含特征行序列的文件)并获得该输入文件的二进制分类。
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(0, ","); //skip no header lines
try {
trainFeatures.initialize( new NumberedFileInputSplit(featureBaseDir + "/s_%d.csv", 0,this._modelDefinition.getNB_TRAIN_EXAMPLES()-1));
} catch (IOException e) {
throw new IOException(String.format("IO error %s. during trainFeatures", e.getMessage()));
} catch (InterruptedException e) {
throw new IOException(String.format("Interrupted exception error %s. during trainFeatures", e.getMessage()));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
try {
trainLabels.initialize(new NumberedFileInputSplit(labelBaseDir + "/s_%d.csv", 0,this._modelDefinition.getNB_TRAIN_EXAMPLES()-1));
} catch (InterruptedException e) {
throw new IOException(String.format("Interrupted exception error %s. during trainLabels initialise", e.getMessage()));
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
this._modelDefinition.getBATCH_SIZE(),this._modelDefinition.getNUM_LABEL_CLASSES(), false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(this._modelDefinition.getRANDOM_SEED()) //Random number generator seed for improved repeatability. Optional.
.updater(new Nesterovs(this._modelDefinition.getLEARNING_RATE()))
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set
.layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(this._modelDefinition.getNB_INPUTS()).nOut(this._modelDefinition.getLSTM_LAYER_SIZE()).build())
.layer(1, new LSTM.Builder().activation(Activation.TANH).nIn(this._modelDefinition.getLSTM_LAYER_SIZE()).nOut(this._modelDefinition.getLSTM_LAYER_SIZE()).build())
.layer(2,new DenseLayer.Builder().nIn(this._modelDefinition.getLSTM_LAYER_SIZE()).nOut(this._modelDefinition.getLSTM_LAYER_SIZE())
.layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
我训练了 N 个时期的模型以获得我的最佳分数。我保存了模型,现在我想打开模型并获取新的顺序特征文件的分类。
如果有这样的例子 - 请告诉我在哪里。
答案是为模型提供与我们训练时完全相同的输入,只是将标签设置为 -1。输出将是一个 INDarray,其中包含一个数组中 0 的概率和另一个数组中 1 的概率,显示在最后一个序列行中。
public void getOutputsForTheseInputsUsingThisNet(String netFilePath,String inputFileDir) throws Exception {
//open the network file
File locationToSave = new File(netFilePath);
MultiLayerNetwork nNet = null;
logger.info("Trying to open the model");
try {
nNet = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
logger.info("Success: Model opened");
} catch (IOException e) {
throw new Exception(String.format("Unable to open model from %s because of error %s", locationToSave.getAbsolutePath(),e.getMessage()));
logger.info("Loading test data");
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0, ","); //skip no lines at the top - i.e. no header
try {
testFeatures.initialize(new NumberedFileInputSplit(inputFileDir + "/features/s_4180%d.csv", 0, 4));
} catch (InterruptedException e) {
throw new Exception(String.format("IO error %s. during testFeatures", e.getMessage()));
logger.info("Loading label data");
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
try {
testLabels.initialize(new NumberedFileInputSplit(inputFileDir + "/labels/s_4180%d.csv", 0,4));
} catch (InterruptedException e) {
throw new IOException(String.format("Interrupted exception error %s. during testLabels initialise", e.getMessage()));
//DataSetIterator inputData = new Seque
logger.info("creating iterator");
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
this._modelDefinition.getBATCH_SIZE(),this._modelDefinition.getNUM_LABEL_CLASSES(), false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//now use it to classify some data
logger.info("classifying examples");
INDArray output = nNet.output(testData);
logger.info("outputing the classifications");
throw new Exception("There is no output");
//sample output
// [[[ 0, 0, 0, 0, 0.9882, 0, 0, 0, 0],
// [ 0, 0, 0, 0, 0.0118, 0, 0, 0, 0]],
// [[ 0, 0.1443, 0, 0, 0, 0, 0, 0, 0],
// [ 0, 0.8557, 0, 0, 0, 0, 0, 0, 0]],
// [[ 0, 0, 0, 0, 0, 0, 0, 0, 0.9975],
// [ 0, 0, 0, 0, 0, 0, 0, 0, 0.0025]],
// [[ 0, 0, 0, 0, 0, 0, 0.8482, 0, 0],
// [ 0, 0, 0, 0, 0, 0, 0.1518, 0, 0]],
// [[ 0, 0, 0, 0.8760, 0, 0, 0, 0, 0],
// [ 0, 0, 0, 0.1240, 0, 0, 0, 0, 0]]]
