使用 MLP 和 GA 学习的 Snake AI 即使经过数千代也不会表现出智能行为

Snake AI that uses MLP and GA to learn doesn't exhibit intelligent behavior even after thousands of generations

我是一名高中生,正在为我的 CS 研究项目工作 class(我很幸运有机会参与这样的工作 class)!该项目旨在让 AI 学习流行的游戏 Snake,并使用通过遗传算法 (GA) 学习的多层感知器 (MLP)。这个项目很大程度上受到了我在 Youtube 上看到的许多视频的启发,这些视频完成了我刚才描述的内容,如你所见 here and here. I've written the project described above using JavaFX and an AI library called Neuroph.

这是我的程序目前的样子:

名称无关紧要,因为我有一个名词和形容词列表,我曾从中生成它们(我认为这会使它更有趣)。 Score 括号中的数字是那一代的最好成绩,因为一次只显示 1 条蛇。

繁殖时,我将 x% 的蛇设置为 parents(在本例中为 20)。然后child人的数量平均分配给每对蛇parents。在这种情况下,“基因”是 MLP 的权重。由于我的库并不真正支持偏差,我在输入层添加了一个偏差神经元,并将其连接到每一层中的所有其他神经元,以使其权重充当偏差(如线程 here 中所述) ).每条蛇的 children 有 50, 50 的机会获得每个基因的 parents' 基因之一。基因也有 5% 的几率发生变异,此处设置为 -1.0 和 1.0 之间的随机数。

每条蛇的 MLP 有 3 层:18 个输入神经元、14 个隐藏神经元和 4 个输出神经元(每个方向)。我提供给它的输入是头部的 x、头部的 y、食物的 x、食物的 y 和剩余的步数。它还在 4 个方向上查看,并检查到食物、墙壁和自身的距离(如果没有看到,则设置为 -1.0)。还有我说的偏差神经元,加上它后数字就变成了 18。

我计算一条蛇得分的方法是通过我的适应度函数,它是(消耗的苹果 × 5 + 活着的秒数 / 2)

这是我的 GAMLPAgent.java,所有 MLP 和 GA 的东西都发生在这里。

package agents;

import graphics.Snake;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;
import javafx.scene.shape.Rectangle;
import org.neuroph.core.Layer;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.neuron.BiasNeuron;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.TransferFunctionType;
import util.Direction;

/**
 *
 * @author Preston Tang
 *
 * GAMLPAgent stands for Genetic Algorithm Multi-Layer Perceptron Agent
 */
public class GAMLPAgent implements Comparable<GAMLPAgent> {

    public Snake mask;

    private final MultiLayerPerceptron mlp;

    private final int width;
    private final int height;
    private final double size;

    private final double mutationRate = 0.05;

    public GAMLPAgent(Snake mask, int width, int height, double size) {
        this.mask = mask;
        this.width = width;
        this.height = height;
        this.size = size;

        //Input: x of head, y of head, x of food, y of food, steps left
        //Input: 4 directions, check for distance to food, wall, and self  + 1 bias neuron (18 total)
        //6 hidden perceptrons (2 hidden layer(s))
        //Output: A direction, 4 possibilities
        mlp = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 18, 14, 4);
        //Adding connections
        List<Layer> layers = mlp.getLayers();

