A* 寻路算法正在 semi-greedy

A* pathfinding algorithm is being semi-greedy

所以我一直在尝试为 Java 中的 2D tilemap 实现 A* 寻路算法,来自这个视频:https://www.youtube.com/watch?v=-L-WgKMFuhE。我尝试遵循伪代码,但我觉得没有足够的细节得到很好的解释,尤其是 G 成本的想法。因此,我决定深入研究视频创作者编写的实际代码,并使用了他的很多想法和结构,因为我花了很多时间查看多个伪代码,但没有得到想要的结果。

我完全理解代码并且按部就班地进行是有道理的,但由于某种原因,算法有时无法生成最短路径。其实就是贪心,想尽早到达终点。这里有两个例子(绿点是起点,红点是终点。灰色方块代表路径,黑色方块代表墙壁):

我反复看了作者的代码和我自己的,就是找不到问题。

这是我的寻路方法:

Node start, end; //these are initialized through a GUI

List<Node> open = new ArrayList<Node>();
Set<Node> closed = new HashSet<Node>();
    
boolean pathExists = true;

public void findPath() {
        open.add(start);

        while(!open.isEmpty()) {
            Map<Node, Double> costs = new HashMap<Node, Double>();
            for(Node node: open) {
                costs.put(node, node.cost());
            }
            Node current = Collections.min(costs.entrySet(), Map.Entry.comparingByValue()).getKey();
            
            open.remove(current);
            closed.add(current);
            
            if(current.equals(end)) {
                end = current;
                return;
            }
            
            for(Node adjacent: current.getAdjacentNodes()) {
                if(contains(closed, adjacent) || !inBounds(adjacent.getPoint()))
                    continue;
                
                double newG = current.g + adjacent.getDistanceTo(current);
                if(newG < adjacent.g || !contains(open, adjacent)) {
                    adjacent.g = newG;
                    adjacent.h = adjacent.getDistanceTo(end);
                    adjacent.parent = current;
                    
                    if(!contains(open, adjacent))
                        open.add(adjacent);
                }
            }
        }
        pathExists = false;
    }

这是我的节点 class:

private class Node {
        public Point point;
        public Node parent;
        
        public double g, h;

        public Node(Point point) {
            this.point = point;
        }
        
        public Node getParent() {
            return parent;
        }
        
        public double cost() {
            return g + h;
        }
        
        public double getDistanceTo(Node node) {
            int x = Math.abs((point.x - node.point.x)/cellWidth); //cell width and height are the dimensions of each tile
            int y = Math.abs((point.y - node.point.y)/cellHeight);
            
            if(x > y)
                return 14*y + 10*(x - y);
            return 14*x + 10*(y - x);
        }
        
        public List<Node> getAdjacentNodes() {
            return List.of(
                    new Node(new Point(point.x + cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x + cellWidth, point.y)),
                    new Node(new Point(point.x + cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y)),
                    new Node(new Point(point.x - cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x, point.y + cellHeight))
                    );
        }
        
        public Point getPoint() {
            return point;
        }
        
        public boolean equals(Node node) {
            return point.x == node.point.x && point.y == node.point.y;
        }
        
        public String toString() {
            return point.toString();
        }
    }

我还使用 Java swing 设置了一个 GUI。代码很乱,但真正重要的是寻路方法和节点 class 在一起。 如果您希望尝试调整这两个部分并验证您的结果,我将在此处留下 GUI 代码。 (注意:使用开始单选按钮选择一个开始图块并单击图块,与结束图块相同。在图块上 Right-click 使其成为墙。然后单击“查找路径”按钮。没有用户错误处理。)

import javax.swing.*;

import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.MouseListener;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.awt.event.MouseEvent;
import java.awt.Point;

public class Test extends JPanel {
    
    private final int width = 600, height = 450;
    private final int cellWidth = 25, cellHeight = 25;
    
    Map<Point, Color> tiles = new HashMap<Point, Color>();
    
    Node start, end;
    
