最近邻计算中的错误条件检查?

Faulty conditional check in nearest neighbor calculation?

我正在尝试编写一个函数,以使用 nearest neighbor algorithm 从列出的第一个城市开始计算通过城市列表的近似旅行推销员路线。但是,每次我 运行 我得到 IndexError: list index out of range.

在调试错误时,我发现 index 的值从一个循环迭代到下一个循环迭代保持不变,而不是改变。当需要追加时,代码会检查 if not in 条件;因为它是 False,所以它将 1 添加到 i 并移动到循环的下一次迭代。一旦它达到比数组中存在的数字更高的数字,它就会给我错误。

所以我的问题是,为什么执行不进入第一个if not in块?代码似乎忽略了它。

对于我的实际问题,我正在读取一个包含 317 个城市的文件,每个城市都有一个索引和两个坐标。这是一个较短的测试城市示例列表:

Nodelist = [
    (1, 63, 71),
    (2, 94, 71),
    (3, 142, 370),
    (4, 173, 1276),
    (5, 205, 1213),
    (6, 213, 69),
    (7, 244, 69),
    (8, 276, 630),
    (9, 283, 732),
    (10, 362, 69),
]

函数代码如下:

def Nearest(Nodelist,Distance,index):
    Time_Calculation = time.time()
    Newarray=[]
    Newarray.append(Nodelist[0])
    for i in range(0,len(Nodelist)):
        for j in range(1,len(Nodelist)):
            if (Nodelist[j] not in Newarray):
                DisEquation = math.sqrt(pow(Nodelist[j][1]-Newarray[i][1],2)+pow(Nodelist[j][2]-Newarray[i][2],2))
                if Distance==0:
                    Distance=DisEquation
                if Distance > DisEquation:
                    index=j
                    Distance=DisEquation
        if(Nodelist[index] not in Newarray):
            Newarray.append(Nodelist[index])
        Distance=0
    print (time.time() - Time_Calculation)
    return Newarray

调用它的代码:

NearestMethodArr=Nearest(Cities,b,index)
print(NearestMethodArr)
print(len(NearestMethodArr))

print 语句应产生:

[(1, 63, 71), (2, 94, 71), (6, 213, 69), (7, 244, 69), (10, 362, 69), (3, 142, 370), (8, 276, 630), (9, 283, 732), (5, 205, 1213), (4, 173, 1276)]
10

我发现我的代码有什么问题,当我将距离重新分配给 x 时,我忘记了我需要用它重新分配索引,因为第一个被测试的城市的距离最短,而在我的第一个代码我仅在 x 小于 Distance

时才重新分配变量索引

新代码:

def Nearest(Nodelist,Distance,index):
    Time_Calculation = time.time()
    Newarray=[]
    Newarray.append(Nodelist[0])
    for i in range(0,len(Nodelist)):
        for j in range(1,len(Nodelist)):
            if (Nodelist[j] not in Newarray):
                DisEquation = math.sqrt(pow(Nodelist[j][1]-Newarray[i][1],2)+pow(Nodelist[j][2]-Newarray[i][2],2))
                if Distance==0:
                    Distance=DisEquation
                    index=j
                if Distance > DisEquation:
                    index=j
                    Distance=DisEquation
        if(Nodelist[index] not in Newarray):
            Newarray.append(Nodelist[index])
        Distance=0
    print (time.time() - Time_Calculation)
    return Newarray
    ```
Newarray.append(Nodelist[0])#adding Nodelist[0] to NewArray    
for i in range(0,len(Nodelist)):
        for j in range(1,len(Nodelist)):
            if (Nodelist[j] not in Newarray):
                DisEquation = math.sqrt(pow(Nodelist[j][1]-Newarray[i [1],2)+pow(Nodelist[j][2]-Newarray[i][2],2)) #you access Newarray at i
                if Distance==0:
                    Distance=DisEquation
                if Distance > DisEquation:
                    index=j
                    Distance=DisEquation
        if(Nodelist[index] not in Newarray):
            Newarray.append(Nodelist[index])#you conditionally add new elements to newarray

如果您看到我添加到您的代码中的注释,那么问题应该很清楚了。您遍历 Nodelist 的所有元素并调用索引 i 您已经向 NewArray 添加了一个元素,因此第一次索引 0 存在。然后你点击不在 Newarray 中的 Nodelist[index],如果它是真的 NewArray 变大 1 然后 NewArray[1] 工作,如果由于任何原因这不是真的那么 NewArray 保持相同的大小并且下一个 NewArray[i]将是索引超出范围错误。

编辑:感谢 CrazyChucky 在评论中让我直截了当。我已经调整如下

我对失败的评论是正确的,尽管我没有确定没有像作者所指出的那样设置索引的根本原因。我没有在脑海中正确解析代码。更新版本中的代码可以工作,但如果您执行以下操作,它会更快更容易阅读:

def new_nearest(Nodelist):
    start_time = time.time()
    shortest_path = [Nodelist[0]]
    Nodelist = Nodelist[1:]
    while len(Nodelist) > 0:
        shortest_dist_sqr = -1
        next_node = None
        for potential_dest in Nodelist:
            dist_sqr = (shortest_path[-1][1] - potential_dest[1])**2 + (shortest_path[-1][2] - potential_dest[2])**2 #you don't keep the distances so there is no need to sqrt as if a > b then a**2 > b**2
            if shortest_dist_sqr < 0 or dist_sqr < shortest_dist_sqr:
                next_node = potential_dest
                shortest_dist_sqr = dist_sqr
        shortest_path.append(next_node)
        Nodelist.remove(next_node)
    print(time.time() - start_time)
    return shortest_path

这 returns 结果相同,但执行速度更快。更改为从内部循环中删除节点的方法可以更清楚地了解正在发生的事情,它可能会使代码变慢一点(在 C 中会这样,但 python 在不同的地方有很多开销,这可能会使这个一个净收益,)并且因为不需要计算实际距离,因为你不存储它你可以比较距离的平方而不做任何平方根。如果你确实想要距离,你可以在确定最近的节点后对其进行平方根。

编辑:我忍不住检查了一下。从 Nodelist 中删除节点的举动实际上代表了大部分时间的节省,而缺少 sqrt 确实确实加快了速度(我使用了 timeit 并改变了代码。)在较低级别的语言中做小事情是超快的所以它很可能更快地单独留下数组并跳过已经使用的元素(这实际上可能不是真的,因为它会扰乱分支预测性能分析真的很难并且取决于您使用的处理器架构......)在python 虽然即使是小东西也非常昂贵(添加两个变量:弄清楚它们是什么类型,解码可变字节长度整数,添加,为结果创建新对象......)所以即使从list 可能比跳过值和单独留下 list 更昂贵,这将导致更多的小操作,这些操作在 Python 中非常慢。如果正在使用低级语言,您还可以认识到节点的顺序是任意的(除了第一个,)所以您可以只拥有一个包含所有节点的数组,而不是创建一个新的小数组,您可以跟踪数组中使用的值的长度,并将数组中的最后一个值复制到为路径中的下一个节点选择的值上。

再次编辑 :P : 我又忍不住好奇了。覆盖节点列表的一部分而不是删除条目的方法让我想知道它在 python 中是否会更快,因为它确实创建了更多在 python 中很慢的工作但减少了涉及的工作量在删除节点元素。事实证明,即使在 python 中,这种方法也很明显(虽然不是很明显,略低于 2%,但在微基准测试中是一致的),所以下面的代码甚至更快:

def newer_nearest(Nodelist):
    shortest_path = [Nodelist[0]]
    Nodelist = Nodelist[1:]
    while len(Nodelist) > 0:
        shortest_dist_sqr = -1
        next_node = None
        for index, potential_dest in enumerate(Nodelist):
            dist_sqr = (shortest_path[-1][1] - potential_dest[1])**2 + (shortest_path[-1][2] - potential_dest[2])**2 #you don't keep the distances so there is no need to sqrt as if a > b then a**2 > b**2
            if shortest_dist_sqr < 0 or dist_sqr < shortest_dist_sqr:
                next_node = index
                shortest_dist_sqr = dist_sqr
        shortest_path.append(Nodelist[next_node])
        Nodelist[next_node] = Nodelist[-1]
        Nodelist = Nodelist[:-1]
    return shortest_path

David Oldford's 在大多数功能改进方面抢先一步,但我想谈谈一些可以使您的代码更清晰、更 Pythonic 的具体内容。可读性很重要。

主要改进如下:

  • rarely, if ever,需要在Python中使用for i in range(len(sequence))。这样做违背了 Python for 循环的设计精神。如果需要索引,请使用 for i, element in enumerate(sequence)。 (当您不需要索引时,只需使用for element in sequence。)
  • 给你的变量命名要清晰(他们说的是什么)并且格式一致(最好是 snake_case,但如果你喜欢 camelCase 或其他格式,选择一个并坚持使用).这不仅会帮助其他人阅读您的代码;它会帮助 当你一个月后回来并且不记得你为什么写你所做的事情时。
  • 打破长线;在赋值(=)和比较(例如==)和逗号之后放置空格;使用空行(有节制地)分隔代码块以提高可读性。
  • 在 Python 中,您可以使用 创建无穷大(或使用 -float('inf') 创建负无穷大。这是进行比较以找到最小或最大某物的惯用方法。
  • 您的函数对 nodes/cities 的列表进行操作。提供距离或索引作为函数的参数没有意义。
  • Multiple assignment 通常是通过为索引分配描述性名称来使代码更清晰的好方法。与 pow(Nodelist[j][1]-Newarray[i][1],2)+pow(Nodelist[j][2]-Newarray[i][2],2).
  • 相比,阅读 (x2 - x1)**2 + (y2 - y1)**2 并理解它在做什么要容易得多
  • 如果不是为了 post 这样的代码,我不会在自己的代码中包含这么多注释。但是 一些 解释控制流的评论可能非常有帮助。
import time

def nearest_neighbor_path(cities):
    """Find traveling salesman path via nearest neighbor algorithm."""
    start_time = time.time()

    # Start at the first city, and maintain a list of unvisited cities.
    path = [cities[0]]
    remaining_cities = cities[1:]

    # Loop until every city has been visited. (An empty list evaluates
    # to False.)
    while remaining_cities:
        # In each loop, set starting coordinates to those of the current
        # city, and initialize shortest distance to infinity.
        _, x1, y1 = path[-1]
        shortest_so_far = float('inf')
        nearest_city = None

        # Investigate each possible city to visit next.
        for index, other_city in enumerate(remaining_cities):
            # Since we don't need the *actual* distance, only a
            # comparison, there's no need to take the square root.
            # (Credit to David Oldford, good catch!)
            _, x2, y2 = other_city
            distance_squared = (x2 - x1)**2 + (y2 - y1)**2

            # If it's the closest one we've seen so far, record it.
            if distance_squared < shortest_so_far:
                shortest_so_far = distance_squared
                index_of_nearest, nearest_city = index, other_city

        # After checking all possible destinations, add the nearest one
        # to the path...
        path.append(nearest_city)
        # ...and remove it from the list of remaining cities. This could
        # simply be remaining_cities.remove(nearest_city), in which case
        # we wouldn't need the index or enumerate() at all. But doing it
        # this way saves an extra iteration to find the city again in
        # the list.
        remaining_cities.pop(index_of_nearest)
    
    print(f'Elapsed time: {time.time() - start_time :f} seconds')
    return path

如果您不介意导入 NumPy for its argmin 函数,则可以进一步简化 while 循环,如下所示:

import numpy as np
import time

def nearest_neighbor_path(cities):
    """Find traveling salesman path via nearest neighbor algorithm."""
    start_time = time.time()
    path = [cities[0]]
    remaining_cities = cities[1:]

    while remaining_cities:
        _, x1, y1 = path[-1]
        distances = [(x2 - x1)**2 + (y2 - y1)**2
                     for _, x2, y2 in remaining_cities]
        index_of_nearest = np.argmin(distances)
        
        nearest_city = remaining_cities.pop(index_of_nearest)
        path.append(nearest_city)
    
    print(f'Elapsed time: {time.time() - start_time :f} seconds')
    return path

我还建议您查看官方 Python 风格指南 PEP 8。这是保持 Python 代码清晰易读的一套非常好的指南。您的代码越容易阅读,您就越容易发现问题和解决方案。