高效的粒子对相互作用计算

Efficient Particle-Pair Interactions Calculation

我有一个 N 体模拟,它为模拟中的多个时间步长生成粒子位置列表。对于给定的帧,我想生成一个粒子索引对列表 (i, j),使得 dist(p[i], p[j]) < masking_radius。本质上,我正在创建一个 "interaction" 对列表,其中这些对彼此之间的距离在一定范围内。我当前的实现看起来像这样:

interaction_pairs = []

# going through each unique pair (order doesn't matter)
for i in range(num_particles):
    for j in range(i + 1, num_particles):
        if dist(p[i], p[j]) < masking_radius:
            interaction_pairs.append((i,j))

由于粒子数量众多,此过程需要很长时间(每次测试 > 1 小时),并且严重限制了我需要对数据执行的操作。我想知道是否有更有效的方法来构建数据,这样计算这些对会更有效,而不是比较每个可能的粒子组合。我正在研究 KDTrees,但我想不出一种方法来利用它们更有效地进行计算。感谢任何帮助,谢谢!

由于您使用的是 python,sklearn 有多个用于查找最近邻的实现: http://scikit-learn.org/stable/modules/neighbors.html

提供了KDTree和Balltree

关于KDTree的要点就是把你所有的粒子都push到KDTree中,然后对每个粒子进行查询:"give me all particles in range X"。 KDtree 通常比暴力搜索更快。 您可以在此处阅读更多示例:https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/kdtrees.pdf

如果您使用的是 2D 或 3D space,那么另一种选择是将 space 切割成大网格(掩蔽半径的单元格大小)并将每个粒子分配到一个网格单元格中.然后你可以通过检查相邻的细胞来找到可能的相互作用候选者(但你也必须做距离检查,但粒子对要少得多)。

这里有一个相当简单的技术,使用简单 Python 可以减少所需的比较次数。

我们首先沿 X、Y 或 Z 轴(由以下代码中的 axis 选择)对点进行排序。假设我们选择 X 轴。然后我们像您的代码一样遍历点对,但是当我们找到距离大于 masking_radius 的点对时,我们测试它们的 X 坐标的差异是否也大于 masking_radius。如果是,那么我们可以退出内部 j 循环,因为所有具有更大 j 的点都具有更大的 X 坐标。

我的 dist2 函数计算平方距离。这比计算实际距离更快,因为计算平方根相对较慢。

我还包含了与您的代码行为相似的代码,即它会测试每一对点,以进行速度比较;它还用于检查快速代码是否正确。 ;)

from random import seed, uniform
from operator import itemgetter

seed(42)

# Make some fake data
def make_point(hi=10.0):
    return [uniform(-hi, hi) for _ in range(3)]

psize = 1000
points = [make_point() for _ in range(psize)]

masking_radius = 4.0
masking_radius2 = masking_radius ** 2

def dist2(p, q):
    return (p[0] - q[0])**2 + (p[1] - q[1])**2 + (p[2] - q[2])**2

pair_count = 0
test_count = 0

do_fast = 1
if do_fast:
    # Sort the points on one axis
    axis = 0
    points.sort(key=itemgetter(axis))

    # Fast
    for i, p in enumerate(points):
        left, right = i - 1, i + 1
        for j in range(i + 1, psize):
            test_count += 1
            q = points[j]
            if dist2(p, q) < masking_radius2:
                #interaction_pairs.append((i, j))
                pair_count += 1
            elif q[axis] - p[axis] >= masking_radius:
                break

        if i % 100 == 0:
            print('\r {:3} '.format(i), flush=True, end='')

    total_pairs = psize * (psize - 1) // 2
    print('\r {} / {} tests'.format(test_count, total_pairs))

else:
    # Slow
    for i, p in enumerate(points):
        for j in range(i+1, psize):
            q = points[j]
            if dist2(p, q) < masking_radius2:
                #interaction_pairs.append((i, j))
                pair_count += 1

        if i % 100 == 0:
            print('\r {:3} '.format(i), flush=True, end='')

print('\n', pair_count, 'pairs')

输出do_fast = 1

 181937 / 499500 tests

 13295 pairs

输出do_fast = 0

 13295 pairs

当然,如果大多数点对都在彼此的 masking_radius 范围内,则使用此技术不会有太大好处。对点进行排序会增加一点时间,但是 Python 的 TimSort 相当高效,尤其是在数据已经部分排序的情况下,因此如果 masking_radius 足够小,您应该会看到明显的改进在速度。