scipy pdist 只得到两个最近的邻居
scipy pdist getting only two closest neighbors
我一直在用 scipy 计算成对距离,我正在尝试获取到两个最近邻居的距离。我目前的工作解决方案是:
dists = squareform(pdist(xs.todense()))
dists = np.sort(dists, axis=1)[:, 1:3]
但是,在我的案例中,方形方法在空间上非常昂贵并且有些多余。我只需要两个最近的距离,而不是全部。有简单的解决方法吗?
谢谢!
线性索引与上三角距离矩阵的(i,j)之间的关系不能直接或容易地反转(参见squareform doc中的注释2)。
然而,通过遍历所有索引可以获得反比关系:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
def inverse_condensed_indices(idx, n):
k = 0
for i in range(n):
for j in range(i+1, n):
if k == idx:
return (i, j)
k +=1
else:
return None
# test
points = np.random.rand(8, 2)
distances = pdist(points)
sorted_idx = np.argsort(distances)
n = points.shape[0]
ij = [inverse_condensed_indices(idx, n)
for idx in sorted_idx[:2]]
# graph
plt.figure(figsize=(5, 5))
for i, j in ij:
x = [points[i, 0], points[j, 0]]
y = [points[i, 1], points[j, 1]]
plt.plot(x, y, '-', color='red');
plt.plot(points[:, 0], points[:, 1], '.', color='black');
plt.xlim(0, 1); plt.ylim(0, 1);
好像比用squareform
快一点:
%timeit squareform(range(28))
# 9.23 µs ± 63 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit inverse_condensed_indices(27, 8)
# 2.38 µs ± 25 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
我一直在用 scipy 计算成对距离,我正在尝试获取到两个最近邻居的距离。我目前的工作解决方案是:
dists = squareform(pdist(xs.todense()))
dists = np.sort(dists, axis=1)[:, 1:3]
但是,在我的案例中,方形方法在空间上非常昂贵并且有些多余。我只需要两个最近的距离,而不是全部。有简单的解决方法吗?
谢谢!
线性索引与上三角距离矩阵的(i,j)之间的关系不能直接或容易地反转(参见squareform doc中的注释2)。
然而,通过遍历所有索引可以获得反比关系:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
def inverse_condensed_indices(idx, n):
k = 0
for i in range(n):
for j in range(i+1, n):
if k == idx:
return (i, j)
k +=1
else:
return None
# test
points = np.random.rand(8, 2)
distances = pdist(points)
sorted_idx = np.argsort(distances)
n = points.shape[0]
ij = [inverse_condensed_indices(idx, n)
for idx in sorted_idx[:2]]
# graph
plt.figure(figsize=(5, 5))
for i, j in ij:
x = [points[i, 0], points[j, 0]]
y = [points[i, 1], points[j, 1]]
plt.plot(x, y, '-', color='red');
plt.plot(points[:, 0], points[:, 1], '.', color='black');
plt.xlim(0, 1); plt.ylim(0, 1);
好像比用squareform
快一点:
%timeit squareform(range(28))
# 9.23 µs ± 63 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit inverse_condensed_indices(27, 8)
# 2.38 µs ± 25 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)