我实施的最近邻算法(对于 TSP)有什么问题?

What's wrong with my implementation of the nearest neighbour algorithm (for the TSP)?

我的任务是实施 nearest neighbour algorithm for the travelling salesman problem。据说该方法应该尝试从每个城市和 return 找到的最佳游览开始。根据自动标记程序,我的实现对最基本的情况有效,但对所有更高级的情况仅部分有效。

我不明白哪里出错了,正在寻求对我的代码正确性的审查。我很想知道哪里出了问题以及正确的方法是什么。

我的Java代码如下:

/*
 * Returns the shortest tour found by exercising the NN algorithm 
 * from each possible starting city in table.
 * table[i][j] == table[j][i] gives the cost of travel between City i and City j.
 */
 public static int[] tspnn(double[][] table) {
     
     // number of vertices 
     int numberOfVertices = table.length;
     // the Hamiltonian cycle built starting from vertex i
     int[] currentHamiltonianCycle = new int[numberOfVertices];
     // the lowest total cost Hamiltonian cycle
     double lowestTotalCost = Double.POSITIVE_INFINITY;
     //  the shortest Hamiltonian cycle
     int[] shortestHamiltonianCycle = new int[numberOfVertices];
     
     // consider each vertex i as a starting point
     for (int i = 0; i < numberOfVertices; i++) {
         /* 
          * Consider all vertices that are reachable from the starting point i,
          * thereby creating a new current Hamiltonian cycle.
          */
         for (int j = 0; j < numberOfVertices; j++) {
             /* 
              * The modulo of the sum of i and j allows us to account for the fact 
              * that Java indexes arrays from 0.
              */
             currentHamiltonianCycle[j] = (i + j) % numberOfVertices;   
         }
         for (int j = 1; j < numberOfVertices - 1; j++) {
             int nextVertex = j;
             for (int p = j + 1; p < numberOfVertices; p++) {
                 if (table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[p]] < table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[nextVertex]]) {
                           nextVertex = p;
                 }
             }
             
             int a = currentHamiltonianCycle[nextVertex];
             currentHamiltonianCycle[nextVertex] = currentHamiltonianCycle[j];
             currentHamiltonianCycle[j] = a;
         }
         
         /*
          * Find the total cost of the current Hamiltonian cycle.
          */
         double currentTotalCost = table[currentHamiltonianCycle[0]][currentHamiltonianCycle[numberOfVertices - 1]];
         for (int z = 0; z < numberOfVertices - 1; z++) {
             currentTotalCost += table[currentHamiltonianCycle[z]][currentHamiltonianCycle[z + 1]];
         }
         
         if (currentTotalCost < lowestTotalCost) {
             lowestTotalCost = currentTotalCost;
             shortestHamiltonianCycle = currentHamiltonianCycle;
         }
     }
     return shortestHamiltonianCycle;
 }

编辑

为了简单的例子,我用笔和纸把这段代码看完了,没发现算法实现有什么问题。基于此,在我看来它应该适用于一般情况。


编辑 2

我已经使用以下模拟示例测试了我的实现:

double[][] table = {{0, 2.3, 1.8, 4.5}, {2.3, 0, 0.4, 0.1}, 
                {1.8, 0.4, 0, 1.3}, {4.5, 0.1, 1.3, 0}}; 

它似乎产生了最近邻算法的预期输出,即 3 -> 1 -> 2 -> 0

我现在想知道自动标记程序是否不正确,或者只是我的实现在一般情况下不起作用。

正如我在评论中所述,我发现算法本身存在一个基本问题

  • 它不会正确排列城镇,但总是按顺序工作(A-B-C-D-A-B-C-D,从任何地方开始并取 4)

为了证明该问题,我编写了以下代码用于测试和设置简单和高级示例。

  • 请先通过 static public final 常量对其进行配置,然后再更改代码本身。
  • 关注这个简单的例子:如果算法运行良好,结果将始终是 A-B-C-D 或 D-C-B-A。
  • 但是正如您从输出中观察到的那样,该算法不会select(全球)最佳游览,因为它对测试城镇的排列是错误的。

我已经添加了我自己的面向对象实现来展示:

  • select离子有问题,这真的很难用一种方法同时正确地完成
  • OO 风格有何优势
  • 正确的 testing/developing 很容易设置和执行(我什至没有在这里使用单元测试,那将是 verify/validate 算法的下一步)

代码:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;

public class TSP_NearestNeighbour {



    static public final int NUMBER_OF_TEST_RUNS = 4;

    static public final boolean GENERATE_SIMPLE_TOWNS = true;

    static public final int NUMBER_OF_COMPLEX_TOWNS         = 10;
    static public final int DISTANCE_RANGE_OF_COMPLEX_TOWNS = 100;



    static private class Town {
        public final String Name;
        public final int    X;
        public final int    Y;
        public Town(final String pName, final int pX, final int pY) {
            Name = pName;
            X = pX;
            Y = pY;
        }
        public double getDistanceTo(final Town pOther) {
            final int dx = pOther.X - X;
            final int dy = pOther.Y - Y;
            return Math.sqrt(Math.abs(dx * dx + dy * dy));
        }
        @Override public int hashCode() { // not really needed here
            final int prime = 31;
            int result = 1;
            result = prime * result + X;
            result = prime * result + Y;
            return result;
        }
        @Override public boolean equals(final Object obj) {
            if (this == obj) return true;
            if (obj == null) return false;
            if (getClass() != obj.getClass()) return false;
            final Town other = (Town) obj;
            if (X != other.X) return false;
            if (Y != other.Y) return false;
            return true;
        }
        @Override public String toString() {
            return Name + " (" + X + "/" + Y + ")";
        }
    }

    static private double[][] generateDistanceTable(final ArrayList<Town> pTowns) {
        final double[][] ret = new double[pTowns.size()][pTowns.size()];
        for (int outerIndex = 0; outerIndex < pTowns.size(); outerIndex++) {
            final Town outerTown = pTowns.get(outerIndex);

            for (int innerIndex = 0; innerIndex < pTowns.size(); innerIndex++) {
                final Town innerTown = pTowns.get(innerIndex);

                final double distance = outerTown.getDistanceTo(innerTown);
                ret[outerIndex][innerIndex] = distance;
            }
        }
        return ret;
    }



    static private ArrayList<Town> generateTowns_simple() {
        final Town a = new Town("A", 0, 0);
        final Town b = new Town("B", 1, 0);
        final Town c = new Town("C", 2, 0);
        final Town d = new Town("D", 3, 0);
        return new ArrayList<>(Arrays.asList(a, b, c, d));
    }
    static private ArrayList<Town> generateTowns_complex() {
        final ArrayList<Town> allTowns = new ArrayList<>();
        for (int i = 0; i < NUMBER_OF_COMPLEX_TOWNS; i++) {
            final int randomX = (int) (Math.random() * DISTANCE_RANGE_OF_COMPLEX_TOWNS);
            final int randomY = (int) (Math.random() * DISTANCE_RANGE_OF_COMPLEX_TOWNS);
            final Town t = new Town("Town-" + (i + 1), randomX, randomY);
            if (allTowns.contains(t)) { // do not allow different towns at same location!
                System.out.println("Towns colliding at " + t);
                --i;
            } else {
                allTowns.add(t);
            }
        }
        return allTowns;
    }
    static private ArrayList<Town> generateTowns() {
        if (GENERATE_SIMPLE_TOWNS) return generateTowns_simple();
        else return generateTowns_complex();
    }



    static private void printTowns(final ArrayList<Town> pTowns, final double[][] pDistances) {
        System.out.println("Towns:");
        for (final Town town : pTowns) {
            System.out.println("\t" + town);
        }

        System.out.println("Distance Matrix:");
        for (int y = 0; y < pDistances.length; y++) {
            System.out.print("\t");
            for (int x = 0; x < pDistances.length; x++) {
                System.out.print(pDistances[y][x] + " (" + pTowns.get(y).Name + "-" + pTowns.get(x).Name + ")" + "\t");
            }
            System.out.println();
        }
    }



