我怎样才能加快我写的 python 代码:使用空间搜索的球体接触检测(碰撞)

How could I speed up my written python code: spheres contact detection (collision) using spatial searching

我正在研究一个球体的空间搜索案例,我想在其中找到相连的球体。为此,我在每个球体周围搜索了中心与搜索球体中心距离为(最大球体直径)的球体。起初,我尝试使用 scipy 相关方法来这样做,但是 scipy 方法与等效的 numpy 方法相比需要更长的时间。对于scipy,我先确定了K-nearest spheres的个数,然后通过cKDTree.query找到它们,这导致了更多的时间消耗。但是,即使省略具有常量值的第一步,它也比 numpy 方法慢(在这种情况下省略第一步是不好的)。 这与我对 scipy 空间搜索速度的预期相反。 所以,我尝试使用一些列表循环而不是一些 numpy 行用于加速使用 numba prange。 Numba 运行 代码快一点,但我相信可以优化此代码以获得更好的性能,也许通过矢量化,使用其他替代 numpy 模块或以其他方式使用 numba。由于防止可能的内存泄漏和……,我在所有球体上都使用了迭代,其中球体数量很多。

import numpy as np
import numba as nb
from scipy.spatial import cKDTree, distance

# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')     # shape: (n-spheres, )     must be loaded by np.load('a.npy') or np.loadtxt('radii_large.csv')
poss = np.load('b.npy')      # shape: (n-spheres, 3)    must be loaded by np.load('b.npy') or np.loadtxt('pos_large.csv', delimiter=',')
"""

rnd = np.random.RandomState(70)
data_volume = 200000

radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()

x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------

# @nb.jit('float64[:,::1](float64[:,::1], float64[::1])', forceobj=True, parallel=True)
def ends_gap(poss, dia_max):
    particle_corsp_overlaps = np.array([], dtype=np.float64)
    ends_ind = np.empty([1, 2], dtype=np.int64)
    """ using list looping """
    # particle_corsp_overlaps = []
    # ends_ind = []

    # for particle_idx in nb.prange(len(poss)):  # by list looping
    for particle_idx in range(len(poss)):
        unshared_idx = np.delete(np.arange(len(poss)), particle_idx)                                                    # <--- relatively high time consumer
        poss_without = poss[unshared_idx]

        """ # SCIPY method ---------------------------------------------------------------------------------------------
        nears_i_ind = cKDTree(poss_without).query_ball_point(poss[particle_idx], r=dia_max)         # <--- high time consumer
        if len(nears_i_ind) > 0:
            dist_i, dist_i_ind = cKDTree(poss_without[nears_i_ind]).query(poss[particle_idx], k=len(nears_i_ind))       # <--- high time consumer
            if not isinstance(dist_i, float):
                dist_i[dist_i_ind] = dist_i.copy()
        """  # NUMPY method --------------------------------------------------------------------------------------------
        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dia_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dia_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dia_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dia_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dia_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dia_max

        nears_i_ind = np.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
        if len(nears_i_ind) > 0:
            dist_i = distance.cdist(poss_without[nears_i_ind], poss[particle_idx][None, :]).squeeze()                   # <--- relatively high time consumer
        # """  # -------------------------------------------------------------------------------------------------------
            contact_check = dist_i - (radii[unshared_idx][nears_i_ind] + radii[particle_idx])
            connected = contact_check[contact_check <= 0]

            particle_corsp_overlaps = np.concatenate((particle_corsp_overlaps, connected))
            """ using list looping """
            # if len(connected) > 0:
            #    for value_ in connected:
            #        particle_corsp_overlaps.append(value_)

            contacts_ind = np.where([contact_check <= 0])[1]
            contacts_sec_ind = np.array(nears_i_ind)[contacts_ind]
            sphere_olps_ind = np.where((poss[:, None] == poss_without[contacts_sec_ind][None, :]).all(axis=2))[0]       # <--- high time consumer

            ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
            if particle_idx > 0:
                ends_ind = np.concatenate((ends_ind, ends_ind_mod_temp))
            else:
                ends_ind[0, 0], ends_ind[0, 1] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]
            """ using list looping """
            # for contacted_idx in sphere_olps_ind:
            #    ends_ind.append([particle_idx, contacted_idx])

    # ends_ind_org = np.array(ends_ind)  # using lists
    ends_ind_org = ends_ind
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)                                # <--- relatively high time consumer
    gap = np.array(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

在我对 23000 个球体进行的一项测试中,scipy、numpy 和 numba 辅助方法分别使用 Colab TPU 在大约 400、200 和 180 秒内完成了循环; 500.000 个球体需要 3.5 小时。这些执行时间对我的项目来说一点都不令人满意,在中等数据量中,球体的数量可能高达 1.000.000。我将在我的主代码中多次调用此代码,并寻找可以在 毫秒 内执行此代码的方法(尽可能快)。可能吗?? 如果有人能根据需要加快代码速度,我将不胜感激。

备注:


如有任何建议或解释,我将不胜感激:

  1. 在这个问题上哪种方法更快?
  2. 在这种情况下,为什么 scipy 并不比其他方法快?它对这个主题有什么帮助?
  3. 在迭代器方法和矩阵形式方法之间进行选择对我来说是一件令人困惑的事情。迭代方法使用更少的内存,并且可以由 numba 和……使用和调整,但是,我认为,对于像 numpy 和……这样的矩阵方法(这取决于内存限制),它没有用,也不能与矩阵方法相提并论……对于巨大的球体数。对于这种情况,也许我可以省略 numpy 的迭代,但我强烈猜测,由于巨大的矩阵大小操作和内存泄漏,它无法处理。

准备样本测试数据:

持有数据: 23000, 500000
半径数据: 23000, 500000
逐行速度测试日志:两个测试用例scipy method and numpy耗时

你试过了吗FLANN

此代码不能完全解决您的问题。它只是在您的 500000 点数据集中找到与每个点最近的 50 个邻居:

from pyflann import FLANN

p = np.loadtxt("pos_large.csv", delimiter=",")
flann = FLANN()
flann.build_index(pts=p)
idx, dist = flann.nn_index(qpts=p, num_neighbors=50)

最后一行在我的笔记本电脑上用了不到一秒,没有任何调整或并行化。

更新: 这个 post 的回答现在被 取代 (考虑到问题的更新)基于不同的方法提供更快的代码。


第 1 步:更好的算法

首先,构建 k-d 树需要 O(n log n) 时间,查询需要 O(log n) 时间,其中 n 是点数。所以乍一看,使用 k-d 树似乎是个好主意。但是,您的代码 为每个点 构建一个 k-d 树,导致 O(n² log n) 时间。这就是 Scipy 解决方案比其他解决方案慢的原因。问题是 Scipy 没有提供更新 k-d 树的方法。原来是updating efficiently a k-d tree appears not to be possible。希望这对您的情况不是问题:您可以 构建一棵包含所有点的 k-d 树,然后丢弃您不想出现的当前点 每个查询的结果。

此外,sphere_olps_ind 的计算在 O(n² m) 时间内运行,其中 n 是点的总数,m 是邻居的平均数(即。从 k-d 树查询中检索到的最近点)。假设没有重复点,那么 sphere_olps_ind 就等于 np.sort(contacts_sec_ind)。后者在 O(m log m) 中运行,这要好得多。

此外,在循环中使用 np.concatenate 在 Numpy 数组中附加值很慢,因为它会为每次迭代创建一个新的更大的数组。使用列表是个好主意,但是 在列表中直接附加 Numpy 数组然后调用 np.concatenate 一次要快得多.

这是结果代码:

def ends_gap(poss, dia_max):
    particle_corsp_overlaps = []
    ends_ind = [np.empty([1, 2], dtype=np.int64)]

    kdtree = cKDTree(poss)

    for particle_idx in range(len(poss)):
        # Find the nearest point including the current one and
        # then remove the current point from the output.
        # The distances can be computed directly without a new query.
        cur_point = poss[particle_idx]
        nears_i_ind = np.array(kdtree.query_ball_point(cur_point, r=dia_max), dtype=np.int64)
        assert len(nears_i_ind) > 0

        if len(nears_i_ind) <= 1:
            continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        dist_i = distance.cdist(poss[nears_i_ind], cur_point[None, :]).squeeze()

        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])
        connected = contact_check[contact_check <= 0]

        particle_corsp_overlaps.append(connected)

        contacts_ind = np.where([contact_check <= 0])[1]
        contacts_sec_ind = nears_i_ind[contacts_ind]
        sphere_olps_ind = np.sort(contacts_sec_ind)

        ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
        if particle_idx > 0:
            ends_ind.append(ends_ind_mod_temp)
        else:
            ends_ind[0][:] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]

    ends_ind_org = np.concatenate(ends_ind)
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)                                # <--- relatively high time consumer
    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

