从 scipy.spatial.Delauna 加速 "find_simplex"

Speed up "find_simplex" from scipy.spatial.Delauna

我使用 Scypi 的 Delaunay 三角剖分构建了一个应用程序。为了验证它,我想做一个 Hold-One-Out 测试,这意味着下面提到的代码片段被调用了很多次 (~1e9)。因此,我想尽快完成。

这是我想要加速的最小工作示例:

from scipy.spatial import Delaunay as Delaunay
import numpy as np
import time

n_pts = 100      # around 1e9 in the real application
pts = np.random.random((n_pts, 2))

t = time.time()
for i in range(n_pts):
    delaunay = Delaunay(pts[np.arange(n_pts)!=i])
    simplex = delaunay.find_simplex(pts[i])
print(time.time()-t)

大部分时间都被find_simplex方法用完了,在我的机器上大约是300ms中的200ms。有什么办法可以加快速度吗?我已经看过 Delaunay 构造函数中的 qhull_options,但是我在那里没有成功。

请注意,我无法更改整体结构,因为“真实”程序运行良好,并且此计算仅用于验证。非常感谢!

很难说 find_simplex 方法的幕后到底发生了什么,但我的猜测是,在第一次调用时它会构造一些搜索结构,因为你只使用它一次构造初始化时间是很多查询都没有摊销。

一个简单的解决方案是 运行 线性搜索,而不调用 find_simplex 方法。 由于您的代码每次迭代都会构造一个 Delaunay 三角剖分,因此 运行 时间将由三角剖分构造决定,线性搜索时间可以忽略不计。

这是一个矢量化的 numpy 函数。

 def is_in_triangle(pt, p0, p1, p2):
    """ Check in which of the triangles the point pt lies.
        p0, p1, p2 are arrays of shape (n, 2) representing vertices of triangles,
        pt is of shape (1, 2).
        Assumes p0, p1, p2 are oriented counter clockwise (as in scipy's Delaunay)
    """ 
    vp0 = pt - p0
    v10 = p1 - p0
    cross0 = vp0[:, 0] * v10[:, 1] - vp0[:, 1] * v10[:, 0]  # is pt to left/right of p0->p1

    vp1 = pt - p1
    v21 = p2 - p1
    cross1 = vp1[:, 0] * v21[:, 1] - vp1[:, 1] * v21[:, 0]  # is pt to left/right of p1->p2
    
    vp2 = pt - p2
    v02 = p0 - p2
    cross2 = vp2[:, 0] * v02[:, 1] - vp2[:, 1] * v02[:, 0]  # is pt to left/right of p2->p0

    return (cross0 < 0) & (cross1 < 0) & (cross2 < 0)  # pt should be to the left of all triangle edges

我使用此功能将您的代码修改为 运行:

t = time.time()
for i in range(n_pts):
    delaunay = Delaunay(pts[np.arange(n_pts)!=i])

    p0 = delaunay.points[delaunay.simplices[:, 0], :]
    p1 = delaunay.points[delaunay.simplices[:, 1], :]
    p2 = delaunay.points[delaunay.simplices[:, 2], :]
    pt = pts[i].reshape((1, 2))
    pt_in_simps_mask = is_in_triangle(pt, p0, p1, p2)
    simp_ind_lst = np.where(pt_in_simps_mask)[0]
    if len(simp_ind_lst) == 0:
        simp_ind = -1
    else:
        simp_ind = simp_ind_lst[0]

print("new time: {}".format(time.time()-t))

在我的机器上,当 运行 100 点此代码时 运行 大约需要 0.036 秒,而原始代码为 0.13 秒(完全没有查询的代码为 0.030 秒,只是Delaunay 结构)。