神经网络返回 NaN 作为输出

Neural Network returning NaN as output

我正在尝试编写我的第一个神经网络来玩连连看四人游戏。 我正在使用 Java 和 deeplearning4j。 我试图实现一个遗传算法,但是当我训练网络一段时间后,网络的输出跳转到 NaN,我无法说出我在哪里搞砸了,所以才会发生这种情况。 我将 post 下面的所有 3 类,其中 Game 是游戏逻辑和规则,VGFrame UI 和 Main 所有的 nn 东西。

我有一个包含 35 个神经网络的池,每次迭代我都会让最好的 5 个存活和繁殖,并稍微随机化新创建的神经网络。 为了评估网络,我让他们互相战斗,给赢家打分,输了打分。 由于我惩罚将石头放入已经满的柱子中,我希望神经网络至少能够在一段时间后按照规则玩游戏,但他们不能这样做。 我用谷歌搜索了 NaN 问题,它似乎是一个梯度问题,但据我所知,这不应该出现在遗传算法中? 有什么想法可以让我查找错误或我的实施通常有什么问题吗?

主要

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;

public class Main {
    final int numRows = 7;
    final int numColums = 6;
    final int randSeed = 123;
    MultiLayerNetwork[] models;

    static Random random = new Random();
    private static final Logger log = LoggerFactory.getLogger(Main.class);
    final float learningRate = .8f;
    int batchSize = 64; // Test batch size
    int nEpochs = 1; // Number of training epochs
    // --
    public static Main current;
    Game mainGame = new Game();

    public static void main(String[] args) {
        current = new Main();
        current.frame = new VGFrame();
        current.loadWeights();
    }

    private VGFrame frame;
    private final double mutationChance = .05;

    public Main() {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU).seed(randSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
                        .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
                .build();
        models = new MultiLayerNetwork[35];
        for (int i = 0; i < models.length; i++) {
            models[i] = new MultiLayerNetwork(conf);
            models[i].init();
        }

    }

    public void addChip(int i, boolean b) {
        if (mainGame.gameState == 0)
            mainGame.addChip(i, b);
        if (mainGame.gameState == 0) {
            float[] f = Main.rowsToInput(mainGame.rows);
            INDArray input = Nd4j.create(f);
            INDArray output = models[0].output(input);
            for (int i1 = 0; i1 < 7; i1++) {
                System.out.println(i1 + ": " + output.getDouble(i1));
            }
            System.out.println("----------------");
            mainGame.addChip(Main.getHighestOutput(output), false);
        }
        getFrame().paint(getFrame().getGraphics());
    }

    public void newGame() {
        mainGame = new Game();
        getFrame().paint(getFrame().getGraphics());
    }

    public void startTraining(int iterations) {

        // --------------------------
        for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
            System.out.println("Iteration " + gameNumber + " of " + iterations);
            float[] evaluation = new float[models.length];
            for (int i = 0; i < models.length; i++) {
                for (int j = 0; j < models.length; j++) {
                    if (i != j) {
                        Game g = new Game();
                        g.playFullGame(models[i], models[j]);
                        if (g.gameState == 1) {
                            evaluation[i] += 45;
                            evaluation[j] += g.turnNumber;
                        }
                        if (g.gameState == 2) {
                            evaluation[j] += 45;
                            evaluation[i] += g.turnNumber;
                        }
                    }
                }
            }

            float[] evaluationSorted = evaluation.clone();
            Arrays.sort(evaluationSorted);
            // keep the best 4
            int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
            for (int i = 0; i < evaluation.length; i++) {
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
                    n1 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
                    n2 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
                    n3 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
                    n4 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
                    n5 = i;
            }
            models[0] = models[n1];
            models[1] = models[n2];
            models[2] = models[n3];
            models[3] = models[n4];
            models[4] = models[n5];

            for (int i = 3; i < evaluationSorted.length; i++) {
                // random parent/keep w8ts
                double r = Math.random();
                if (r > .3) {
                    models[i] = models[random.nextInt(3)].clone();

                } else if (r > .1) {
                    models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
                }
                // Mutate
                INDArray params = models[i].params();
                models[i].setParams(mutate(params));
            }
        }
    }

    private INDArray mutate(INDArray params) {
        double[] d = params.toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < mutationChance)
                d[i] += (Math.random() - .5) * learningRate;

        }
        return Nd4j.create(d);
    }

    private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        double[] d = m1.params().toDoubleVector();
        double[] d2 = m2.params().toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < .5)
                d[i] += d2[i];
        }
        return Nd4j.create(d);
    }

    static int getHighestOutput(INDArray output) {
        int x = 0;
        for (int i = 0; i < 7; i++) {
            if (output.getDouble(i) > output.getDouble(x))
                x = i;
        }
        return x;
    }

    static float[] rowsToInput(byte[][] rows) {
        float[] f = new float[7 * 6];
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 7; j++) {
                // f[j + i * 7] = rows[j][i] / 2f;
                f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
            }
        }
        return f;
    }

    public void saveWeights() {
        log.info("Saving model");
        for (int i = 0; i < models.length; i++) {
            File resourcesDirectory = new File("src/resources/model" + i);
            try {
                models[i].save(resourcesDirectory, true);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public void loadWeights() {
        if (new File("src/resources/model0").exists()) {
            for (int i = 0; i < models.length; i++) {
                File resourcesDirectory = new File("src/resources/model" + i);
                try {

                    models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
        System.out.println("col: " + models[0].params().shapeInfoToString());
    }

    public VGFrame getFrame() {
        return frame;
    }

}

VGFrame

import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;

public class VGFrame extends JFrame {
    JTextField iterations;
    /**
     * 
     */
    private static final long serialVersionUID = 1L;

    public VGFrame() {
        super("Vier Gewinnt");
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        this.setSize(1300, 800);
        this.setVisible(true);
        JPanel panelGame = new JPanel();
        panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
        this.add(panelGame);

        var handler = new Handler();
        var menuHandler = new MenuHandler();

        JButton b1 = new JButton("1");
        JButton b2 = new JButton("2");
        JButton b3 = new JButton("3");
        JButton b4 = new JButton("4");
        JButton b5 = new JButton("5");
        JButton b6 = new JButton("6");
        JButton b7 = new JButton("7");
        b1.addActionListener(handler);
        b2.addActionListener(handler);
        b3.addActionListener(handler);
        b4.addActionListener(handler);
        b5.addActionListener(handler);
        b6.addActionListener(handler);
        b7.addActionListener(handler);
        panelGame.add(b1);
        panelGame.add(b2);
        panelGame.add(b3);
        panelGame.add(b4);
        panelGame.add(b5);
        panelGame.add(b6);
        panelGame.add(b7);

        JButton buttonTrain = new JButton("Train");
        JButton buttonNewGame = new JButton("New Game");
        JButton buttonSave = new JButton("Save Weights");
        JButton buttonLoad = new JButton("Load Weights");

        iterations = new JTextField("1000");

        buttonTrain.addActionListener(menuHandler);
        buttonNewGame.addActionListener(menuHandler);
        buttonSave.addActionListener(menuHandler);
        buttonLoad.addActionListener(menuHandler);
        iterations.addActionListener(menuHandler);

        panelGame.add(iterations);
        panelGame.add(buttonTrain);
        panelGame.add(buttonNewGame);
        panelGame.add(buttonSave);
        panelGame.add(buttonLoad);

        this.validate();
    }

    @Override
    public void paint(Graphics g) {
        super.paint(g);
        if (Main.current.mainGame.rows == null)
            return;
        var rows = Main.current.mainGame.rows;
        for (int i = 0; i < rows.length; i++) {
            for (int j = 0; j < rows[0].length; j++) {
                if (rows[i][j] == 0)
                    break;

                g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
                g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
            }
        }
    }

    public void update() {
    }
}

class Handler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        if (Main.current.mainGame.playersTurn)
            Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
    }
}

class MenuHandler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        switch (event.getActionCommand()) {
        case "New Game":
            Main.current.newGame();
            break;
        case "Train":
            Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
            break;
        case "Save Weights":
            Main.current.saveWeights();
            break;
        case "Load Weights":
            Main.current.loadWeights();
            break;
        }

    }
}

游戏

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class Game {

    int turnNumber = 0;
    byte[][] rows = new byte[7][6];
    boolean playersTurn = true;

    int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw

    public boolean isRunning() {
        return this.gameState == 0;
    }

    public void addChip(int x, boolean player1) {
        turnNumber++;
        byte b = nextRow(x);
        if (b == 6) {
            gameState = player1 ? 2 : 1;
            return;
        }
        rows[x][b] = (byte) (player1 ? 1 : 2);
        gameState = checkWinner(x, b);
    }

    private byte nextRow(int x) {
        for (byte i = 0; i < rows[x].length; i++) {
            if (rows[x][i] == 0)
                return i;
        }
        return 6;
    }

    // 0 continue, 1 Player won, 2 ai won, 3 Draw
    private int checkWinner(int x, int y) {
        int color = rows[x][y];
        // Vertikal
        if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
            return rows[x][y];

        // Horizontal
        if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
            return rows[x][y];

        // Diagonal1
        if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
            return rows[x][y];
        // Diagonal2
        if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
            return rows[x][y];
        
        for (byte[] bs : rows) {
            for (byte s : bs) {
                if (s == 0)
                    return 0;
            }
        }
        return 3; // Draw
    }

    private int getCount(int x, int y, int dirX, int dirY) {
        int color = rows[x][y];
        int count = 0;
        while (true) {
            x += dirX;
            y += dirY;
            if (x < 0 | x > 6 | y < 0 | y > 5)
                break;
            if (color != rows[x][y])
                break;
            count++;
        }
        return count;
    }

    public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        boolean player1 = true;
        while (this.gameState == 0) {
            float[] f = Main.rowsToInput(this.rows);
            INDArray input = Nd4j.create(f);
            this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
            player1 = !player1;
        }
    }
}

