当坐标保存在对象中时,使用 python 中的 kd-tree 查找 k 个最近的邻居
Find k nearest neighbors using kd-tree in python when coordinates are held in objects
我需要为集合中的每个对象找到 k
个最近的邻居。每个对象都有其坐标作为属性。
为了解决这个任务,我正在尝试使用 scipy
中的 spatial.KDTree
。如果我使用列表或元组来表示一个点,它工作正常,但它不适用于对象。
我在 class 中实现了 __getitem__
和 __len__
方法,但是 KDTree
实现要求我的对象提供不存在的坐标轴(比如二维点的第 3 坐标) ).
这是重现问题的简单脚本:
from scipy import spatial
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
else:
raise Exception('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)
print(result)
我没有必要使用这个具体的实现或库甚至算法,但要求是处理对象。
P.S。我可以将 id
字段添加到每个对象并将所有坐标移动到单独的数组中,其中索引是对象 id
。但我还是想尽可能避免这种做法。
class 可能需要访问对象的切片。但是根据您的定义,不可能使用切片(尝试Unit(6, 6)[:]
,它会抛出相同的错误)。
处理此问题的一种方法是将 x 和 y 变量保存在列表中:
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
self.data = [x, y]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return 2
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
result = tree.query(Unit(6,6), 3)
print(result)
(array([1.41421356, 2.82842712, 4.24264069]), array([4, 3, 2]))
docs for scipy.spatial.KDTree
state that the data
parameter should be array_like
which generally means "convertible to a numpy array." And indeed, the first line of initialization tries to convert the data to a numpy array, as you can see in the source code:
class KDTree(object):
""" ... """
def __init__(self, data, leafsize=10):
self.data = np.asarray(data)
所以你想要实现的是一个对象,这样它们的列表就可以很好地转换为一个 numpy 数组。这是 ,因为 numpy 尝试了很多方法来将你的对象变成一个数组。但是,包含许多相同长度序列的可迭代绝对符合条件。
你的 Unit
对象基本上是一个序列,因为它实现了 __len__
和 __getitem__
以及从 0 开始的连续整数索引。 Python 知道你的序列何时结束从它击中 IndexError
。但是您的 __getitem__
却在错误的索引上引发了 Exception
。因此,从这两种方法提供顺序迭代的正常机制中断了。相反,提高一个 IndexError
,你会很好地转换:
class Unit:
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
raise IndexError('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
现在我们可以毫无问题地检查这些转换成 numpy 数组的列表:
In [5]: np.array([Unit(1, 1), Unit(2, 2), Unit(3, 3), Unit(4, 4), Unit(5, 5)])
Out[5]:
array([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]])
所以,我们现在初始化 KDTree
应该没有问题。这就是为什么如果您将坐标存储在一个内部列表中并且只是将 __getitem__
推迟到该列表,或者只是将您的坐标视为一些简单的序列(如列表或元组),那么您会没事的。
像这样使用简单 类 的更简单的方法是使用 namedtuples
或类似的方法,但对于更复杂的对象,将它们变成序列是一个很好的方法。
我需要为集合中的每个对象找到 k
个最近的邻居。每个对象都有其坐标作为属性。
为了解决这个任务,我正在尝试使用 scipy
中的 spatial.KDTree
。如果我使用列表或元组来表示一个点,它工作正常,但它不适用于对象。
我在 class 中实现了 __getitem__
和 __len__
方法,但是 KDTree
实现要求我的对象提供不存在的坐标轴(比如二维点的第 3 坐标) ).
这是重现问题的简单脚本:
from scipy import spatial
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
else:
raise Exception('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)
print(result)
我没有必要使用这个具体的实现或库甚至算法,但要求是处理对象。
P.S。我可以将 id
字段添加到每个对象并将所有坐标移动到单独的数组中,其中索引是对象 id
。但我还是想尽可能避免这种做法。
class 可能需要访问对象的切片。但是根据您的定义,不可能使用切片(尝试Unit(6, 6)[:]
,它会抛出相同的错误)。
处理此问题的一种方法是将 x 和 y 变量保存在列表中:
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
self.data = [x, y]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return 2
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
result = tree.query(Unit(6,6), 3)
print(result)
(array([1.41421356, 2.82842712, 4.24264069]), array([4, 3, 2]))
docs for scipy.spatial.KDTree
state that the data
parameter should be array_like
which generally means "convertible to a numpy array." And indeed, the first line of initialization tries to convert the data to a numpy array, as you can see in the source code:
class KDTree(object):
""" ... """
def __init__(self, data, leafsize=10):
self.data = np.asarray(data)
所以你想要实现的是一个对象,这样它们的列表就可以很好地转换为一个 numpy 数组。这是
你的 Unit
对象基本上是一个序列,因为它实现了 __len__
和 __getitem__
以及从 0 开始的连续整数索引。 Python 知道你的序列何时结束从它击中 IndexError
。但是您的 __getitem__
却在错误的索引上引发了 Exception
。因此,从这两种方法提供顺序迭代的正常机制中断了。相反,提高一个 IndexError
,你会很好地转换:
class Unit:
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
raise IndexError('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
现在我们可以毫无问题地检查这些转换成 numpy 数组的列表:
In [5]: np.array([Unit(1, 1), Unit(2, 2), Unit(3, 3), Unit(4, 4), Unit(5, 5)])
Out[5]:
array([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]])
所以,我们现在初始化 KDTree
应该没有问题。这就是为什么如果您将坐标存储在一个内部列表中并且只是将 __getitem__
推迟到该列表,或者只是将您的坐标视为一些简单的序列(如列表或元组),那么您会没事的。
像这样使用简单 类 的更简单的方法是使用 namedtuples
或类似的方法,但对于更复杂的对象,将它们变成序列是一个很好的方法。