    private static void testAlgorithm() {
        final ArrayList<Town> towns = generateTowns();

        for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
            final double[][] distances = generateDistanceTable(towns);
            printTowns(towns, distances);

            {
                final int[] path = tspnn(distances);
                System.out.println("tspnn Path:");
                for (int pathIndex = 0; pathIndex < path.length; pathIndex++) {
                    final Town t = towns.get(pathIndex);
                    System.out.println("\t" + t);
                }
            }
            {
                final ArrayList<Town> path = tspnn_simpleNN(towns);
                System.out.println("tspnn_simpleNN Path:");
                for (final Town t : path) {
                    System.out.println("\t" + t);
                }
                System.out.println("\n");
            }

            // prepare for for next run. We do this at the end of the loop so we can only print first config
            Collections.shuffle(towns);
        }

    }

    public static void main(final String[] args) {
        testAlgorithm();
    }



    /*
     * Returns the shortest tour found by exercising the NN algorithm
     * from each possible starting city in table.
     * table[i][j] == table[j][i] gives the cost of travel between City i and City j.
     */
    public static int[] tspnn(final double[][] table) {

        // number of vertices
        final int numberOfVertices = table.length;
        // the Hamiltonian cycle built starting from vertex i
        final int[] currentHamiltonianCycle = new int[numberOfVertices];
        // the lowest total cost Hamiltonian cycle
        double lowestTotalCost = Double.POSITIVE_INFINITY;
        //  the shortest Hamiltonian cycle
        int[] shortestHamiltonianCycle = new int[numberOfVertices];

        // consider each vertex i as a starting point
        for (int i = 0; i < numberOfVertices; i++) {
            /*
             * Consider all vertices that are reachable from the starting point i,
             * thereby creating a new current Hamiltonian cycle.
             */
            for (int j = 0; j < numberOfVertices; j++) {
                /*
                 * The modulo of the sum of i and j allows us to account for the fact
                 * that Java indexes arrays from 0.
                 */
                currentHamiltonianCycle[j] = (i + j) % numberOfVertices;
            }
            for (int j = 1; j < numberOfVertices - 1; j++) {
                int nextVertex = j;
                for (int p = j + 1; p < numberOfVertices; p++) {
                    if (table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[p]] < table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[nextVertex]]) {
                        nextVertex = p;
                    }
                }

                final int a = currentHamiltonianCycle[nextVertex];
                currentHamiltonianCycle[nextVertex] = currentHamiltonianCycle[j];
                currentHamiltonianCycle[j] = a;
            }

            /*
             * Find the total cost of the current Hamiltonian cycle.
             */
            double currentTotalCost = table[currentHamiltonianCycle[0]][currentHamiltonianCycle[numberOfVertices - 1]];
            for (int z = 0; z < numberOfVertices - 1; z++) {
                currentTotalCost += table[currentHamiltonianCycle[z]][currentHamiltonianCycle[z + 1]];
            }

            if (currentTotalCost < lowestTotalCost) {
                lowestTotalCost = currentTotalCost;
                shortestHamiltonianCycle = currentHamiltonianCycle;
            }
        }
        return shortestHamiltonianCycle;
    }



    /**
     * Here come my basic implementations.
     * They can be heavily (heavily!) improved, but are verbose and direct to show the logic behind them
     */



    /**
     * <p>example how to implement the NN solution th OO way</p>
     * we could also implement
     * <ul>
     * <li>a recursive function</li>
     * <li>or one with running counters</li>
     * <li>or one with a real map/route objects, where further optimizations can take place</li>
     * </ul>
     */
    public static ArrayList<Town> tspnn_simpleNN(final ArrayList<Town> pTowns) {
        ArrayList<Town> bestRoute = null;
        double bestCosts = Double.MAX_VALUE;

        for (final Town startingTown : pTowns) {
            //setup
            final ArrayList<Town> visitedTowns = new ArrayList<>(); // ArrayList because we need a stable index
            final HashSet<Town> unvisitedTowns = new HashSet<>(pTowns); // all towns are available at start; we use HashSet because we need fast search; indexing plays not role here

            // step 1
            Town currentTown = startingTown;
            visitedTowns.add(currentTown);
            unvisitedTowns.remove(currentTown);

            // steps 2-n
            while (unvisitedTowns.size() > 0) {
                // find nearest town
                final Town nearestTown = findNearestTown(currentTown, unvisitedTowns);
                if (nearestTown == null) throw new IllegalStateException("Something in the code is wrong...");

                currentTown = nearestTown;
                visitedTowns.add(currentTown);
                unvisitedTowns.remove(currentTown);
            }

            // selection
            final double cost = getCostsOfRoute(visitedTowns);
            if (cost < bestCosts) {
                bestCosts = cost;
                bestRoute = visitedTowns;
            }
        }
        return bestRoute;
    }



    static private Town findNearestTown(final Town pCurrentTown, final HashSet<Town> pSelectableTowns) {
        double minDist = Double.MAX_VALUE;
        Town minTown = null;

        for (final Town checkTown : pSelectableTowns) {
            final double dist = pCurrentTown.getDistanceTo(checkTown);
            if (dist < minDist) {
                minDist = dist;
                minTown = checkTown;
            }
        }

        return minTown;
    }
    static private double getCostsOfRoute(final ArrayList<Town> pTowns) {
        double costs = 0;
        for (int i = 1; i < pTowns.size(); i++) { // use pre-index
            final Town t1 = pTowns.get(i - 1);
            final Town t2 = pTowns.get(i);
            final double cost = t1.getDistanceTo(t2);
            costs += cost;
        }
        return costs;
    }



}