第 2 步:优化

首先,query_ball_point 调用可以在 并行 中通过向 [=117= 提供 poss 同时在所有点上完成] 方法并指定参数 workers=-1。但是,请注意,这需要更多内存。

此外,Numba 可用于显着加快计算速度。主要可以改进的部分是距离的计算和许多不必要的临时数组的创建以及[=76的使用=]Numpy 数组直接索引 而不是列表的追加(因为在 query_ball_point 调用后可以知道输出数组的有界大小)。

这是一个使用 Numba 优化代码的简单示例:

@nb.jit('(float64[:, ::1], int64[::1], int64[::1], float64)')
def compute(poss, all_neighbours, all_neighbours_sizes, dia_max):
    particle_corsp_overlaps = []
    ends_ind_lst = [np.empty((1, 2), dtype=np.int64)]
    an_offset = 0

    for particle_idx in range(len(poss)):
        cur_point = poss[particle_idx]
        cur_len = all_neighbours_sizes[particle_idx]
        nears_i_ind = all_neighbours[an_offset:an_offset+cur_len]
        an_offset += cur_len
        assert len(nears_i_ind) > 0

        if len(nears_i_ind) <= 1:
            continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        dist_i = np.empty(len(nears_i_ind), dtype=np.float64)

        # Compute the distances
        x1, y1, z1 = poss[particle_idx]
        for i in range(len(nears_i_ind)):
            x2, y2, z2 = poss[nears_i_ind[i]]
            dist_i[i] = np.sqrt((x2-x1)**2 + (y2-y1)**2 + (z2-z1)**2)

        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])
        connected = contact_check[contact_check <= 0]

        particle_corsp_overlaps.append(connected)

        contacts_ind = np.where(contact_check <= 0)
        contacts_sec_ind = nears_i_ind[contacts_ind]
        sphere_olps_ind = np.sort(contacts_sec_ind)

        ends_ind_mod_temp = np.empty((len(sphere_olps_ind), 2), dtype=np.int64)
        for i in range(len(sphere_olps_ind)):
            ends_ind_mod_temp[i, 0] = particle_idx
            ends_ind_mod_temp[i, 1] = sphere_olps_ind[i]

        if particle_idx > 0:
            ends_ind_lst.append(ends_ind_mod_temp)
        else:
            tmp = ends_ind_lst[0]
            tmp[:] = ends_ind_mod_temp[0, :]

    return particle_corsp_overlaps, ends_ind_lst

