是否有快速 Python 算法来查找数据集中位于给定圆内的所有点?

Is there a fast Python algorithm to find all points in a dataset which lie in a given circle?

我有大量数据(超过 10^5 个点)。我正在搜索一种快速算法,它可以找到数据集中的所有点,这些点位于由其中心点和半径给出的圆中。

我考虑过使用 kd-tree 来计算例如离圆心最近的 10 个点,然后检查它们是否在圆内。但我不确定这是否是正确的方法。

要检查点(a, b)是否在圆心(x, y)和半径r的圆内,那么你可以简单地做一个计算:

within_circle = ((x-a)**2 + (y-b)**2) <= r*r)

这个方程使用了圆的属性,圆上可以得到到一个点的绝对距离(如果你注意到的话,它也用在距离公式中)。

像 k-d 树或四叉树这样的 space 分区数据结构可以准确地回答您的查询;不需要像“10 个最近的点”这样的启发式方法。

您可以递归搜索树:

  1. 如果当前节点的边界矩形根本不与圆相交,则忽略它。
  2. 否则,如果当前节点的边界矩形完全包含在圆内,return所有点都在圆内。
  3. 否则,如果当前节点的边界矩形与圆部分重叠:
    • 如果当前节点是叶子,则单独测试每个点以查看它是否在圆圈内,return 是否在圆圈内。
    • 否则,对当前节点的子节点进行递归,并return所有点return由递归调用编辑。

对于第 2 步,矩形完全包含在圆中当且仅当它的四个角是。对于步骤 1 和 3,您可以测试矩形是否与圆的边界框相交,并保守地递归,因为矩形 可能 与圆相交;或者您可以 an exact(但稍微复杂一些)在矩形和实际圆之间进行检查,从而节省一些不必要的递归调用。

因为这是 Python,而不是 return 收集集合中的点,然后 concatenating/unioning 递归调用的结果,您可以从递归生成器函数中生成点.这在一定程度上简化了实施。

如果您想先过滤大量数据集而不进行大量计算,您可以使用大小为 (2r x 2r) 的正方形,其中心与圆相同(其中 r 是圆的半径)。

看看这张照片:

如果你有中心的坐标 (a,b) 和 r 的半径,那么正方形内​​的点 (x,y) 验证:

in_square_points = [(x,y) for (x,y) in data if a-r < x < a+r and b-r < y < b+r]

最后,在这个过滤器之后,您可以应用圆方程:

in_circle_points = [(x,y) for (x,y) in in_square_points if (x-a)**2 + (y-b)**2 < r**2]

** 编辑 **

如果您的输入结构如下:

data = [
    [13, 45],
    [-1, 2],
    ...
    [60, -4]
]

那你可以试试,如果你更喜欢普通的for循环:

in_square_points = []
for i in range(len(data)):
    x = data[i][0]
    y = data[i][1]
    if a-r < x < a+r and b-r < y < b+r:
        in_square_points.append([x, y])
print(in_square_points)

我针对一个简单的 Numpy 实现对 numexpr 版本进行了基准测试,如下所示:

#!/usr/bin/env python3

import numpy as np
import numexpr as ne

# Ensure repeatable, deterministic randomness!
np.random.seed(42)

# Generate test arrays
N = 1000000
X = np.random.rand(N)
Y = np.random.rand(N)

# Define centre and radius
cx = cy = r = 0.5

def method1(X,Y,cx,cy,r):
    """Straight Numpy determination of points in circle"""
    d = (X-cx)**2 + (Y-cy)**2
    res = d < r**2
    return res

def method2(X,Y,cx,cy,r):
    """Numexpr determination of points in circle"""
    res = ne.evaluate('((X-cx)**2 + (Y-cy)**2)<r**2')
    return res

def method3(data,a,b,r): 
    """List based determination of points in circle, with pre-filtering using a square"""
    in_square_points = [(x,y) for (x,y) in data if a-r < x < a+r and b-r < y < b+r] 
    in_circle_points = [(x,y) for (x,y) in in_square_points if (x-a)**2 + (y-b)**2 < r**2] 
    return in_circle_points

# Timing
%timeit method1(X,Y,cx,cy,r)

%timeit method2(X,Y,cx,cy,r)

# Massage input data (before timing) to match agorithm
data=[(x,y) for x,y in zip(X,Y)]
%timeit method3(data,cx,cy,r)

然后我在 IPython 中计时如下:

%timeit method1(X,Y,cx,cy,r)                                                                                                                     
6.68 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit method2(X,Y,cx,cy,r)                                                                                                                     
743 µs ± 17.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit method3(data,cx,cy,r)
1.11 s ± 9.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

因此 numexpr 版本的发布速度提高了 9 倍。由于点位于 [0..1] 范围内,该算法有效地计算 pi 并且两种方法的结果相同:

method1(X,Y,cx,cy,r).sum()                                                                                                                       
784973

method2(X,Y,cx,cy,r).sum()                                                                                                                       
784973

len(method3(data,cx,cy,r))                                                                                                                       
784973

4 * 784973 / N
3.139

注意:我应该指出 numexpr 自动为您跨多个 CPU 内核多线程处理您的代码。如果您想尝试使用线程数,您可以在调用 method2() 之前动态更改它,甚至可以在其中使用:

# Split calculations across 6 threads
ne.set_num_threads(6) 

欢迎任何其他希望测试其方法速度的人使用我的代码作为基准测试框架。

如果你只对圆圈中的点数感兴趣,你可以试试 Numba。

import numpy as np
import numba as nb
import numexpr as ne

def method2(X,Y,cx,cy,r):
    """Numexpr method"""
    res = ne.evaluate('((X-cx)**2 + (Y-cy)**2) < r**2')
    return res

@nb.njit(fastmath=True,parallel=True)
def method3(X,Y,cx,cy,r):
    acc=0
    for i in nb.prange(X.shape[0]):
        if ((X[i]-cx)**2 + (Y[i]-cy)**2) < r**2:
            acc+=1
    return acc

时间

# Ensure repeatable, deterministic randomness!
np.random.seed(42)

# Generate test arrays
N = 1000000
X = np.random.rand(N)
Y = np.random.rand(N)

# Define centre and radius
cx = cy = r = 0.5

#@Mark Setchell
%timeit method2(X,Y,cx,cy,r)
#825 µs ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit method3(X,Y,cx,cy,r)
#480 µs ± 94.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)