生成具有每行重置的最低 N 值位置的掩码数组
Generate mask array with lowest N valued positions reset per row
给定一个二维距离数组,使用 argsort 生成一个索引数组,其中第一个元素是行中最小值的索引。仅对前 K 列使用索引 select,例如 K = 3。
position = np.random.randint(100, size=(5, 5))
array([[36, 63, 3, 78, 98],
[75, 86, 63, 61, 79],
[21, 12, 72, 27, 23],
[38, 16, 17, 88, 29],
[93, 37, 48, 88, 10]])
idx = position.argsort()
array([[2, 0, 1, 3, 4],
[3, 2, 0, 4, 1],
[1, 0, 4, 3, 2],
[1, 2, 4, 0, 3],
[4, 1, 2, 3, 0]])
idx[:,0:3]
array([[2, 0, 1],
[3, 2, 0],
[1, 0, 4],
[1, 2, 4],
[4, 1, 2]])
然后我想做的是创建一个掩码数组,当应用于原始位置数组时 returns 只有产生 k 个最短距离的索引。
我将此方法基于我发现的一些适用于一维数组的代码。
# https://glowingpython.blogspot.co.uk/2012/04/k-nearest-neighbor-search.html
from numpy import random, argsort, sqrt
from matplotlib import pyplot as plt
def knn_search(x, D, K):
""" find K nearest neighbours of data among D """
ndata = D.shape[1]
K = K if K < ndata else ndata
# euclidean distances from the other points
sqd = sqrt(((D - x[:, :ndata]) ** 2).sum(axis=0))
idx = argsort(sqd) # sorting
# return the indexes of K nearest neighbours
return idx[:K]
# knn_search test
data = random.rand(2, 5) # random dataset
x = random.rand(2, 1) # query point
# performing the search
neig_idx = knn_search(x, data, 2)
figure = plt.figure()
plt.scatter(data[0,:], data[1,:])
plt.scatter(x[0], x[1], c='g')
plt.scatter(data[0, neig_idx], data[1, neig_idx], c='r', marker = 'o')
plt.show()
这是一种方法 -
N = 3 # number of points to be set as False per row
# Slice out the first N cols per row
k_idx = idx[:,:N]
# Initialize output array
out = np.ones(position.shape, dtype=bool)
# Index into output with k_idx as col indices to reset
out[np.arange(k_idx.shape[0])[:,None], k_idx] = 0
最后一步涉及 advanced-indexing
,如果您是 NumPy 的新手,这可能是一大步,但基本上我们在这里使用 k_idx
对列进行索引,并且我们正在形成索引元组索引到范围数组为 np.arange(k_idx.shape[0])[:,None]
的行。有关 advanced-indexing
.
的更多信息
我们可以通过使用 np.argpartition
而不是 argsort
来提高性能,就像这样 -
k_idx = np.argpartition(position, N)[:,:N]
将每行最低 3
个元素设置为 False -
的案例的示例输入、输出
In [227]: position
Out[227]:
array([[36, 63, 3, 78, 98],
[75, 86, 63, 61, 79],
[21, 12, 72, 27, 23],
[38, 16, 17, 88, 29],
[93, 37, 48, 88, 10]])
In [228]: out
Out[228]:
array([[False, False, False, True, True],
[False, True, False, False, True],
[False, False, True, True, False],
[ True, False, False, True, False],
[ True, False, False, True, False]], dtype=bool)
给定一个二维距离数组,使用 argsort 生成一个索引数组,其中第一个元素是行中最小值的索引。仅对前 K 列使用索引 select,例如 K = 3。
position = np.random.randint(100, size=(5, 5))
array([[36, 63, 3, 78, 98],
[75, 86, 63, 61, 79],
[21, 12, 72, 27, 23],
[38, 16, 17, 88, 29],
[93, 37, 48, 88, 10]])
idx = position.argsort()
array([[2, 0, 1, 3, 4],
[3, 2, 0, 4, 1],
[1, 0, 4, 3, 2],
[1, 2, 4, 0, 3],
[4, 1, 2, 3, 0]])
idx[:,0:3]
array([[2, 0, 1],
[3, 2, 0],
[1, 0, 4],
[1, 2, 4],
[4, 1, 2]])
然后我想做的是创建一个掩码数组,当应用于原始位置数组时 returns 只有产生 k 个最短距离的索引。
我将此方法基于我发现的一些适用于一维数组的代码。
# https://glowingpython.blogspot.co.uk/2012/04/k-nearest-neighbor-search.html
from numpy import random, argsort, sqrt
from matplotlib import pyplot as plt
def knn_search(x, D, K):
""" find K nearest neighbours of data among D """
ndata = D.shape[1]
K = K if K < ndata else ndata
# euclidean distances from the other points
sqd = sqrt(((D - x[:, :ndata]) ** 2).sum(axis=0))
idx = argsort(sqd) # sorting
# return the indexes of K nearest neighbours
return idx[:K]
# knn_search test
data = random.rand(2, 5) # random dataset
x = random.rand(2, 1) # query point
# performing the search
neig_idx = knn_search(x, data, 2)
figure = plt.figure()
plt.scatter(data[0,:], data[1,:])
plt.scatter(x[0], x[1], c='g')
plt.scatter(data[0, neig_idx], data[1, neig_idx], c='r', marker = 'o')
plt.show()
这是一种方法 -
N = 3 # number of points to be set as False per row
# Slice out the first N cols per row
k_idx = idx[:,:N]
# Initialize output array
out = np.ones(position.shape, dtype=bool)
# Index into output with k_idx as col indices to reset
out[np.arange(k_idx.shape[0])[:,None], k_idx] = 0
最后一步涉及 advanced-indexing
,如果您是 NumPy 的新手,这可能是一大步,但基本上我们在这里使用 k_idx
对列进行索引,并且我们正在形成索引元组索引到范围数组为 np.arange(k_idx.shape[0])[:,None]
的行。有关 advanced-indexing
.
我们可以通过使用 np.argpartition
而不是 argsort
来提高性能,就像这样 -
k_idx = np.argpartition(position, N)[:,:N]
将每行最低 3
个元素设置为 False -
In [227]: position
Out[227]:
array([[36, 63, 3, 78, 98],
[75, 86, 63, 61, 79],
[21, 12, 72, 27, 23],
[38, 16, 17, 88, 29],
[93, 37, 48, 88, 10]])
In [228]: out
Out[228]:
array([[False, False, False, True, True],
[False, True, False, False, True],
[False, False, True, True, False],
[ True, False, False, True, False],
[ True, False, False, True, False]], dtype=bool)