def ends_gap(poss, dia_max):
    kdtree = cKDTree(poss)
    tmp = kdtree.query_ball_point(poss, r=dia_max, workers=-1)
    all_neighbours = np.concatenate(tmp, dtype=np.int64)
    all_neighbours_sizes = np.array([len(e) for e in tmp], dtype=np.int64)
    particle_corsp_overlaps, ends_ind_lst = compute(poss, all_neighbours, all_neighbours_sizes, dia_max)
    ends_ind_org = np.concatenate(ends_ind_lst)
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)
    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

ends_gap(poss, dia_max)

性能分析

以下是我的 6 核机器(配备 i5-9600KF 处理器)在小型数据集上的性能结果:

Initial code with Scipy:             259 s
Initial default code with Numpy:     112 s
Optimized algorithm:                   1.37 s
Final optimized code:                  0.22 s

不幸的是,Scipy k-d 树 太大而无法放入内存 我机器上的大数据集。

因此,采用高效算法的 Numba 实现比初始 Numpy 实现快 ~510 倍,比初始 Scipy 实现快约 1200 倍。

Numba 代码可以进一步优化,但请注意,Numba compute 调用在我的机器上花费的时间不到 25%。 np.unique 调用是最昂贵的,但要使其更快并不容易。很大一部分时间花在了Scipy-to-Numba数据转换上,但只要使用Scipy,这段代码是必须的。因此,可以通过高级 Numba 优化改进代码(例如肯定快 2 倍),但如果您需要更快的代码,则需要 使用像 C++ 这样的本地语言 和highly-optimized 并行 k-d 树实现。我希望 very-optimized 本机代码的速度快一个数量级,但不会更多。无论实现如何,我都不敢相信在我的机器上可以在不到 10 毫秒的时间内计算出大数据集。


备注