这在未更改的状态下,为我们提供类似于以下内容的输出:

Towns:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
Distance Matrix:
    0.0 (A-A)   1.0 (A-B)   2.0 (A-C)   3.0 (A-D)   
    1.0 (B-A)   0.0 (B-B)   1.0 (B-C)   2.0 (B-D)   
    2.0 (C-A)   1.0 (C-B)   0.0 (C-C)   1.0 (C-D)   
    3.0 (D-A)   2.0 (D-B)   1.0 (D-C)   0.0 (D-D)   
tspnn Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
tspnn_simpleNN Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)


Towns:
    C (2/0)
    D (3/0)
    B (1/0)
    A (0/0)
Distance Matrix:
    0.0 (C-C)   1.0 (C-D)   1.0 (C-B)   2.0 (C-A)   
    1.0 (D-C)   0.0 (D-D)   2.0 (D-B)   3.0 (D-A)   
    1.0 (B-C)   2.0 (B-D)   0.0 (B-B)   1.0 (B-A)   
    2.0 (A-C)   3.0 (A-D)   1.0 (A-B)   0.0 (A-A)   
tspnn Path:
    C (2/0)
    D (3/0)
    B (1/0)
    A (0/0)
tspnn_simpleNN Path:
    D (3/0)
    C (2/0)
    B (1/0)
    A (0/0)


Towns:
    D (3/0)
    B (1/0)
    C (2/0)
    A (0/0)
Distance Matrix:
    0.0 (D-D)   2.0 (D-B)   1.0 (D-C)   3.0 (D-A)   
    2.0 (B-D)   0.0 (B-B)   1.0 (B-C)   1.0 (B-A)   
    1.0 (C-D)   1.0 (C-B)   0.0 (C-C)   2.0 (C-A)   
    3.0 (A-D)   1.0 (A-B)   2.0 (A-C)   0.0 (A-A)   
tspnn Path:
    D (3/0)
    B (1/0)
    C (2/0)
    A (0/0)
tspnn_simpleNN Path:
    D (3/0)
    C (2/0)
    B (1/0)
    A (0/0)


Towns:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
Distance Matrix:
    0.0 (A-A)   1.0 (A-B)   2.0 (A-C)   3.0 (A-D)   
    1.0 (B-A)   0.0 (B-B)   1.0 (B-C)   2.0 (B-D)   
    2.0 (C-A)   1.0 (C-B)   0.0 (C-C)   1.0 (C-D)   
    3.0 (D-A)   2.0 (D-B)   1.0 (D-C)   0.0 (D-D)   
tspnn Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
tspnn_simpleNN Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)

如您所见,您的算法严重依赖于 input/towns 的序列。如果算法正确,结果将始终是 A-B-C-D 或 D-C-B-A。

因此,请使用此 'testing' 框架来改进您的代码。你提供的方法 tspnn() 不依赖其他代码,所以一旦你改进了你的代码,你可以把我的东西都注释掉。或者将这一切放在另一个 class 中,然后跨 class 调用您的实际实现。由于它是 static public,您可以通过 YourClassName.tspnn(distances).

轻松调用它

另一方面,也许看看你是否可以改进 自动标记程序,这样你就可以毫无问题地完成 Java。