当坐标保存在对象中时,使用 python 中的 kd-tree 查找 k 个最近的邻居



我需要从集合中找到每个对象的k最近邻。每个对象都有其坐标作为属性。 为了解决任务,我正在尝试使用scipyspatial.KDTree.如果我使用列表或元组来表示一个点,它可以正常工作,但它不适用于对象。 我在类中实现了__getitem____len__方法,但KDTree实现要求我的对象提供不存在的坐标轴(例如二维点的第 3 个坐标(。

下面是重现该问题的简单脚本:

from scipy import spatial
class Unit:
    def __init__(self, x,y):
        self.x = x
        self.y = y

    def __getitem__(self, index):        
        if index == 0:
            return self.x
        elif index == 1:
            return self.y
        else:          
            raise Exception('Unit coordinates are 2 dimensional')

    def __len__(self):        
        return 2

#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)
print(result)

我没有必要使用这个特定的实现或库甚至算法,但要求是处理对象。

附言我可以向每个对象添加id字段,并将所有坐标移动到单独的数组中,其中索引是对象id。但如果可能的话,我仍然想避免这种做法。

scipy.spatial.KDTree的文档指出data参数应该array_like,这通常意味着"可转换为 numpy 数组"。事实上,初始化的第一行尝试将数据转换为 numpy 数组,如源代码所示:

class KDTree(object):
    """ ... """
    def __init__(self, data, leafsize=10):
        self.data = np.asarray(data)

因此,您要实现的是一个对象,以便它们的列表可以很好地转换为numpy数组。这很难准确定义,因为 numpy 尝试了许多方法将对象转换为数组。但是,包含许多相同长度的序列的可迭代对象绝对符合条件。

您的Unit对象基本上是一个序列,因为它实现了__len____getitem__,并使用从 0 开始的连续整数进行索引。Python 知道你的序列何时从它击中IndexError结束。但是您的__getitem__反而对不良指数提出了Exception。因此,从这两种方法提供顺序迭代的正常机制中断了。相反,提出一个IndexError,你会很好地转换:

class Unit:
    def __init__(self, x, y):
        self.x = x
        self.y = y
    def __getitem__(self, index):        
        if index == 0:
            return self.x
        elif index == 1:
            return self.y
        raise IndexError('Unit coordinates are 2 dimensional')
    def __len__(self):        
        return 2

现在我们可以检查这些转换为 numpy 数组的列表,没有问题:

In [5]: np.array([Unit(1, 1), Unit(2, 2), Unit(3, 3), Unit(4, 4), Unit(5, 5)])
Out[5]:
array([[1, 1],
       [2, 2],
       [3, 3],
       [4, 4],
       [5, 5]])

因此,我们现在初始化KDTree应该没有问题。这就是为什么如果您将坐标存储在内部列表中并仅__getitem__推迟到该列表,或者只是将坐标视为一些简单的序列(如列表或元组(,那么您会没问题。

对于像这样的简单类,一种更简单的方法是使用 namedtuples 或类似对象,但对于更复杂的对象,将它们转换为序列是一种不错的方法。

该类可能需要访问对象的切片。但是使用您的定义,不可能使用切片(尝试Unit(6, 6)[:],它会抛出相同的错误(。

处理此问题的一种方法是将 x 和 y 变量保存在列表中:

class Unit:
    def __init__(self, x,y):
        self.x = x
        self.y = y
        self.data = [x, y]
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):        
        return 2
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
result = tree.query(Unit(6,6), 3)
print(result)
(array([1.41421356, 2.82842712, 4.24264069]), array([4, 3, 2]))

最新更新