注意gap与提供的函数不同(其他值保持不变)。然而,同样的事情发生在最初的 Scipy 方法和 Numpy 的方法之间。这似乎来自 nears_i_inddist_i 等变量的排序,这些变量未被 Scipy 定义,并以 non-trivial 方式更改 gap 结果(不仅仅是gap 的顺序)。我不确定这是初始实施的问题。因此,比较不同实现的正确性要困难得多。

forceobj 不应在生产中使用,因为文档指出这仅用于测试目的。

作为与这些不同性能相关的 and to overcome probable memory leaks, I post this answer. During my testing executions, memory usage grows up and limits the execution to some smaller data volumes (maximum 200000 by my machine and 100000 on COLAB). This problem leads to much longer runtimes than resulted runtimes by Richard. So, I opened a SciPy issue 的更新,并将一些记忆结果放在那里并进行比较。
但是,到目前为止我还没有得到任何答案,我还不清楚性能之间这些显着差异的来源!!??

Fezzani referred to another SciPy issue 使用 chunk 并准备好比较以显示 chunk 值的影响 在运行时。 St运行gely,尽管 Fezzani 的机器(Intel® Core™ i7-10700K CPU @ 3.80GHz × 16;32GiB RAM)似乎比Richard 的机器(6 核机器,i5-9600KF 处理器,16 GiB RAM 2 通道 DDR4 @ 3200MHz 达到 36~40 GiB/s),他在大数据上的执行chunk 方法至少(大约)33 秒(以避免内存泄漏)。
我无法弄清楚为什么以及哪些硬件可以帮助机器通过内存泄漏并导致像理查德那样令人满意的快速执行(也许它与 KF理查德的类型 CPU) !!??


通过在一些 related memory issues, I could guess cKDTree methods are facing this inevitable problem when data volume is huge or … and scikit-learn, perhaps, be a better choice. In this regard, based on my understanding from JaminSore answer and the referred Martelli answer 中寻找,我试图从 scikit-learn 中评估 BallTreeKDTreeBallTree 在我的案例中比 KDTree 有更好的性能(大约 1.5 到 2 倍),所以我使用它。大数据没有内存泄漏,但需要 2 分钟 (Richard 结果和我的结果现在只是时间单位不同 ;))。当数据量增加时,它 运行 比 scipy 快。在我的测试中,scipy 在较小的数据量(低内存消耗)上更快,并且随着数据量的增长,scipy 性能由于其实现行为或相关错误而落后(我还不清楚);对于我准备的 100000 个数据量,scikit-learn 执行速度快 1.5 到 2 倍。

我想使用数组是 scikit-learn 与 scipy 方法的列表相比的一大优势,后者可以从上述 Martelli answer 中导出.可能是性能不同的原因。

scikit-learn 方法 return 一个 object 类型 ndarray 里面有不同长度的数组需要进行排序以获得与主代码相同的结果。我通过将nears_i_indcode-line修改为nears_i_ind = np.sort(all_neighbours[an_offset:an_offset+cur_len]),在compute函数中应用了循环中每个元素的相关排序行为。使用 BallTreetmpall_neighbours消耗内存差不多。 注意:如果两者同名,内存消耗会减少(几乎减半)。因此,BallTree 修改后的 Richard 的 ends_gap 函数将是:

def ends_gap(poss, dia_max):
    balltree = BallTree(poss, metric='euclidean')

    # tmp = balltree.query_radius(poss, r=dia_max)
    # all_neighbours = np.concatenate(tmp, dtype=np.int64)
    all_neighbours = balltree.query_radius(poss, r=dia_max)
    all_neighbours_sizes = np.array([len(e) for e in all_neighbours], dtype=np.int64)
    all_neighbours = np.concatenate(all_neighbours, dtype=np.int64)

    particle_corsp_overlaps, ends_ind_lst = compute(poss, all_neighbours, all_neighbours_sizes)
    ends_ind_org = np.concatenate(ends_ind_lst)
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)
    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

不是multi-threaded,可以提高速度;我会努力multi-thread.
在我的机器上(i5 第一代 cpu intel core 760 @ 2.8GHz,16gb ram cl9 双通道 DDR3 ripjaws,x64 windows 系统)200000 数据量:


我提出的两种方法存在一些错误,导致不同的间隙值,Richard 在注释部分中提到了这一点。为了产生相同的结果,必须为 中的 nears_i_ind 添加 return_sorted=True 优化算法 并且 ends_indends_ind_lst 更改为 list 除了删除两个代码中的 if-else 语句:

优化算法:

def ends_gap(poss, dia_max):
    particle_corsp_overlaps = []
    ends_ind = []                                                                       # <------- this line is modified

    kdtree = cKDTree(poss)

    for particle_idx in range(len(poss)):
        cur_point = poss[particle_idx]
        nears_i_ind = np.array(kdtree.query_ball_point(cur_point, r=dia_max, return_sorted=True), dtype=np.int64)       # <------- this line is modified
        assert len(nears_i_ind) > 0

        if len(nears_i_ind) <= 1:
            continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        dist_i = distance.cdist(poss[nears_i_ind], cur_point[None, :]).squeeze()

        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])
        connected = contact_check[contact_check <= 0]

        particle_corsp_overlaps.append(connected)

        contacts_ind = np.where([contact_check <= 0])[1]
        contacts_sec_ind = nears_i_ind[contacts_ind]
        sphere_olps_ind = np.sort(contacts_sec_ind)

        ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
        ends_ind.append(ends_ind_mod_temp)                                              # <------- this line is modified

    ends_ind_org = np.concatenate(ends_ind)
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)
    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

Numba 最终优化代码:

@nb.jit('(float64[:, ::1], int64[::1], int64[::1])')
def compute(poss, all_neighbours, all_neighbours_sizes):
    particle_corsp_overlaps = []
    ends_ind_lst = []                                                                   # <------- this line is modified
    an_offset = 0

    for particle_idx in range(len(poss)):
        cur_len = all_neighbours_sizes[particle_idx]
        nears_i_ind = np.sort(all_neighbours[an_offset:an_offset+cur_len])              # <------- this line is modified
        an_offset += cur_len
        assert len(nears_i_ind) > 0

        if len(nears_i_ind) <= 1:
            continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        dist_i = np.empty(len(nears_i_ind), dtype=np.float64)

        x1, y1, z1 = poss[particle_idx]
        for i in range(len(nears_i_ind)):
            x2, y2, z2 = poss[nears_i_ind[i]]
            dist_i[i] = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2)

        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])
        connected = contact_check[contact_check <= 0]
        particle_corsp_overlaps.append(connected)

        contacts_ind = np.where(contact_check <= 0)
        contacts_sec_ind = nears_i_ind[contacts_ind]
        sphere_olps_ind = np.sort(contacts_sec_ind)

        ends_ind_mod_temp = np.empty((len(sphere_olps_ind), 2), dtype=np.int64)
        for i in range(len(sphere_olps_ind)):
            ends_ind_mod_temp[i, 0] = particle_idx
            ends_ind_mod_temp[i, 1] = sphere_olps_ind[i]
        ends_ind_lst.append(ends_ind_mod_temp)                                          # <------- this line is modified

    return particle_corsp_overlaps, ends_ind_lst


def ends_gap(poss, dia_max):
    balltree = BallTree(poss, metric='euclidean')                                       # <------- new code
    all_neighbours = balltree.query_radius(poss, r=dia_max)                             # <------- new code and modified
    all_neighbours_sizes = np.array([len(e) for e in all_neighbours], dtype=np.int64)   # <------- this line is modified
    all_neighbours = np.concatenate(all_neighbours, dtype=np.int64)                     # <------- this line is modified
    particle_corsp_overlaps, ends_ind_lst = compute(poss, all_neighbours, all_neighbours_sizes)
    ends_ind_org = np.concatenate(ends_ind_lst)
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)
    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

在我的机器上大约有 550000 个数据量:

通过将查询半径固定为最大球体半径的两倍,您将创建大量需要过滤掉的虚假“碰撞”。

下面的 Python 通过使用第四个维度来提高 kd-tree 查询的选择性,相对于您的答案实现了显着的加速。每个半径为 r 的欧几里德球都是 over-approximated 由一个半径为 r√d 的 L1 球组成,其中 d 是维度(此处为 3)。 3d 中 L1 球碰撞的测试变成了 4d 中点在固定 L1 距离内的测试。

如果您切换到较低级别的语言,您可以通过更改 kd-tree 实现以使用 L2+L1 度量组合来避免单独的过滤步骤。

import numpy as np
from scipy import spatial
from timeit import default_timer


def load_data():
    centers = np.loadtxt("pos_large.csv", delimiter=",")
    radii = np.loadtxt("radii_large.csv")
    assert radii.shape + (3,) == centers.shape
    return centers, radii


def count_contacts(centers, radii):
    scaled_centers = centers / np.sqrt(centers.shape[1])
    max_radius = radii.max()
    tree = spatial.cKDTree(np.c_[scaled_centers, max_radius - radii])
    count = 0
    for i, x in enumerate(np.c_[scaled_centers, radii - max_radius]):
        for j in tree.query_ball_point(x, r=2 * max_radius, p=1):
            d = centers[i] - centers[j]
            r = radii[i] + radii[j]
            if i < j and np.inner(d, d) <= r * r:
                count += 1
    return count


def main():
    centers, radii = load_data()
    start = default_timer()
    print(count_contacts(centers, radii))
    end = default_timer()
    print(end - start)


if __name__ == "__main__":
    main()

根据之前的答案,我设计了一个高效的算法,与之前的算法相比更少的内存占用更快(尤其是在大数据集上)。话虽这么说,该算法非常复杂,并突破了 Python 和 Numba 的极限。

之前算法的关键问题是他们设置了一个dia_max的阈值,这个阈值比实际需要的大很多。实际上,dia_max 设置为最大可能的 redius,以确保不会错过任何重叠。问题是大数据集包含大小非常不同的球,其中一些球很大。这意味着以前的算法是在许多小球周围获取非常大的半径。 结果是每个球有数千个邻居要检查,而只有少数可以真正重叠

有效解决此问题的一种解决方案是 根据大小 将球分成不同的组。这个想法是首先根据 radii 对球进行排序,然后将排序后的球分成两组,然后独立查询每对可能的组之间的邻居,然后合并数据以应用之前的算法(有一些额外的优化) .更具体地说,查询适用于小球与大球、小球与其他小球、大球与其他大球以及大球与小球之间的查询。

另一个加快速度的关键点是使用 joblib 并行请求不同的邻居查询。这个解决方案远非完美,因为 BallTree 对象需要被复制,这是低效的,但这是强制性的,因为并行性目前在 CPython 中完成(即 GIL,pickling 等)。 ).使用支持并行请求的包可以绕过 CPython 的这种固有限制,但这样做的现有包似乎没有提供足够有用的接口来解决这个问题,或者没有优化到实际有用。

最后,可以通过删除几乎所有 非常昂贵的(隐式)数组分配 来强烈优化 Numba 代码。使用针对小数组优化的 in-place 排序算法 也可以显着缩短执行时间(主要是因为 Numba 的默认实现执行多个昂贵的分配并且没有针对小数组进行优化)。此外,最终的 np.unique 操作可以用一个基本循环完全重写,因为主循环迭代 ID 递增的球(因此已经排序)。

这是结果代码:

import numpy as np
import numba as nb
from sklearn.neighbors import BallTree
from joblib import Parallel, delayed

def flatten_neighbours(arr):
    sizes = np.fromiter(map(len, arr), count=len(arr), dtype=np.int64)
    values = np.concatenate(arr, dtype=np.int64)
    return sizes, values

@delayed
def find_neighbours(searched_pts, ref_pts, max_dist):
    balltree = BallTree(ref_pts, leaf_size=16, metric='euclidean')
    res = balltree.query_radius(searched_pts, r=max_dist)
    return flatten_neighbours(res)

def vstack_neighbours(top_infos, bottom_infos):
    top_sizes, top_values = top_infos
    bottom_sizes, bottom_values = bottom_infos
    return np.concatenate([top_sizes, bottom_sizes]), np.concatenate([top_values, bottom_values])

@nb.njit('(Tuple([int64[::1],int64[::1]]), Tuple([int64[::1],int64[::1]]), int64)')
def hstack_neighbours(left_infos, right_infos, offset):
    left_sizes, left_values = left_infos
    right_sizes, right_values = right_infos
    n = left_sizes.size
    out_sizes = np.empty(n, dtype=np.int64)
    out_values = np.empty(left_values.size + right_values.size, dtype=np.int64)
    left_cur, right_cur, out_cur = 0, 0, 0
    right_values += offset
    for i in range(n):
        left, right = left_sizes[i], right_sizes[i]
        full = left + right
        out_values[out_cur:out_cur+left] = left_values[left_cur:left_cur+left]
        out_values[out_cur+left:out_cur+full] = right_values[right_cur:right_cur+right]
        out_sizes[i] = full
        left_cur += left
        right_cur += right
        out_cur += full
    return out_sizes, out_values

@nb.njit('(int64[::1], int64[::1], int64[::1], int64[::1])')
def reorder_neighbours(in_sizes, in_values, index, reverse_index):
    n = reverse_index.size
    out_sizes = np.empty_like(in_sizes)
    out_values = np.empty_like(in_values)
    in_offsets = np.empty_like(in_sizes)
    s, cur = 0, 0

    for i in range(n):
        in_offsets[i] = s
        s += in_sizes[i]

    for i in range(n):
        in_ind = reverse_index[i]
        size = in_sizes[in_ind]
        in_offset = in_offsets[in_ind]
        out_sizes[i] = size
        for j in range(size):
            out_values[cur+j] = index[in_values[in_offset+j]]
        cur += size

    return out_sizes, out_values

@nb.njit
def small_inplace_sort(arr):
    if len(arr) < 80:
        # Basic insertion sort
        i = 1
        while i < len(arr):
            x = arr[i]
            j = i - 1
            while j >= 0 and arr[j] > x:
                arr[j+1] = arr[j]
                j = j - 1
            arr[j+1] = x
            i += 1
    else:
        arr.sort()

@nb.jit('(float64[:, ::1], float64[::1], int64[::1], int64[::1])')
def compute(poss, radii, neighbours_sizes, neighbours_values):
    n, m = neighbours_sizes.size, np.max(neighbours_sizes)

    # Big buffers allocated with the maximum size.
    # Thank to virtual memory, it does not take more memory can actually needed.
    particle_corsp_overlaps = np.empty(neighbours_values.size, dtype=np.float64)
    ends_ind_org = np.empty((neighbours_values.size, 2), dtype=np.float64)

    in_offset = 0
    out_offset = 0

    buff1 = np.empty(m, dtype=np.int64)
    buff2 = np.empty(m, dtype=np.float64)
    buff3 = np.empty(m, dtype=np.float64)

    for particle_idx in range(n):
        size = neighbours_sizes[particle_idx]
        cur = 0

        for i in range(size):
            value = neighbours_values[in_offset+i]
            if value != particle_idx:
                buff1[cur] = value
                cur += 1

        nears_i_ind = buff1[0:cur]
        small_inplace_sort(nears_i_ind)  # Note: bottleneck of this function
        in_offset += size

        if len(nears_i_ind) == 0:
            continue

        x1, y1, z1 = poss[particle_idx]
        cur = 0

        for i in range(len(nears_i_ind)):
            index = nears_i_ind[i]
            x2, y2, z2 = poss[index]
            dist = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2)
            contact_check = dist - (radii[index] + radii[particle_idx])
            if contact_check <= 0.0:
                buff2[cur] = contact_check
                buff3[cur] = index
                cur += 1

        particle_corsp_overlaps[out_offset:out_offset+cur] = buff2[0:cur]

        contacts_sec_ind = buff3[0:cur]
        small_inplace_sort(contacts_sec_ind)
        sphere_olps_ind = contacts_sec_ind

        for i in range(cur):
            ends_ind_org[out_offset+i, 0] = particle_idx
            ends_ind_org[out_offset+i, 1] = sphere_olps_ind[i]

        out_offset += cur

    # Truncate the views to their real size
    particle_corsp_overlaps = particle_corsp_overlaps[:out_offset]
    ends_ind_org = ends_ind_org[:out_offset]

    assert len(ends_ind_org) % 2 == 0
    size = len(ends_ind_org)//2
    ends_ind = np.empty((size,2), dtype=np.int64)
    ends_ind_idx = np.empty(size, dtype=np.int64)
    gap = np.empty(size, dtype=np.float64)
    cur = 0

    # Find efficiently duplicates (replace np.unique+np.sort)
    for i in range(len(ends_ind_org)):
        left, right = ends_ind_org[i]
        if left < right:
            ends_ind[cur, 0] = left
            ends_ind[cur, 1] = right
            ends_ind_idx[cur] = i
            gap[cur] = particle_corsp_overlaps[i]
            cur += 1

    return gap, ends_ind, ends_ind_idx, ends_ind_org

