Java:使用节点 class 进行统一成本搜索

Java: Uniform Cost Search with Node class

下面的代码应该检测图像,创建一个包含该图像像素值的二维数组,并确定从图像内的点 A 到图像内的点 B。当我 运行 我的代码时,出现以下错误:

Exception in thread "main" java.lang.ClassCastException: UniformCostSearch$Node
cannot be cast to java.lang.Comparable
        at java.util.PriorityQueue.siftUpComparable(Unknown Source)
        at java.util.PriorityQueue.siftUp(Unknown Source)
        at java.util.PriorityQueue.offer(Unknown Source)
        at java.util.PriorityQueue.add(Unknown Source)
        at UniformCostSearch.uCostSearch(UniformCostSearch.java:159)
        at UniformCostSearch.main(UniformCostSearch.java:216)

为什么会发生这种情况,我该如何解决?我查看了其他解决方案,但我相信我遇到了不同的问题。

public class UniformCostSearch
{   
    //------------------------------------
    // Node Class
    //------------------------------------
    public static class Node implements Comparator<Node>
    {
        public double cost;
        Node parent;
        public int x; 
        public int y;

        public Node(){}

        public Node(double cost, Node parent, int x, int y)
        {
            this.cost = cost;
            this.parent = parent;
            this.x = x;
            this.y = y;
        }

        @Override
        public int compare(Node node1, Node node2)
        {
            if(node1.cost < node2.cost)
                return -1;
            if(node1.cost > node2.cost)
                return 1;
            return 0;
        }

        @Override
        public boolean equals(Object obj)
        {
            if(obj instanceof Node)
            {
                Node node = (Node) obj;
                if (this.parent == node.parent)
                    return true;
            }
            return false;
        }
    }

    public static void uCostSearch(Node startState, Node endNode, 
        BufferedImage img){

        int adjacencyMatrix[][] = null;
        PriorityQueue<Node> pq = new PriorityQueue<Node>();
        startState.cost = 0.0;
        startState.parent = null;
        int inc = 0;

        pq.add(startState);

        Set<Point> visited = new HashSet<Point>();
        Point startPoint = new Point(startState.x, startState.y);
        visited.add(startPoint);

        byte[] terrain =
            ((DataBufferByte)img.getRaster().getDataBuffer()).getData();
        int width = img.getWidth();
        int height = img.getHeight();
        final boolean hasAlphaChannel = img.getAlphaRaster() != null;
        int channel_count = 4;
        int green_channel = 2; // 0=alpha, 1=blue, 2=green, 3=red

        int[][] result = new int[height][width];

// Source: 
         if (hasAlphaChannel) {
         final int pixelLength = 4;
         for (int pixel = 0, row = 0, col = 0; pixel < terrain.length; pixel += pixelLength) {
            int argb = 0;
            argb += (((int) terrain[pixel] & 0xff) << 24); // alpha
            argb += ((int) terrain[pixel + 1] & 0xff); // blue
            argb += (((int) terrain[pixel + 2] & 0xff) << 8); // green
            argb += (((int) terrain[pixel + 3] & 0xff) << 16); // red
            result[row][col] = argb;
            col++;
            if (col == width) {
               col = 0;
               row++;
            } 
         } 
      } else { 
         final int pixelLength = 3;
         for (int pixel = 0, row = 0, col = 0; pixel < terrain.length; pixel += pixelLength) {
            int argb = 0;
            argb += -16777216; // 255 alpha
            argb += ((int) terrain[pixel] & 0xff); // blue
            argb += (((int) terrain[pixel + 1] & 0xff) << 8); // green
            argb += (((int) terrain[pixel + 2] & 0xff) << 16); // red
            result[row][col] = argb;
            col++;
            if (col == width) {
               col = 0;
               row++;
            } 
         }
      }

        while(!pq.isEmpty()){
            Node s = pq.remove();
            if(s == endNode)
                return;
            else{
                if(inc++ % 5000 < 1000){
                    // set to Green
                    img.setRGB(s.y, s.x, 0xFF00FF00);       
                }
            ArrayList<Point> directions = new ArrayList<Point>();
            directions.add(new Point(s.x, s.y+1));
            directions.add(new Point(s.x, s.y-1));
            directions.add(new Point(s.x+1, s.y));
            directions.add(new Point(s.x-1, s.y));
            for(Point step: directions){
                int x = step.x;
                int y = step.y;

                if(0 <= y && y < height && 0 <= x && x < width 
                    && !(visited.contains(step))){
                        visited.add(step);
                        Node temp = new Node(s.cost + channel_count * 
                        (y * width + x) + green_channel, s, x, y);
                        pq.add(temp);   
                    }
            }
            }
            inc++;
            System.out.print(inc);
        }       
    }

    public static BufferedImage GetImage(String filename)
    {
        BufferedImage bImage = null;
        try{
            bImage = ImageIO.read(new File(filename));
        }
        catch(Exception e)
        {
            e.printStackTrace();
        }
        return bImage;
    }

    /*
    public createPath(Node node, BufferedImage image){
        while(node.parent != null){
            //set pixel to red;
            node = node.parent;
        }
    }
    */

    // main method
    public static void main(String[] args)
    {
        int startX = 0, startY = 0, endX = 0, endY = 0;
        BufferedImage bufferedImage = null;
        // Get filename and start and end points from command line
        try{
            String filename = args[0];
            startX = Integer.parseInt(args[1]);
            startY = Integer.parseInt(args[2]);
            endX = Integer.parseInt(args[3]);
            endY = Integer.parseInt(args[4]);
            bufferedImage = GetImage(filename);
            }
            catch(Exception e)
            {
                e.printStackTrace();
                System.out.print("Please enter filename start.x start.y" + 
                    " end.x end.y in that order");
            }

            Node startingPoint = new Node(0.0, null, startX, startY);
            Node endingPoint = new Node();
            endingPoint.x = endX;
            endingPoint.y = endY;

            uCostSearch(startingPoint, endingPoint, bufferedImage);

            //createPath(endingPoint, bufferedImage)
            //saveImage
            //PrintCost
            //Output Image

    }

}

为了满足 PriorityQueue class 的排序能力(由 UniformCostSearch 使用),您的 Node class 必须实现 Comparable 接口,而不是Comparator 界面。

要在比较中使用节点父节点,您可以试试这个:

    @Override
    public int compareTo(Node that)
    {
        if (this.parent != null && that.parent != null) {
            return this.parent.compareTo(that.parent);
        }
        return this.cost - that.cost;
    }

不过这只是一个例子。我没有声称它会满足您的要求,但它会防止 NullPointerException。