        for (int r = 0; r < layers.size(); r++) {
            for (int c = 0; c < layers.get(r).getNeuronsCount(); c++) {
                mlp.getInputNeurons().get(mlp.getInputsCount() - 1).addInputConnection(layers.get(r).getNeuronAt(c));
            }
        }

//        System.out.println(mlp.getInputNeurons().get(17).getInputConnections() + " " + mlp.getInputNeurons().get(17).getOutConnections());
        mlp.randomizeWeights();

//        System.out.println(Arrays.toString(mlp.getInputNeurons().get(17).getWeights()));
    }

    public void compute() {
        if (mask.isAlive()) {
            Rectangle head = mask.getSnakeParts().get(0);
            Rectangle food = mask.getFood();

            double headX = head.getX();
            double headY = head.getY();
            double foodX = mask.getFood().getX();
            double foodY = mask.getFood().getY();
            int stepsLeft = mask.getSteps();

            double foodL = -1.0, wallL, selfL = -1.0;
            double foodR = -1.0, wallR, selfR = -1.0;
            double foodU = -1.0, wallU, selfU = -1.0;
            double foodD = -1.0, wallD, selfD = -1.0;

            //The 4 directions
            //Left Direction
            if (head.getY() == food.getY() && head.getX() > food.getX()) {
                foodL = head.getX() - food.getX();
            }

            wallL = head.getX() - size;

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getY() == part.getY() && head.getX() > part.getX()) {
                    selfL = head.getX() - part.getX();
                    break;
                }
            }

            //Right Direction
            if (head.getY() == food.getY() && head.getX() < food.getX()) {
                foodR = food.getX() - head.getX();
            }

            wallR = size * width - head.getX();

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getY() == part.getY() && head.getX() < part.getX()) {
                    selfR = part.getX() - head.getX();
                    break;
                }
            }

            //Up Direction
            if (head.getX() == food.getX() && head.getY() < food.getY()) {
                foodU = food.getY() - head.getY();
            }

            wallU = size * height - head.getY();

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getX() == part.getX() && head.getY() < part.getY()) {
                    selfU = part.getY() - head.getY();
                    break;
                }
            }

            //Down Direction
            if (head.getX() == food.getX() && head.getY() > food.getY()) {
                foodD = head.getY() - food.getY();
            }

            wallD = head.getY() - size;

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getX() == part.getX() && head.getY() > part.getY()) {
                    selfD = head.getY() - food.getY();
                    break;
                }
            }

            mlp.setInput(
                    headX, headY, foodX, foodY, stepsLeft,
                    foodL, wallL, selfL,
                    foodR, wallR, selfR,
                    foodU, wallU, selfU,
                    foodD, wallD, selfD, 1);

            mlp.calculate();

            if (getIndexOfLargest(mlp.getOutput()) == 0) {
                mask.setDirection(Direction.UP);
            } else if (getIndexOfLargest(mlp.getOutput()) == 1) {
                mask.setDirection(Direction.DOWN);
            } else if (getIndexOfLargest(mlp.getOutput()) == 2) {
                mask.setDirection(Direction.LEFT);
            } else if (getIndexOfLargest(mlp.getOutput()) == 3) {
                mask.setDirection(Direction.RIGHT);
            }
        }
    }

    public double[][] breed(GAMLPAgent agent, int num) {
        //Converts Double[] to double[]
        //
        double[] parent1 = Stream.of(mlp.getWeights()).mapToDouble(Double::doubleValue).toArray();
        double[] parent2 = Stream.of(agent.getMLP().getWeights()).mapToDouble(Double::doubleValue).toArray();

        double[][] childGenes = new double[num][parent1.length];

        for (int r = 0; r < num; r++) {
            for (int c = 0; c < childGenes[r].length; c++) {
                if (new Random().nextInt(100) <= mutationRate * 100) {
                    childGenes[r][c] = ThreadLocalRandom.current().nextDouble(-1.0, 1.0);
//childGenes[r][c] += childGenes[r][c] * 0.1;
                } else {
                    childGenes[r][c] = new Random().nextDouble() < 0.5 ? parent1[c] : parent2[c];
                }
            }
        }

        return childGenes;
    }

    public MultiLayerPerceptron getMLP() {
        return mlp;
    }

    public void setMask(Snake mask) {
        this.mask = mask;
    }

    public Snake getMask() {
        return mask;
    }

    public int getIndexOfLargest(double[] array) {
        if (array == null || array.length == 0) {
            return -1; // null or empty
        }
        int largest = 0;
        for (int i = 1; i < array.length; i++) {
            if (array[i] > array[largest]) {
                largest = i;
            }
        }
        return largest; // position of the first largest found
    }

    @Override
    public int compareTo(GAMLPAgent t) {
        if (this.getMask().getScore() < t.getMask().getScore()) {
            return -1;
        } else if (t.getMask().getScore() < this.getMask().getScore()) {
            return 1;
        }
        return 0;
    }

    public void debugLocation() {
        Rectangle head = mask.getSnakeParts().get(0);
        Rectangle food = mask.getFood();
        System.out.println(head.getX() + " " + head.getY() + " " + food.getX() + " " + food.getY());
        System.out.println(mask.getName() + ": " + Arrays.toString(mlp.getOutput()));
    }

    public void debugInput() {
        String s = "";
        for (int i = 0; i < mlp.getInputNeurons().size(); i++) {
            s += mlp.getInputNeurons().get(i).getOutput() + " ";
        }
        System.out.println(s);
    }

    public double[] getOutput() {
        return mlp.getOutput();
    }
}

这是我代码的主要 class,GeneticSnake2.java,游戏循环所在的位置,以及我将基因分配给 child 蛇的地方(我知道它可以做得更干净)。

package main;

import agents.GAMLPAgent;
import ui.InfoBar;
import graphics.Snake;
import graphics.SnakeGrid;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
import javafx.animation.AnimationTimer;
import javafx.application.Application;
import static javafx.application.Application.launch;
import javafx.scene.Scene;
import javafx.scene.control.Slider;
import javafx.scene.layout.Pane;
import javafx.scene.paint.Color;
import javafx.stage.Stage;

/**
 *
 * @author Preston Tang
 */
public class GeneticSnake2 extends Application {

    private final int width = 45;
    private final int height = 40;

    private final double displaySize = 120;
    private final double size = 12;

    private final Color pathColor = Color.rgb(120, 120, 120);
    private final Color wallColor = Color.rgb(50, 50, 50);

    private final int initSnakeLength = 2;

    private final int populationSize = 1000;

    private int generation = 0;

    private int initSteps = 100;
    private int stepsIncrease = 50;

    private double parentPercentage = 0.2;

    private final ArrayList<Color> snakeColors = new ArrayList() {
        {
            add(Color.GREEN);
            add(Color.RED);
            add(Color.YELLOW);
            add(Color.BLUE);
            add(Color.MAGENTA);
            add(Color.PINK);
            add(Color.ORANGERED);
            add(Color.BLACK);
            add(Color.GOLDENROD);
            add(Color.WHITE);
        }
    };

    private final ArrayList<Snake> snakes = new ArrayList<>();

    private final ArrayList<GAMLPAgent> agents = new ArrayList<>();

    private long initTime = System.nanoTime();

    @Override
    public void start(Stage stage) {
        Pane root = new Pane();
        Pane graphics = new Pane();
        graphics.setPrefHeight(height * size);
        graphics.setPrefWidth(width * size);
        graphics.setTranslateX(0);
        graphics.setTranslateY(displaySize);

        Pane display = new Pane();
        display.setStyle("-fx-background-color: BLACK");
        display.setPrefHeight(displaySize);
        display.setPrefWidth(width * size);
        display.setTranslateX(0);
        display.setTranslateY(0);

        root.getChildren().add(display);

        SnakeGrid sg = new SnakeGrid(pathColor, wallColor, width, height, size);

        //Parsing "adjectives.txt" and "nouns.txt" to form possible names
        ArrayList<String> adjectives = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/adjectives.txt").getFile())).split("\n")));
        ArrayList<String> nouns = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/nouns.txt").getFile())).split("\n")));

        //Initializing the population
        for (int i = 0; i < populationSize; i++) {
            //Get random String from lists and capitalize first letter
            String adj = adjectives.get(new Random().nextInt(adjectives.size()));
            adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);

            String noun = nouns.get(new Random().nextInt(nouns.size()));
            noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);

            Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));

            //We want to see the first snake
            if (i == 0) {
                InfoBar bar = new InfoBar();
                bar.getStatusText().setText("Status: Alive");
                bar.getStatusText().setFill(Color.GREENYELLOW);
                bar.getSizeText().setText("Population Size: " + populationSize);

                Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
                bar.getNameText().setText("Name: " + snake.getName());

                snakes.add(snake);
                agents.add(new GAMLPAgent(snake, width, height, size));

            } else {
                Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);

                snakes.add(snake);
                agents.add(new GAMLPAgent(snake, width, height, size));
            }
        }

        //Focused on original snake
        display.getChildren().add(snakes.get(0).getInfoBar());

        graphics.getChildren().addAll(sg);

        graphics.getChildren().addAll(snakes.get(0));

        root.getChildren().add(graphics);

        //Add the speed controller (slider)
        Slider slider = new Slider(1, 10, 10);
        slider.setTranslateX(205);
        slider.setTranslateY(75);
        slider.setDisable(true);

        root.getChildren().add(slider);

        Scene scene = new Scene(root, width * size, height * size + displaySize);
        stage.setScene(scene);

        //Fixes the setResizable bug
        //
        stage.setTitle("21-GeneticSnake2 Cause the First Version Got Deleted ;-; Started on 6/8/2020");
        stage.setResizable(false);
        stage.sizeToScene();
        stage.show();

        AnimationTimer timer = new AnimationTimer() {
            private long lastUpdate = 0;

            @Override
            public void handle(long now) {
                if (now - lastUpdate >= (10 - (int) slider.getValue()) * 50_000_000) {
                    lastUpdate = now;

                    int alive = populationSize;
                    for (int i = 0; i < snakes.size(); i++) {
                        Snake snake = snakes.get(i); //Current snake

                        if (i == 0) {
                            Collections.sort(agents);
                            snake.getInfoBar().getScoreText().setText("Score: " + snake.getScore() + " (" + agents.get(agents.size() - 1).getMask().getScore() + ")");
                        }

                        if (!snake.isAlive()) {
                            alive--;

                            //Update graphics for main snake
                            if (i == 0) {
                                snake.getInfoBar().getStatusText().setText("Status: Dead");
                                snake.getInfoBar().getStatusText().setFill(Color.RED);
                                graphics.getChildren().remove(snake);
                            }

                        } else {
                            //If out of steps
                            if (snake.getSteps() <= 0) {
                                snake.setAlive(false);
                            }

                            //Bounds Detection (left right up down)
                            if (snake.getSnakeParts().get(0).getX() >= width * size
                                    || snake.getSnakeParts().get(0).getX() <= 0
                                    || snake.getSnakeParts().get(0).getY() >= height * size
                                    || snake.getSnakeParts().get(0).getY() <= 0) {
                                snake.setAlive(false);
                            }

                            //Self-Collision Detection
                            for (int o = 1; o < snakes.get(o).getSnakeParts().size(); o++) {
                                if (snakes.get(o).getSnakeParts().get(0).getX() == snakes.get(o).getSnakeParts().get(o).getX()
                                        && snakes.get(o).getSnakeParts().get(0).getY() == snakes.get(o).getSnakeParts().get(o).getY()) {
                                    snakes.get(o).setAlive(false);
                                }
                            }

                            int rate = (int) slider.getValue();
                            int seconds = (int) ((System.nanoTime() - initTime) * rate / 1_000_000_000);

                            agents.get(i).compute();
                            snake.manageMovement();
                            snake.setSecondsAlive(seconds);

//                            agents.get(0);
//                            System.out.println(Arrays.toString(agents.get(0).getOutput()));
//                            
//                            System.out.println("\n\n\n\n\n\n\n");
                            //Expression to calculate score
                            double exp = (snake.getConsumed() * 5 + snake.getSecondsAlive() / 2.0D);
//double exp = snake.getSteps() + (Math.pow(2, snake.getConsumed()) + Math.pow(snake.getConsumed(), 2.1) * 500)
//        - (Math.pow(snake.getConsumed(), 1.2) * Math.pow(0.25 * snake.getSteps(), 1.3));

                            snake.setScore(Math.round(exp * 100.0) / 100.0);

                            //Update graphics for main snake
                            if (i == 0) {
                                snake.getInfoBar().getTimeText().setText("Time Survived: " + snake.getSecondsAlive() + "s");
                                snake.getInfoBar().getFoodText().setText("Food Consumed: " + snake.getConsumed());
                                snake.getInfoBar().getGenerationText().setText("Generation: " + generation);
                                snake.getInfoBar().getStepsText().setText("Steps Remaining: " + snake.getSteps());
                            }
                        }
                    }

                    //Reset and breed
                    if (alive == 0) {
                        //Ascending order
                        initTime = System.nanoTime();
                        generation++;
                        graphics.getChildren().clear();
                        graphics.getChildren().addAll(sg);
                        snakes.clear();

                        //x% of snakes are parents
                        int parentNum = (int) (populationSize * parentPercentage);

                        //Faster odd number check
                        if ((parentNum & 1) != 0) {
                            //If odd make even
                            parentNum += 1;
                        }

                        for (int i = 0; i < parentNum; i += 2) {
                            //Get the 2 parents, sorted by score
                            GAMLPAgent p1 = agents.get(populationSize - (i + 2));
                            GAMLPAgent p2 = agents.get(populationSize - (i + 1));

                            //Produce the next generation
                            double[][] childGenes = p1.breed(p2, ((populationSize - parentNum) / parentNum) * 2);

                            //Debugs Genes
//                            System.out.println(Arrays
//                                    .stream(childGenes)
//                                    .map(Arrays::toString)
//                                    .collect(Collectors.joining(System.lineSeparator())));
                            //Soft copy
                            ArrayList<GAMLPAgent> temp = new ArrayList<>(agents);

                            for (int o = 0; o < childGenes.length; o++) {
                                temp.get(o).getMLP().setWeights(childGenes[o]);
                            }

                            //Add the genes of every pair of parents to the children
                            for (int o = 0; o < childGenes.length; o++) {
                                //Useful debug message
//                                System.out.println("ParentNum: " + parentNum
//                                        + " ChildPerParent: " + (populationSize - parentNum) / parentNum
//                                        + " Index: " + (o + (i / 2 * childGenes.length))
//                                        + " ChildGenesNum: " + childGenes.length
//                                        + " Var O: " + o);

                                //Adds the genes of the temp to the agents
                                agents.set((o + (i / 2 * childGenes.length)), temp.get(o));
                            }
//                            System.out.println("\n\n\n\n\n\n");
                        }

                        //Debugging the snakes' genes to a file
//                        String str = "";
//                        for (int i = 0; i < agents.size(); i++) {
//                            str += "Index: " + i + "\t" + Arrays.toString(agents.get(i).getMLP().getWeights())+  "\n\n\n";
//                        }
//
//                        printToFile(str, "gen" + generation);

                        for (int i = 0; i < populationSize; i++) {
                            //Get random String from lists and capitalize first letter
                            String adj = adjectives.get(new Random().nextInt(adjectives.size()));
                            adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);

                            String noun = nouns.get(new Random().nextInt(nouns.size()));
                            noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);

                            Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));

                            //We want to see the first snake
                            if (i == 0) {
                                InfoBar bar = new InfoBar();
                                bar.getStatusText().setText("Status: Alive");
                                bar.getStatusText().setFill(Color.GREENYELLOW);
                                bar.getSizeText().setText("Population Size: " + populationSize);

                                Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
                                bar.getNameText().setText("Name: " + snake.getName());
                                snakes.add(snake);
                                agents.get(i).setMask(snake);
                            } else {
                                Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
                                snakes.add(snake);
                                agents.get(i).setMask(snake);
                            }
                        }

                        graphics.getChildren().add(snakes.get(0));
                        display.getChildren().clear();

                        //Focused on original snake at first
                        display.getChildren().add(snakes.get(0).getInfoBar());
                    }
                }
            }
        };
        //Starts the infinite loop
        timer.start();
    }

    public String readFile(File f) {
        String content = "";
        try {
            content = new Scanner(f).useDelimiter("\Z").next();
        } catch (FileNotFoundException ex) {
            System.err.println("Error: Unable to read " + f.getName());
        }
        return content;
    }

    public void printToFile(String str, String name) {
        FileWriter fileWriter;
        try {
            fileWriter = new FileWriter(name + ".txt");
            try (BufferedWriter bufferedWriter = new BufferedWriter(fileWriter)) {
                bufferedWriter.write(str);
            }

        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }

    public static void main(String[] args) {
        launch(args);
    }
}

