为什么 DeepLearning4J CNN 在 INDArray 输出中返回的不是概率而是 0 和 1
Why DeepLearning4J CNN is returning not probabilities but only 0s and 1s in the INDArray output
我正在玩 DL4J 版本 1.0.0-beta3 并尝试创建一个卷积神经网络来识别 32x32 的棋子图像。
这是我用来创建和训练网络的代码:
public class BuildNetwork1 {
public static void main(String[] args) throws Exception {
File rootDir = new File("./CNNinput/chesscom1");
File locationToSave = new File(rootDir, "trained.chesscom1.bin");
int height = 32;
int width = 32;
int channels = 1;
int rngseed = 777;
int numEpochs = 100;
File trainData = new File(rootDir, "training");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngseed)
.updater(new Adam.Builder().learningRate(0.01).build())
.activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER)
.list()
//.layer(new ConvolutionLayer.Builder(new int[] {5, 5}, new int[] {1, 1}, new int[]{0, 0}).name("cnn1").nIn(1).nOut(64).biasInit(0).build())
//.layer(new SubsamplingLayer.Builder(new int[] {2, 2}, new int[] {2, 2}).name("maxpool1").build())
//.layer(new ConvolutionLayer.Builder(new int[] {5, 5}, new int[] {1, 1}, new int[]{0, 0}).name("cnn2").nIn(64).nOut(16).biasInit(0).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(13)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(32, 32, 1))
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.setListeners(new ScoreIterationListener(10));
ImageLoader loader = new ImageLoader(height, width, channels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
for (int e = 0; e < numEpochs; e++) {
File[] labels = trainData.listFiles();
for (int i = 0; i < labels.length; i++) {
File label = labels[i];
File[] images = label.listFiles();
for (int j = 0; j < images.length; j++) {
File imageFile = images[j];
BufferedImage image = ImageIO.read(imageFile);
INDArray input = loader.asMatrix(image).reshape(1, channels, height, width);
scaler.fit(new DataSet(input, null));
scaler.transform(input);
double[][] outputArray = new double[1][13];
outputArray[0][Integer.parseInt(label.getName())] = 1d;
INDArray output = Nd4j.create(outputArray);
network.fit(input, output);
}
}
}
boolean saveUpdater = true;
ModelSerializer.writeModel(network, locationToSave, saveUpdater);
}
}
以及我为获得结果而使用的代码:
public class CalcNetworkAll {
public static void main(String[] args) throws Exception {
int height = 32;
int width = 32;
int channels = 1;
File rootDir = new File("./CNNinput/chesscom1");
File locationToLoad = new File("./CNNinput/chesscom1/trained.chesscom1.bin");
File testData = new File(rootDir, "testing");
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(locationToLoad);
ImageLoader loader = new ImageLoader(height, width, channels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
File[] labels = testData.listFiles();
for (int i = 0; i < labels.length; i++) {
File label = labels[i];
File[] images = label.listFiles();
for (int j = 0; j < images.length; j++) {
File imageFile = images[j];
BufferedImage image = ImageIO.read(imageFile);
INDArray input = loader.asMatrix(image).reshape(1, channels, height, width);
scaler.fit(new DataSet(input, null));
scaler.transform(input);
INDArray output = network.output(input, false);
System.out.println(label.getName() + " => " + output);
}
}
}
}
它运行良好并提供了预期的结果,但我的问题是输出仅包含 0 和 1 而不是概率:
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000,8.1707e-37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
你知道为什么会这样吗?
非常感谢!
您的模型对其输出非常有信心。当您向它显示它之前可能已经看到的数据以及您已经训练您的模型非常适合该数据(通常称为过度拟合)时,可能会发生这种情况。
我正在玩 DL4J 版本 1.0.0-beta3 并尝试创建一个卷积神经网络来识别 32x32 的棋子图像。 这是我用来创建和训练网络的代码:
public class BuildNetwork1 {
public static void main(String[] args) throws Exception {
File rootDir = new File("./CNNinput/chesscom1");
File locationToSave = new File(rootDir, "trained.chesscom1.bin");
int height = 32;
int width = 32;
int channels = 1;
int rngseed = 777;
int numEpochs = 100;
File trainData = new File(rootDir, "training");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngseed)
.updater(new Adam.Builder().learningRate(0.01).build())
.activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER)
.list()
//.layer(new ConvolutionLayer.Builder(new int[] {5, 5}, new int[] {1, 1}, new int[]{0, 0}).name("cnn1").nIn(1).nOut(64).biasInit(0).build())
//.layer(new SubsamplingLayer.Builder(new int[] {2, 2}, new int[] {2, 2}).name("maxpool1").build())
//.layer(new ConvolutionLayer.Builder(new int[] {5, 5}, new int[] {1, 1}, new int[]{0, 0}).name("cnn2").nIn(64).nOut(16).biasInit(0).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(13)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(32, 32, 1))
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.setListeners(new ScoreIterationListener(10));
ImageLoader loader = new ImageLoader(height, width, channels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
for (int e = 0; e < numEpochs; e++) {
File[] labels = trainData.listFiles();
for (int i = 0; i < labels.length; i++) {
File label = labels[i];
File[] images = label.listFiles();
for (int j = 0; j < images.length; j++) {
File imageFile = images[j];
BufferedImage image = ImageIO.read(imageFile);
INDArray input = loader.asMatrix(image).reshape(1, channels, height, width);
scaler.fit(new DataSet(input, null));
scaler.transform(input);
double[][] outputArray = new double[1][13];
outputArray[0][Integer.parseInt(label.getName())] = 1d;
INDArray output = Nd4j.create(outputArray);
network.fit(input, output);
}
}
}
boolean saveUpdater = true;
ModelSerializer.writeModel(network, locationToSave, saveUpdater);
}
}
以及我为获得结果而使用的代码:
public class CalcNetworkAll {
public static void main(String[] args) throws Exception {
int height = 32;
int width = 32;
int channels = 1;
File rootDir = new File("./CNNinput/chesscom1");
File locationToLoad = new File("./CNNinput/chesscom1/trained.chesscom1.bin");
File testData = new File(rootDir, "testing");
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(locationToLoad);
ImageLoader loader = new ImageLoader(height, width, channels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
File[] labels = testData.listFiles();
for (int i = 0; i < labels.length; i++) {
File label = labels[i];
File[] images = label.listFiles();
for (int j = 0; j < images.length; j++) {
File imageFile = images[j];
BufferedImage image = ImageIO.read(imageFile);
INDArray input = loader.asMatrix(image).reshape(1, channels, height, width);
scaler.fit(new DataSet(input, null));
scaler.transform(input);
INDArray output = network.output(input, false);
System.out.println(label.getName() + " => " + output);
}
}
}
}
它运行良好并提供了预期的结果,但我的问题是输出仅包含 0 和 1 而不是概率:
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000,8.1707e-37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
你知道为什么会这样吗? 非常感谢!
您的模型对其输出非常有信心。当您向它显示它之前可能已经看到的数据以及您已经训练您的模型非常适合该数据(通常称为过度拟合)时,可能会发生这种情况。