加速 Python cKDTree

Speed up Python cKDTree

我目前有一个我创建的函数,它在 55 像素范围内将蓝点与其(最多)3 个最近的邻居连接起来。vertices_xy_list 是一个非常大的列表或点(嵌套列表) 大约 5000-10000 对。

vertices_xy_list 示例:

[[3673.3333333333335, 2483.3333333333335],
 [3718.6666666666665, 2489.0],
 [3797.6666666666665, 2463.0],
 [3750.3333333333335, 2456.6666666666665],...]

我目前已经编写了这个 calculate_draw_vertice_lines() 函数,它在 While 循环中使用 CKDTree 来查找 55 像素内的所有点,然后用绿线连接它们。

可以看出,随着列表变长,这会呈指数级变慢。有什么方法可以显着加快此功能?比如向量化操作?

def calculate_draw_vertice_lines():

    global vertices_xy_list
    global cell_wall_lengths
    global list_of_lines_references

    index = 0

    while True:

        if (len(vertices_xy_list) == 1):

            break

        point_tree = spatial.cKDTree(vertices_xy_list)

        index_of_closest_points = point_tree.query_ball_point(vertices_xy_list[index], 55)

        index_of_closest_points.remove(index)

        for stuff in index_of_closest_points:

            list_of_lines_references.append(plt.plot([vertices_xy_list[index][0],vertices_xy_list[stuff][0]] , [vertices_xy_list[index][1],vertices_xy_list[stuff][1]], color = 'green'))

            wall_length = math.sqrt( (vertices_xy_list[index][0] - vertices_xy_list[stuff][0])**2 + (vertices_xy_list[index][1] - vertices_xy_list[stuff][1])**2 )

            cell_wall_lengths.append(wall_length)

        del vertices_xy_list[index]

    fig.canvas.draw()

如果我理解正确选择绿线的逻辑,就不需要在每次迭代时都创建一个KDTree。对于每对 (p1, p2) 蓝点,当且仅当满足以下条件时才应该画线:

  1. p1 是 p2 的 3 个最近邻之一。
  2. p2 是 p1 的 3 个最近邻之一。
  3. 距离(p1, p2) < 55.

您可以创建一次KDTree,并高效地创建一个绿线列表。这是实现的一部分,returns 需要绘制绿线的点的索引对列表。 10,000 点在我的机器上运行时间约为 0.5 秒。

import numpy as np
from scipy import spatial


data = np.random.randint(0, 1000, size=(10_000, 2))

def get_green_lines(data):
    tree = spatial.cKDTree(data)
    # each key in g points to indices of 3 nearest blue points
    g = {i: set(tree.query(data[i,:], 4)[-1][1:]) for i in range(data.shape[0])}

    green_lines = list()
    for node, candidates in g.items():
        for node2 in candidates:
            if node2 < node:
                # avoid double-counting
                continue

            if node in g[node2] and spatial.distance.euclidean(data[node,:], data[node2,:]) < 55:
                green_lines.append((node, node2))

    return green_lines

您可以按如下方式继续绘制绿线:

green_lines = get_green_lines(data)
fig, ax = plt.subplots()
ax.scatter(data[:, 0], data[:, 1], s=1)
from matplotlib import collections as mc
lines = [[data[i], data[j]] for i, j in green_lines]
line_collection = mc.LineCollection(lines, color='green')
ax.add_collection(line_collection)

示例输出: