使用 pairwise_distances_chunked 计算最近邻搜索

Using pairwise_distances_chunked to compute nearest neighbor search

我有一个细长的数据矩阵(大小:250,000 x 10),我将其表示为 X。我还有一个矢量 p 来测量我的数据点的质量。我的目标是为数据矩阵 X 中的每一行 x 计算以下函数:

r(x) = 分钟{ ||x-y|| | p[y]>p[x], X 中的 y }

在较小的数据集上,我将使用 sklearn.metrics.pairwise_distances 来预先计算距离,如下所示:

from sklearn import metrics
n = len(X);

D_full = metrics.pairwise_distances(X);
r = np.zeros((n,1));
for i in range(n):
    r[i] = (D_full[i,p>p[i]]).min();

但是,上述方法占用大量内存,因为我需要存储 D_full:一个完整​​的 n x n 矩阵。看起来 sklearn.metrics.pairwise_distances_chunked 可能是解决这类问题的好工具,因为距离矩阵一次只存储一个块。我希望在如何使用它方面得到一些帮助,因为我目前不熟悉生成器对象。假设我调用以下内容:

from sklearn import metrics
D = metrics.pairwise_distances_chunked(X);
D_chunk = next(D)

以上代码生成 D(一个生成器对象)和 D_chunk(一个 536 x n 数组)。 D_chunk 是否对应于我之前方法中矩阵 D_full 的前 536 行?如果是,next(D_chunk) 是否对应接下来的 536 行?

非常感谢您的帮助。

这是一个可能的解决方案的概要,但缺少详细信息。简而言之,我会执行以下操作:

创建一个 BallTree 来查询,并初始化大小为 250000 的 min_quality_distance,比如用零。

对于k=2

  1. 对于每个向量,找到最近的 k 个邻居(包括它自己)。
  2. 如果找到最远距离在 k 范围内的矢量,并且质量足够,则更新该点的 min_quality_distance
  3. 剩下的,重复k=k+1

在每次迭代中,我们必须查询更少的向量。这个想法是,在每次迭代中,您都会蚕食一些具有正确条件的最近邻居,并且每一步都会更容易。 (50% 更容易?)我将展示如何进行第一次迭代,这样应该可以构建循环。

你可以做到;

import numpy as np
size = 250000

X = np.random.random( size=(size,10))
p = np.random.random( size=size)

并用

创建一个 BallTree
from sklearn.neighbors import BallTree

tree = BallTree(X, leaf_size=10, metric='minkowski')

并使用(这大约需要 5 分钟。)

对其进行第一次迭代查询
k_nearest = 2

distances, indici = tree.query(X, k=k_nearest, return_distance=True, dualtree=False, sort_results=True)

最近k内最远点的标记为

most_far_away_indici = indici[:,-1:]

及其品质

p[most_far_away_indici]

所以我们可以

quality_closeby = p[most_far_away_indici]

并检查它是否足够

indici_sufficient_quality = quality_closeby > np.expand_dims(p, axis=1)

我们有

found_closeby = np.all( indici_sufficient_quality, axis=1 )

这是真的,我们在附近找到了足够的质量。

我们可以用

更新向量
distances_nearby = distances[:,-1:]

rx = np.zeros(size)
rx[found_closeby] = distances_nearby[found_closeby][:,0]

我们现在需要注意剩下的那些我们不走运的地方,这些是

~found_closeby

所以

indici_not_found = ~found_closeby

distances, indici = tree.query(X[indici_not_found], k=3, return_distance=True, dualtree=False, sort_results=True)

等..

我确信前几个循环需要几分钟,但经过几次迭代后,速度将很快达到秒级。

这是 np.argwhere() 等的一个小练习,以确保正确的标记得到更新。

它可能不是最快的,但它是一种可行的方法。

由于无法知道某个块的尺寸,我建议使用 np.ones_like 而不是 np.zeros。