快速浏览一下,根据对乘数变体的分析,NaN 似乎是由 算术下溢产生的,这是由于梯度太小造成的太接近绝对值 0)。

这是代码中最可疑的部分:

 f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);

如果rows[j][i] == 1则存储0f。我不知道神经网络是如何管理的(甚至 java),但从数学上讲,有限大小的 float 不能包含零.

即使您的代码会用一些额外的盐来改变 0f,这些数组值的结果也会有变得太接近零的风险。由于表示实数时精度有限,非常接近零的值无法表示,因此NaN.

这些值有一个非常友好的名称:subnormal numbers.

Any non-zero number with magnitude smaller than the smallest normal number is subnormal.

IEEE_754

As with IEEE 754-1985, The standard recommends 0 for signaling NaNs, 1 for quiet NaNs, so that a signaling NaNs can be quieted by changing only this bit to 1, while the reverse could yield the encoding of an infinity.

上面的文字在这里很重要:根据标准,您实际上是在指定 NaN 并存储任何 0f 值。


即使名字有误导性,Float.MIN_VALUE也是一个值,高于0:

真实最小float值实际上是:-Float.MAX_VALUE.

Is floating point math subnormal?


标准化梯度[=​​94=]

如果您检查问题只是因为 0f 值,您可以将它们更改为代表类似内容的其他值; Float.MIN_VALUEFloat.MIN_NORMAL 等。类似这样的事情,也在可能发生这种情况的代码的其他可能部分。以这些为例,并使用这些范围:

rows[j][i] == 1 ? Float.MIN_VALUE : 1f;

rows[j][i] == 1 ?  Float.MIN_NORMAL : Float.MAX_VALUE/2;

rows[j][i] == 1 ? -Float.MAX_VALUE/2 : Float.MAX_VALUE/2;

即便如此,这也可能导致 NaN,具体取决于这些值的更改方式。 如果是这样,您应该规范化这些值。您可以尝试为此应用 GradientNormalizer。在你的网络初始化中,应该为每一层(或那些有问题的层)定义这样的东西:

new NeuralNetConfiguration
  .Builder()
  .weightInit(WeightInit.XAVIER)
  (...)
  .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
        .weightInit(WeightInit.XAVIER)
        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) //this   
        .build())
  
  (...)

有不同的规范化器,因此请选择最适合您的架构的规范化器,以及哪些层应包含一个规范化器。选项是:

GradientNormalization

  • RenormalizeL2PerLayer

    Rescale gradients by dividing by the L2 norm of all gradients for the layer.

  • RenormalizeL2PerParamType

    Rescale gradients by dividing by the L2 norm of the gradients, separately for each type of parameter within the layer. This differs from RenormalizeL2PerLayer in that here, each parameter type (weight, bias etc) is normalized separately. For example, in a MLP/FeedForward network (where G is the gradient vector), the output is as follows:

    GOut_weight = G_weight / l2(G_weight) GOut_bias = G_bias / l2(G_bias)

  • ClipElementWiseAbsoluteValue

    Clip the gradients on a per-element basis. For each gradient g, set g <- sign(g) max(maxAllowedValue,|g|). i.e., if a parameter gradient has absolute value greater than the threshold, truncate it. For example, if threshold = 5, then values in range -5<g<5 are unmodified; values <-5 are set to -5; values >5 are set to 5.

  • ClipL2PerLayer

    Conditional renormalization. Somewhat similar to RenormalizeL2PerLayer, this strategy scales the gradients if and only if the L2 norm of the gradients (for entire layer) exceeds a specified threshold. Specifically, if G is gradient vector for the layer, then:

    GOut = G if l2Norm(G) < threshold (i.e., no change) GOut = threshold * G / l2Norm(G)

  • ClipL2PerParamType

    Conditional renormalization. Very similar to ClipL2PerLayer, however instead of clipping per layer, do clipping on each parameter type separately. For example in a recurrent neural network, input weight gradients, recurrent weight gradients and bias gradient are all clipped separately.


Here你可以找到这些应用的完整例子GradientNormalizers

我想我终于明白了。我试图使用 deeplearning4j-ui 可视化网络,但出现了一些版本不兼容的错误。更改版本后我收到一个新错误,指出网络输入需要一个二维数组,我在互联网上发现所有版本都需要这样。

所以我改变了

float[] f = new float[7 * 6];
Nd4j.create(f);

float[][] f = new float[1][7 * 6];
Nd4j.createFromArray(f);

NaN 值终于消失了。 @aran 所以我想假设不正确的输入绝对是正确的方向。非常感谢您的帮助:)