我需要从集合中找到每个对象的k
最近邻。每个对象都有其坐标作为属性。 为了解决任务,我正在尝试使用scipy
的spatial.KDTree
.如果我使用列表或元组来表示一个点,它可以正常工作,但它不适用于对象。 我在类中实现了__getitem__
和__len__
方法,但KDTree
实现要求我的对象提供不存在的坐标轴(例如二维点的第 3 个坐标(。
下面是重现该问题的简单脚本:
from scipy import spatial
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
else:
raise Exception('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)
print(result)
我没有必要使用这个特定的实现或库甚至算法,但要求是处理对象。
附言我可以向每个对象添加id
字段,并将所有坐标移动到单独的数组中,其中索引是对象id
。但如果可能的话,我仍然想避免这种做法。
scipy.spatial.KDTree
的文档指出data
参数应该array_like
,这通常意味着"可转换为 numpy 数组"。事实上,初始化的第一行尝试将数据转换为 numpy 数组,如源代码所示:
class KDTree(object):
""" ... """
def __init__(self, data, leafsize=10):
self.data = np.asarray(data)
因此,您要实现的是一个对象,以便它们的列表可以很好地转换为numpy数组。这很难准确定义,因为 numpy 尝试了许多方法将对象转换为数组。但是,包含许多相同长度的序列的可迭代对象绝对符合条件。
您的Unit
对象基本上是一个序列,因为它实现了__len__
和__getitem__
,并使用从 0 开始的连续整数进行索引。Python 知道你的序列何时从它击中IndexError
结束。但是您的__getitem__
反而对不良指数提出了Exception
。因此,从这两种方法提供顺序迭代的正常机制中断了。相反,提出一个IndexError
,你会很好地转换:
class Unit:
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
raise IndexError('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
现在我们可以检查这些转换为 numpy 数组的列表,没有问题:
In [5]: np.array([Unit(1, 1), Unit(2, 2), Unit(3, 3), Unit(4, 4), Unit(5, 5)])
Out[5]:
array([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]])
因此,我们现在初始化KDTree
应该没有问题。这就是为什么如果您将坐标存储在内部列表中并仅__getitem__
推迟到该列表,或者只是将坐标视为一些简单的序列(如列表或元组(,那么您会没问题。
对于像这样的简单类,一种更简单的方法是使用 namedtuples
或类似对象,但对于更复杂的对象,将它们转换为序列是一种不错的方法。
该类可能需要访问对象的切片。但是使用您的定义,不可能使用切片(尝试Unit(6, 6)[:]
,它会抛出相同的错误(。
处理此问题的一种方法是将 x 和 y 变量保存在列表中:
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
self.data = [x, y]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return 2
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
result = tree.query(Unit(6,6), 3)
print(result)
(array([1.41421356, 2.82842712, 4.24264069]), array([4, 3, 2]))