    public Test() {
        JFrame frame = new JFrame();
        frame.setSize(800, 450);
        frame.setLocationRelativeTo(null);
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        frame.setResizable(false);
        
        ButtonGroup group = new ButtonGroup();
        
        JRadioButton startSelect = new JRadioButton("Start");
        JRadioButton endSelect = new JRadioButton("End");
        JRadioButton coord = new JRadioButton("Coord");
        
        startSelect.setSelected(true);
        
        group.add(startSelect);
        group.add(endSelect);
        group.add(coord);
        
        JButton find = new JButton("Find Path");
        find.addActionListener(event -> {
            findPath();
            drawPath();
            repaint();
        });
        
        SpringLayout layout = new SpringLayout();
        this.setLayout(layout);
        layout.putConstraint(SpringLayout.EAST, startSelect, -130, SpringLayout.EAST, this);
        layout.putConstraint(SpringLayout.NORTH, startSelect, 30, SpringLayout.NORTH, this);
        this.add(startSelect);
        layout.putConstraint(SpringLayout.WEST, endSelect, 10, SpringLayout.EAST, startSelect);
        layout.putConstraint(SpringLayout.NORTH, endSelect, 0, SpringLayout.NORTH, startSelect);
        this.add(endSelect);
        layout.putConstraint(SpringLayout.WEST, coord, 10, SpringLayout.EAST, endSelect);
        layout.putConstraint(SpringLayout.NORTH, coord, 0, SpringLayout.NORTH, endSelect);
        this.add(coord);
        layout.putConstraint(SpringLayout.EAST, find, -60, SpringLayout.EAST, this);
        layout.putConstraint(SpringLayout.NORTH, find, 120, SpringLayout.NORTH, this);
        this.add(find);
        
        this.addMouseListener(new MouseListener() {
            @Override
            public void mousePressed(MouseEvent e) {
                double x = e.getX();
                double y = e.getY();
                
                Node tile = new Node(new Point(cellWidth*((int) x/cellWidth), cellHeight*((int) y/cellHeight)));
                
                switch(e.getButton()) {
                case MouseEvent.BUTTON1:
                    if(startSelect.isSelected()) {
                        start = tile;
                        tiles.put(tile.getPoint(), Color.GRAY);
                    }
                    else if(endSelect.isSelected()) {
                        end = tile;
                        tiles.put(tile.getPoint(), Color.GRAY);
                    }
                    else if(coord.isSelected()) {
                        System.out.println(tile.getPoint());
                    }
                    break;
                case MouseEvent.BUTTON3:
                    tiles.put(tile.getPoint(), Color.BLACK);
                    closed.add(tile);
                    break;
                }
                repaint();
            }
            
            @Override
            public void mouseClicked(MouseEvent e) {

            }
            
            @Override
            public void mouseReleased(MouseEvent e) {

            }

            @Override
            public void mouseEntered(MouseEvent e) {
                
            }

            @Override
            public void mouseExited(MouseEvent e) {
                
            }
        });
        
        this.setPreferredSize(frame.getSize());
        frame.add(this);
        frame.pack();
        
        for(int i=0; i<width; i+=cellWidth) {
            for(int j=0; j<height; j+=cellHeight) {
                tiles.put(new Point(i, j), Color.WHITE);
            }
        }
        
        frame.setVisible(true);
    }
    
    public boolean inBounds(Point point) {
        return (point.x >= 0 && point.y >= 0) && (point.x <= width && point.y <= height);
    }
    
    public boolean contains(List<Node> list, Node node) {
        for(Node n: list) {
            if(n.equals(node))
                return true;
        }
        return false;
    }
    
    public boolean contains(Set<Node> list, Node node) {
        for(Node n: list) {
            if(n.equals(node))
                return true;
        }
        return false;
    }
    
    Set<Node> open = new HashSet<Node>();
    Set<Node> closed = new HashSet<Node>();
    
    boolean pathExists = true;
    
    /*
     * 
     * 
     * 
     * PATHFINDING METHOD*/
    public void findPath() {
        open.add(start);

        while(!open.isEmpty()) {
            Map<Node, Double> costs = new HashMap<Node, Double>();
            for(Node node: open) {
                costs.put(node, node.cost());
            }
            Node current = Collections.min(costs.entrySet(), Map.Entry.comparingByValue()).getKey();
            
            open.remove(current);
            closed.add(current);
            
            if(current.equals(end)) {
                end = current;
                return;
            }
            
            for(Node adjacent: current.getAdjacentNodes()) {
                if(contains(closed, adjacent) || !inBounds(adjacent.getPoint()))
                    continue;
                
                double newG = current.g + adjacent.getDistanceTo(current);
                if(newG < adjacent.g || !contains(open, adjacent)) {
                    adjacent.g = newG;
                    adjacent.h = adjacent.getDistanceTo(end);
                    adjacent.parent = current;
                    
                    if(!contains(open, adjacent))
                        open.add(adjacent);
                }
            }
        }
        pathExists = false;
    }
    
    public void drawPath() {
        if(pathExists) {
            Node current = end;
            
            while(!current.equals(start)) {
                tiles.put(current.getParent().getPoint(), Color.GRAY);
                current = current.getParent();
            }
        }
    }
    
    /*
     * 
     * 
     * 
     * 
     * NODE CLASS*/
    private class Node {
        public Point point;
        public Node parent;
        
        public double g, h;

        public Node(Point point) {
            this.point = point;
        }
        
        public Node getParent() {
            return parent;
        }
        
        public double cost() {
            return g + h;
        }
        
        /*public double getH() {
            return 10*Math.sqrt(Math.pow((end.getPoint().x - point.x)/cellWidth, 2) + Math.pow((end.getPoint().y - point.y)/cellHeight, 2));
        }*/
        
        public double getDistanceTo(Node node) {
            int x = Math.abs((point.x - node.point.x)/cellWidth);
            int y = Math.abs((point.y - node.point.y)/cellHeight);
            
            if(x > y)
                return 14*y + 10*(x - y);
            return 14*x + 10*(y - x);       
        }
        
        /*public double getPositionTo(Node node) {
            if((point.y == node.point.y && (point.x > node.point.x || point.x < node.point.x)) ||
                    (point.x == node.point.x && (point.y < node.point.y || point.y > node.point.y)))
                return 10;
            else
                return 14;
        }*/
        
        public List<Node> getAdjacentNodes() {
            return List.of(
                    new Node(new Point(point.x + cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x + cellWidth, point.y)),
                    new Node(new Point(point.x + cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y - cellHeight)),
                    new Node(new Point(point.x - cellWidth, point.y)),
                    new Node(new Point(point.x - cellWidth, point.y + cellHeight)),
                    new Node(new Point(point.x, point.y + cellHeight))
                    );
        }
        
        public Point getPoint() {
            return point;
        }
        
        public boolean equals(Node node) {
            return point.x == node.point.x && point.y == node.point.y;
        }
        
        public String toString() {
            return point.toString();
        }
    }
    
    public void paintComponent(Graphics tool) {
        super.paintComponent(tool);
        
        for(Map.Entry<Point, Color> tile: tiles.entrySet()) {
            tool.setColor(tile.getValue());
            tool.fillRect(tile.getKey().x, tile.getKey().y, cellWidth, cellHeight);
        }
        tool.setColor(Color.BLACK);
        for(int i=0; i<width; i+=cellWidth) {
            tool.drawLine(i, 0, i, height);
        }
        for(int i=0; i<height; i+=cellHeight) {
            tool.drawLine(0, i, width, i);
        }
        
        if(start != null) {
            tool.setColor(Color.GREEN);
            tool.fillOval(start.point.x + 8, start.point.y + 8, 10, 10);
        }
        if(end != null) {
            tool.setColor(Color.RED);
            tool.fillOval(end.point.x + 8, end.point.y + 8, 10, 10);
        }
    }
    
    public static void main(String[] args) {
        new Test();
    }
}

编辑 我为开集和闭集写了自己的包含方法,只是为了确保我没有弄乱那里的东西:

public boolean contains(Set<Node> list, Node node) {
        for(Node n: list) {
            if(n.equals(node))
                return true;
        }
        return false;
    }

感谢@thatotherguy 的解决方案

如果它对以后的任何人有帮助,问题是每当我为当前节点获取相邻节点时,我正在创建全新的节点,所有节点的 g 值为 0。

为了纠正这个问题,我首先检查了开放列表中是否存在相邻节点。如果是这样,获取该节点而不是创建新节点。如果该节点不存在于开放集中,则创建一个具有正确 g 值的新节点。

public List<Node> getAdjacentNodes() {
            List<Node> nodes = new LinkedList<Node>();
            nodes.add(new Node(new Point(point.x + cellWidth, point.y + cellHeight)));
            nodes.add(new Node(new Point(point.x + cellWidth, point.y)));
            nodes.add(new Node(new Point(point.x + cellWidth, point.y - cellHeight)));
            nodes.add(new Node(new Point(point.x, point.y - cellHeight)));
            nodes.add(new Node(new Point(point.x - cellWidth, point.y - cellHeight)));
            nodes.add(new Node(new Point(point.x - cellWidth, point.y)));
            nodes.add(new Node(new Point(point.x - cellWidth, point.y + cellHeight)));
            nodes.add(new Node(new Point(point.x, point.y + cellHeight)));
            
            List<Node> correctNodes = new ArrayList<Node>();
            List<Node> remove = new ArrayList<Node>();
            
            for(Node node: nodes) {
                for(Node openNode: open) {
                    if(node.equals(openNode)) {
                        correctNodes.add(openNode);
                        remove.add(node);
                        break;
                    }
                }
            }
            
            nodes.removeAll(remove);
            for(Node node: nodes) {
                node.g = this.g + getDistanceTo(node);
                correctNodes.add(node);
            }
            return correctNodes;
        }