加速 Python cKDTree
Speed up Python cKDTree
我目前有一个我创建的函数,它在 55 像素范围内将蓝点与其(最多)3 个最近的邻居连接起来。vertices_xy_list 是一个非常大的列表或点(嵌套列表) 大约 5000-10000 对。
vertices_xy_list
示例:
[[3673.3333333333335, 2483.3333333333335],
[3718.6666666666665, 2489.0],
[3797.6666666666665, 2463.0],
[3750.3333333333335, 2456.6666666666665],...]
我目前已经编写了这个 calculate_draw_vertice_lines()
函数,它在 While 循环中使用 CKDTree 来查找 55 像素内的所有点,然后用绿线连接它们。
可以看出,随着列表变长,这会呈指数级变慢。有什么方法可以显着加快此功能?比如向量化操作?
def calculate_draw_vertice_lines():
global vertices_xy_list
global cell_wall_lengths
global list_of_lines_references
index = 0
while True:
if (len(vertices_xy_list) == 1):
break
point_tree = spatial.cKDTree(vertices_xy_list)
index_of_closest_points = point_tree.query_ball_point(vertices_xy_list[index], 55)
index_of_closest_points.remove(index)
for stuff in index_of_closest_points:
list_of_lines_references.append(plt.plot([vertices_xy_list[index][0],vertices_xy_list[stuff][0]] , [vertices_xy_list[index][1],vertices_xy_list[stuff][1]], color = 'green'))
wall_length = math.sqrt( (vertices_xy_list[index][0] - vertices_xy_list[stuff][0])**2 + (vertices_xy_list[index][1] - vertices_xy_list[stuff][1])**2 )
cell_wall_lengths.append(wall_length)
del vertices_xy_list[index]
fig.canvas.draw()
如果我理解正确选择绿线的逻辑,就不需要在每次迭代时都创建一个KDTree。对于每对 (p1, p2) 蓝点,当且仅当满足以下条件时才应该画线:
- p1 是 p2 的 3 个最近邻之一。
- p2 是 p1 的 3 个最近邻之一。
- 距离(p1, p2) < 55.
您可以创建一次KDTree,并高效地创建一个绿线列表。这是实现的一部分,returns 需要绘制绿线的点的索引对列表。 10,000 点在我的机器上运行时间约为 0.5 秒。
import numpy as np
from scipy import spatial
data = np.random.randint(0, 1000, size=(10_000, 2))
def get_green_lines(data):
tree = spatial.cKDTree(data)
# each key in g points to indices of 3 nearest blue points
g = {i: set(tree.query(data[i,:], 4)[-1][1:]) for i in range(data.shape[0])}
green_lines = list()
for node, candidates in g.items():
for node2 in candidates:
if node2 < node:
# avoid double-counting
continue
if node in g[node2] and spatial.distance.euclidean(data[node,:], data[node2,:]) < 55:
green_lines.append((node, node2))
return green_lines
您可以按如下方式继续绘制绿线:
green_lines = get_green_lines(data)
fig, ax = plt.subplots()
ax.scatter(data[:, 0], data[:, 1], s=1)
from matplotlib import collections as mc
lines = [[data[i], data[j]] for i, j in green_lines]
line_collection = mc.LineCollection(lines, color='green')
ax.add_collection(line_collection)
示例输出:
我目前有一个我创建的函数,它在 55 像素范围内将蓝点与其(最多)3 个最近的邻居连接起来。vertices_xy_list 是一个非常大的列表或点(嵌套列表) 大约 5000-10000 对。
vertices_xy_list
示例:
[[3673.3333333333335, 2483.3333333333335],
[3718.6666666666665, 2489.0],
[3797.6666666666665, 2463.0],
[3750.3333333333335, 2456.6666666666665],...]
我目前已经编写了这个 calculate_draw_vertice_lines()
函数,它在 While 循环中使用 CKDTree 来查找 55 像素内的所有点,然后用绿线连接它们。
可以看出,随着列表变长,这会呈指数级变慢。有什么方法可以显着加快此功能?比如向量化操作?
def calculate_draw_vertice_lines():
global vertices_xy_list
global cell_wall_lengths
global list_of_lines_references
index = 0
while True:
if (len(vertices_xy_list) == 1):
break
point_tree = spatial.cKDTree(vertices_xy_list)
index_of_closest_points = point_tree.query_ball_point(vertices_xy_list[index], 55)
index_of_closest_points.remove(index)
for stuff in index_of_closest_points:
list_of_lines_references.append(plt.plot([vertices_xy_list[index][0],vertices_xy_list[stuff][0]] , [vertices_xy_list[index][1],vertices_xy_list[stuff][1]], color = 'green'))
wall_length = math.sqrt( (vertices_xy_list[index][0] - vertices_xy_list[stuff][0])**2 + (vertices_xy_list[index][1] - vertices_xy_list[stuff][1])**2 )
cell_wall_lengths.append(wall_length)
del vertices_xy_list[index]
fig.canvas.draw()
如果我理解正确选择绿线的逻辑,就不需要在每次迭代时都创建一个KDTree。对于每对 (p1, p2) 蓝点,当且仅当满足以下条件时才应该画线:
- p1 是 p2 的 3 个最近邻之一。
- p2 是 p1 的 3 个最近邻之一。
- 距离(p1, p2) < 55.
您可以创建一次KDTree,并高效地创建一个绿线列表。这是实现的一部分,returns 需要绘制绿线的点的索引对列表。 10,000 点在我的机器上运行时间约为 0.5 秒。
import numpy as np
from scipy import spatial
data = np.random.randint(0, 1000, size=(10_000, 2))
def get_green_lines(data):
tree = spatial.cKDTree(data)
# each key in g points to indices of 3 nearest blue points
g = {i: set(tree.query(data[i,:], 4)[-1][1:]) for i in range(data.shape[0])}
green_lines = list()
for node, candidates in g.items():
for node2 in candidates:
if node2 < node:
# avoid double-counting
continue
if node in g[node2] and spatial.distance.euclidean(data[node,:], data[node2,:]) < 55:
green_lines.append((node, node2))
return green_lines
您可以按如下方式继续绘制绿线:
green_lines = get_green_lines(data)
fig, ax = plt.subplots()
ax.scatter(data[:, 0], data[:, 1], s=1)
from matplotlib import collections as mc
lines = [[data[i], data[j]] for i, j in green_lines]
line_collection = mc.LineCollection(lines, color='green')
ax.add_collection(line_collection)
示例输出: