关于 leetcode 1091 二进制矩阵中的最短路径的问题

Question with regard leetcode 1091 Shortest Path in Binary Matrix

我正在尝试使用启发式算法和优先级队列来解决 leetcode 1091 二进制矩阵中的最短路径。但是,我无法通过所有测试。你知道我代码中的错误吗?

例如,输入是[[0,0,0],[1,1,0],[1,1,0]],输出应该是4。但是,我的代码得到了输出5. 我使用的启发式是当前节点到目标节点之间的直接距离。

class Solution {
    
    public int shortestPathBinaryMatrix(int[][] grid) {
        
        int side_length = grid.length;
        
        // if the s left top corner is 1 then, no path exist and return -1
        if(grid[0][0]== 1 || grid[side_length - 1][side_length - 1]== 1)
        {
            return -1;
        }
        
        if(side_length == 1)
        {
            return 1;
        }
        
        
        // 2D array for 8 directions
        int[][] directions = new int[][]{{1,0},{-1,0},{0,1},{0,-1},{-1,-1},{-1,1},{1,-1},{1,1}};
        
        
        PriorityQueue<Node> pqueue = new PriorityQueue<Node>(10,  new Comparator<Node>()
       {
            public int compare(Node i, Node j) {
                if(Double.compare(i.heuristic, j.heuristic) < 0){
                    return 100;
                }
                else
                {
                    return -100;
                }
            }
       });

        double heuristic = e_distance(0, 0, side_length - 1, side_length - 1);
        
        Node start_point = new Node(0, 0, heuristic);
        
        pqueue.add(start_point);
        
        
        boolean explored[][] = new boolean[side_length][side_length];
        
        explored[0][0] = true;
        
        int output = 1;
        
        while(!pqueue.isEmpty())
        {

            Node curr_point = pqueue.poll();
        
            
            int x = curr_point.x;
            int y = curr_point.y;
            
            explored[x][y] = true;
            
            if(x == side_length - 1 && y == side_length - 1)
            {
                    return output;
            }
            


            for(int[] successor : directions)
            {
                
                int successor_x = x + successor[0];
                int successor_y = y +  + successor[1];
                heuristic = e_distance(successor_x, successor_y, side_length - 1, side_length - 1);
                
                Node successor_point = new Node(successor_x, successor_y, heuristic);
                
                if (pqueue.contains(successor_point))
                {
                    continue;
                }
                
                
                if(successor_x >= 0 && successor_x < side_length && successor_y >= 0 
                   && successor_y < side_length && grid[successor_x][successor_y] == 0 
                   && !explored[successor_x][successor_y])
                {
                    if(successor_x == side_length - 1 && successor_y == side_length - 1)
                    {
                        return output + 1;
                    }
                    
                    pqueue.add(successor_point);
                }

                else
                {
                    continue;
                }

            }
            
            output++;

        }

        return -1;
    }

    
    public double e_distance(int x, int y, int target_x, int target_y)
    {
        return Math.sqrt(Math.abs(target_x - x) * Math.abs(target_x - x) + Math.abs(target_y - y)* Math.abs(target_y - y));
    }
    

}
public class Node{
        
        public int x;
        public int y;
        public double heuristic;
    
    
        public Node(int x, int y, double heuristic)
        {
            this.x = x;
            this.y = y;
            this.heuristic = heuristic;
        }
        
    
    }

以下是基于您的代码的 BFS 解决方案。它正在工作,尽管它可能需要进一步调试:

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

public class Main   {

    public static void main(String[] args) {

        List<int[][]>grids = new ArrayList<>();
        grids.add(  new int[][] {{0,1},{1,0}} );//2
        grids.add(  new int[][]{{0,0,0},{1,1,0},{1,1,0}} ); //4
        grids.add( new int[][] {{1,0,0},{1,1,0},{1,1,0}} );//-1

        Solution s = new Solution();
        for (int[][]grid : grids) {
            System.out.println(s.shortestPathBinaryMatrix(grid));
        }
    }
}

class Solution {
    // 2D array for 8 directions
    public static int[][] DIRECTIONS = new int[][]{{1,0},{-1,0},{0,1},{0,-1},{-1,-1},{-1,1},{1,-1},{1,1}};

    public int shortestPathBinaryMatrix(int[][] grid) {

        int side_length = grid.length;

        // if the s left top corner is 1 then, no path exist and return -1
        if(grid[0][0]== 1 || grid[side_length - 1][side_length - 1]== 1)    return -1;

        if(side_length == 1)    return 1;

        Queue<Node> pqueue =  new LinkedList<>();

        Node start_point = new Node(0, 0);

        pqueue.add(start_point);

        boolean explored[][] = new boolean[side_length][side_length];//you can use grid values to mark explored
        //int output = 1; use Node.parent to mark the path

        while(!pqueue.isEmpty()){

            Node curr_point = pqueue.poll();

            int x = curr_point.x;
            int y = curr_point.y;

            explored[x][y] = true;

            if(x == side_length - 1 && y == side_length - 1)    return pathLength(curr_point);

            for(int[] successor :  DIRECTIONS)  {

                int successor_x = x + successor[0];
                int successor_y = y +  + successor[1];

                Node successor_point = new Node(successor_x, successor_y);

                if (pqueue.contains(successor_point))
                {
                    continue;
                }

                if(successor_x >= 0 && successor_x < side_length && successor_y >= 0
                        && successor_y < side_length
                        && grid[successor_y][successor_x] == 0 //NOT grid[successor_x][successor_y] == 0
                        && !explored[successor_x][successor_y])
                {
                    //if(successor_x == side_length - 1 && successor_y == side_length - 1)
                    //  return output + 1;

                    explored[successor_x][successor_y] = true; //mark as explored
                    successor_point.setParent(curr_point);     //mark as child of current node
                    pqueue.add(successor_point);
                }

                else //this else does nothing
                {
                    continue;
                }
            }
        }
        return -1;
    }

    private int pathLength(Node node) {
        if(node == null) return 0;
        int pathLength = 1;
        while (node.getParent() !=null){
            node = node.getParent();
            pathLength++;
        }
        return pathLength;
    }
}

class Node{

    public int x, y;
    public double cost;
    public Node parent = null;

    public Node(int x, int y){
        this(x, y, 0);
    }

    public Node(int x, int y, double cost)
    {
        this.x = x;  this.y = y;
        this.cost = cost;
    }

    public Node getParent() {
        return parent;
    }

    public void setParent(Node parent) {
        this.parent = parent;
    }

    //todo implement equals and hashCode 
}

shortestPathBinaryMatrix 的更好实现:

public int shortestPathBinaryMatrix(int[][] grid) {

    int side_length = grid.length;

    // if the s left top corner is 1 then, no path exist and return -1
    if(grid[0][0]== 1 || grid[side_length - 1][side_length - 1]== 1)    return -1;

    if(side_length == 1)    return 1;

    Queue<Node> queue =  new LinkedList<>();
    queue.add(new Node(0, 0));

    while(!queue.isEmpty()){

        Node curr_point = queue.poll();
        int x = curr_point.x;   int y = curr_point.y;

        if(x == side_length - 1 && y == side_length - 1) return pathLength(curr_point);
        grid[y][x] = 1;

        for(int[] successor :  DIRECTIONS)  {

            int successor_x = x + successor[0];
            int successor_y = y +  + successor[1];

            if(successor_x >= 0 && successor_x < side_length && successor_y >= 0
                    && successor_y < side_length
                    && grid[successor_y][successor_x] == 0) {

                Node successor_point = new Node(successor_x, successor_y);

                if (queue.contains(successor_point)){
                    continue;
                }

                grid[successor_y][successor_x] = 1; //mark as explored
                successor_point.setParent(curr_point);     //mark as child of current node
                queue.add(successor_point);
            }
        }
    }
    return -1;
}

当不同的边具有不同的成本(加权图)时,您可能需要实施 Dijkstra 算法。 Dijkstra 算法基本上是一种增强的 BFS。这是需要成本和 PriorityQueue 的地方。
shortestPathBinaryMatrix 变为:

    //Dijkstra's Algorithm
    public int shortestPathBinaryMatrix(int[][] grid) {

        int side_length = grid.length;

        // if the s left top corner is 1 then, no path exist and return -1
        if(grid[0][0]== 1 || grid[side_length - 1][side_length - 1]== 1)    return -1;

        if(side_length == 1)    return 1;

        PriorityQueue<Node> queue = new PriorityQueue<>(10,(i,j)-> Double.compare(i.cost, j.cost));
        queue.add(new Node(0, 0, 0));

        while(!queue.isEmpty()){

            Node curr_point = queue.poll();
            int x = curr_point.x;   int y = curr_point.y;

            if(x == side_length - 1 && y == side_length - 1) return pathLength(curr_point);

            grid[y][x] = 1;
            for(int[] successor :  DIRECTIONS)  {

                int successor_x = x + successor[0];
                int successor_y = y +  + successor[1];

                if(successor_x >= 0 && successor_x < side_length
                        && successor_y >= 0 && successor_y < side_length
                        && grid[successor_y][successor_x] == 0) {

                    double cost = curr_point.cost+1;
                    Node successor_point = new Node(successor_x, successor_y, cost);
                    if (queue.contains(successor_point)) {
                        continue;
                    }

                    grid[successor_y][successor_x] = 1; //mark as explored
                    successor_point.setParent(curr_point);     //mark as child of current node
                    queue.add(successor_point);
                }
            }
        }
        return -1;
    }

实现 A* 算法时需要启发式算法。 A*算法基本上是一种增强的Dijkstra算法。
要实现它,您只需将成本计算修改为:
double cost = curr_point.cost+1 + heuristic ; 所以shortestPathBinaryMatrix 变成:

    //A* algorithm
    public int shortestPathBinaryMatrix(int[][] grid) {

        int side_length = grid.length;

        // if the s left top corner is 1 then, no path exist and return -1
        if(grid[0][0]== 1 || grid[side_length - 1][side_length - 1]== 1)    return -1;

        if(side_length == 1)    return 1;

        PriorityQueue<Node> queue = new PriorityQueue<>(10,(i,j)-> Double.compare(i.cost, j.cost));
        queue.add(new Node(0, 0, 0));

        while(!queue.isEmpty()){

            Node curr_point = queue.poll();
            int x = curr_point.x;   int y = curr_point.y;

            if(x == side_length - 1 && y == side_length - 1) return pathLength(curr_point);

            grid[y][x] = 1;
            for(int[] successor :  DIRECTIONS)  {

                int successor_x = x + successor[0];
                int successor_y = y +  + successor[1];

                if(successor_x >= 0 && successor_x < side_length
                        && successor_y >= 0 && successor_y < side_length
                        && grid[successor_y][successor_x] == 0) {

                    double cost = curr_point.cost+1 + distance(successor_x, successor_y, x, y);
                    Node successor_point = new Node(successor_x, successor_y, cost);
                    if (queue.contains(successor_point)) {
                        continue;
                    }

                    grid[successor_y][successor_x] = 1; //mark as explored
                    successor_point.setParent(curr_point);     //mark as child of current node
                    queue.add(successor_point);
                }
            }
        }
        return -1;
    }

distance定义为:

    public double distance(int x, int y, int targetX, int targetY)   {
       return Math.sqrt(Math.pow(targetX - x,2) + Math.pow(targetY - y,2));
    }