当坐标保存在对象中时,使用 python 中的 kd-tree 查找 k 个最近的邻居

Find k nearest neighbors using kd-tree in python when coordinates are held in objects

我需要为集合中的每个对象找到 k 个最近的邻居。每个对象都有其坐标作为属性。 为了解决这个任务,我正在尝试使用 scipy 中的 spatial.KDTree。如果我使用列表或元组来表示一个点,它工作正常,但它不适用于对象。 我在 class 中实现了 __getitem____len__ 方法,但是 KDTree 实现要求我的对象提供不存在的坐标轴(比如二维点的第 3 坐标) ).

这是重现问题的简单脚本:

from scipy import spatial

class Unit:

    def __init__(self, x,y):
        self.x = x
        self.y = y


    def __getitem__(self, index):        
        if index == 0:
            return self.x
        elif index == 1:
            return self.y
        else:          
            raise Exception('Unit coordinates are 2 dimensional')


    def __len__(self):        
        return 2



#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]

tree = spatial.KDTree(points)

#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)

print(result)

我没有必要使用这个具体的实现或库甚至算法,但要求是处理对象。

P.S。我可以将 id 字段添加到每个对象并将所有坐标移动到单独的数组中,其中索引是对象 id。但我还是想尽可能避免这种做法。

class 可能需要访问对象的切片。但是根据您的定义,不可能使用切片(尝试Unit(6, 6)[:],它会抛出相同的错误)。

处理此问题的一种方法是将 x 和 y 变量保存在列表中:

class Unit:
    def __init__(self, x,y):
        self.x = x
        self.y = y
        self.data = [x, y]

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):        
        return 2

points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
result = tree.query(Unit(6,6), 3)

print(result)
(array([1.41421356, 2.82842712, 4.24264069]), array([4, 3, 2]))

docs for scipy.spatial.KDTree state that the data parameter should be array_like which generally means "convertible to a numpy array." And indeed, the first line of initialization tries to convert the data to a numpy array, as you can see in the source code:

class KDTree(object):
    """ ... """
    def __init__(self, data, leafsize=10):
        self.data = np.asarray(data)

所以你想要实现的是一个对象,这样它们的列表就可以很好地转换为一个 numpy 数组。这是 ,因为 numpy 尝试了很多方法来将你的对象变成一个数组。但是,包含许多相同长度序列的可迭代绝对符合条件。

你的 Unit 对象基本上是一个序列,因为它实现了 __len____getitem__ 以及从 0 开始的连续整数索引。 Python 知道你的序列何时结束从它击中 IndexError。但是您的 __getitem__ 却在错误的索引上引发了 Exception。因此,从这两种方法提供顺序迭代的正常机制中断了。相反,提高一个 IndexError,你会很好地转换:

class Unit:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):        
        if index == 0:
            return self.x
        elif index == 1:
            return self.y
        raise IndexError('Unit coordinates are 2 dimensional')

    def __len__(self):        
        return 2

现在我们可以毫无问题地检查这些转换成 numpy 数组的列表:

In [5]: np.array([Unit(1, 1), Unit(2, 2), Unit(3, 3), Unit(4, 4), Unit(5, 5)])
Out[5]:
array([[1, 1],
       [2, 2],
       [3, 3],
       [4, 4],
       [5, 5]])

所以,我们现在初始化 KDTree 应该没有问题。这就是为什么如果您将坐标存储在一个内部列表中并且只是将 __getitem__ 推迟到该列表,或者只是将您的坐标视为一些简单的序列(如列表或元组),那么您会没事的。

像这样使用简单 类 的更简单的方法是使用 namedtuples 或类似的方法,但对于更复杂的对象,将它们变成序列是一个很好的方法。