def ends_gap(poss, radii):
    assert poss.size >= 1

    # Sort the balls
    index = np.argsort(radii)
    reverse_index = np.empty(index.size, np.int64)
    reverse_index[index] = np.arange(index.size, dtype=np.int64)
    sorted_poss = poss[index]
    sorted_radii = radii[index]

    # Split them in two groups: the small and the big ones
    split_ind = len(radii) * 3 // 4
    small_poss, big_poss = np.split(sorted_poss, [split_ind])
    small_radii, big_radii = np.split(sorted_radii, [split_ind])
    max_small_radii = sorted_radii[max(split_ind, 0)]
    max_big_radii = sorted_radii[-1]

    # Find the neighbours in parallel
    result = Parallel(n_jobs=4, backend='threading')([
        find_neighbours(small_poss, small_poss, small_radii+max_small_radii),
        find_neighbours(small_poss, big_poss,   small_radii+max_big_radii  ),
        find_neighbours(big_poss,   small_poss, big_radii+max_small_radii  ),
        find_neighbours(big_poss,   big_poss,   big_radii+max_big_radii    )
    ])
    small_small_neighbours = result[0]
    small_big_neighbours = result[1]
    big_small_neighbours = result[2]
    big_big_neighbours = result[3]

    # Merge the (segmented) arrays in a big one
    neighbours_sizes, neighbours_values = vstack_neighbours(
        hstack_neighbours(small_small_neighbours, small_big_neighbours, split_ind),
        hstack_neighbours(big_small_neighbours, big_big_neighbours, split_ind)
    )

    # Reverse the indices.
    # Note that the results in `neighbours_values` associated to 
    # `neighbours_sizes[i]` are subsets of `query_radius([poss[i]], r=dia_max)`
    # on a `BallTree(poss)`.
    res = reorder_neighbours(neighbours_sizes, neighbours_values, index, reverse_index)
    neighbours_sizes, neighbours_values = res

    # Finally compute the neighbours with a method similar to the 
    # previous one, but using a much faster optimized code.
    return compute(poss, radii, neighbours_sizes, neighbours_values)

result = ends_gap(poss, radii)

这是结果(仍然在同一台 i5-9600KF 机器上):

Small dataset:
 - Reference optimized Numba code:    256 ms
 - This highly-optimized Numba code:   82 ms

Big dataset:
 - Reference optimized Numba code:    42.7 s  (take about 7~8 GiB of RAM)
 - This highly-optimized Numba code:   4.2 s  (take about  1  GiB of RAM)

因此,新算法在小数据集上的速度提高了约 3.1 倍(加上之前的优化),在大数据集上的速度提高了约 10 倍!这比最初发布的算法快 3 个数量级。

请注意,80% 的时间花在 BallTree 查询上(这已经大部分是并行的)。主要的 Numba 计算功能仅占用 12% 的时间,超过 75% 的时间用于对输入索引进行排序。因此,邻域搜索显然是瓶颈。可以通过将当前查询拆分为多个较小的查询来稍微改进,但这会使代码更加复杂,以实现相对较小的改进(例如,快 1.5 倍)。请注意,更复杂的代码更难维护,修改 bug-prone。因此,我认为转向母语以克服 Python 的限制是提高性能的最佳解决方案。话虽如此,编写更快的本机代码来解决这个问题远非易事(除非你找到好的 k-d 树、八叉树或球树库)。尽管如此,它肯定比进一步优化这段代码要好。


分析

分析表明,在 scikit-learn 的 BallTree 中,至少有 50% 的时间花费在未优化的标量循环中,这些循环可以使用像 AVX-2(和循环展开)这样的 SIMD 指令,大约是 4 倍快点。此外,一些 multi-threading 问题也是可见的(顶部的 4 个线程是 joblib worker,light-green 部分是空闲时间):

这说明这个实现是sub-optimal。轻松缩短执行时间的一种可能方法是优化 scikit-learn BallTree 实现的热循环。另一种策略可能是尝试更有效地使用线程(可能通过在 scikit-learn 模块的某些部分释放 GIL)。

由于scikit-learn的BallTreeclass是written in CythonBallTree是基于DKTree本身基于BinaryTree)。您可以尝试在您的机器上重建包并 简单地调整编译器优化 。使用参数 -O3 -march=native -ffast-math 应该使编译器能够使用更快的 SIMD 指令和更积极的优化,从而显着加快速度。请注意,使用 -ffast-mathunsae 因为它假定 Scikit 的代码永远不会使用 NaNInf-0 值(否则结果完全未定义)并且 floating-point 数字操作是关联的(导致不同的结果)。也就是说,这样的选项对于改进数字代码的自动矢量化至关重要。

对于GIL,可以看出它是在query_radius函数中释放的,但似乎并不是BallTree的构造函数。也许,最简单的解决方案是像 Scipy 那样实现 query/query_radius 的并行版本。