主要问题是,即使在几千代之后,蛇仍然简单地自杀到墙上。在我上面链接的视频中,蛇在第 5 代时避开墙壁并获取食物。我怀疑问题出在主要 class 那里,我正在为已经出生的蛇分配基因。

实际上我已经坚持了几个星期了。以前,我怀疑的问题之一是缺乏输入,因为那时我的输入要少得多。但是现在,我认为不再是这样了。如果需要,我可以尝试查看 4 个对角线方向,将另外 12 个输入添加到蛇的 MLP。我也去人工智能Discord求助过,但是一直没有找到解决办法。

如果需要,我愿意发送我的全部代码,这样您就可以运行自己进行模拟。

如果您已经读到这里,感谢您抽出时间来帮助我!非常感谢。

你的蛇快死了,我并不感到惊讶。

让我们退一步。什么是人工智能?好吧,这是一个搜索问题。我们正在搜索一些参数 space 以找到在给定游戏当前状态的情况下解决贪吃蛇问题的参数集。您可以想象一个 space 参数具有全局最小值:最好的蛇,犯错误最少的蛇。

所有学习算法都从这个参数的某个点开始 space 并尝试随着时间的推移找到全局最大值。首先,让我们考虑一下 MLP。 MLP 通过尝试一组权重,计算损失函数,然后朝着进一步最小化损失的方向(梯度下降)迈出一步来学习。很明显,MLP 会找到一个最小值,但它是否能找到足够好的最小值是一个问题,并且有很多训练技术可以提高这种机会。

另一方面,遗传算法的收敛特性非常差。首先,让我们停止调用这些遗传算法。让我们称这些 大杂烩算法代替。大杂烩算法从两个 parents 中获取两组参数,将它们打乱,然后产生一个新的大杂烩。是什么让您认为这会比两者中的任何一个都更好?你在这里最小化什么?你怎么知道它正在接近更好的东西?如果你附加了一个损失函数,你怎么知道你在一个实际上可以最小化的 space 中?

我想表达的观点是,遗传算法是无原则的,与自然不同。大自然并不只是将密码子放入搅拌机中来制造一条新的 DNA 链,而这正是遗传算法所做的。有一些技术可以增加一些爬山时间,但遗传算法仍然有 tons of problems.

要点是,不要被名字冲昏了头脑。遗传算法只是大杂烩算法。我的观点是您的方法不起作用,因为 GA 无法保证在 无限 次迭代后收敛,而 MLP 无法保证收敛到良好的全局最小值。

怎么办?好吧,更好的方法是使用适合您问题的学习范例。更好的方法是使用强化学习。关于这个主题有一个非常好的course on Udacity from Georgia Tech