最近邻搜索 kd树



N点的列表[(x_1,y_1), (x_2,y_2), ... ]我试图根据距离找到每个点最近的邻居。我的数据集太大,无法使用蛮力方法,因此 KDtree 似乎是最好的。

我看到sklearn.neighbors.KDTree可以找到最近的邻居,而不是从头开始实施一个。这可以用来找到每个粒子的最近邻居,即返回一个dim(N)列表吗?

这个问题非常广泛,缺少细节。目前还不清楚你尝试了什么,你的数据是什么样子的,最近的邻居是什么(身份?

假设您对标识(距离为 0)不感兴趣,则可以查询两个最近的邻居并删除第一列。这可能是最简单的方法。

法典:

import numpy as np
from sklearn.neighbors import KDTree
np.random.seed(0)
X = np.random.random((5, 2))  # 5 points in 2 dimensions
tree = KDTree(X)
nearest_dist, nearest_ind = tree.query(X, k=2)  # k=2 nearest neighbors where k1 = identity
print(X)
print(nearest_dist[:, 1])    # drop id; assumes sorted -> see args!
print(nearest_ind[:, 1])     # drop id 

输出

[[ 0.5488135   0.71518937]
[ 0.60276338  0.54488318]
[ 0.4236548   0.64589411]
[ 0.43758721  0.891773  ]
[ 0.96366276  0.38344152]]
[ 0.14306129  0.1786471   0.14306129  0.20869372  0.39536284]
[2 0 0 0 1]

你可以使用sklearn.neighbors.KDTreequery_radius()方法,该方法返回某个半径内最近邻的索引列表(而不是返回k个最近的邻域)。

from sklearn.neighbors import KDTree
points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
tree = KDTree(points, leaf_size=2)
all_nn_indices = tree.query_radius(points, r=1.5)  # NNs within distance of 1.5 of point
all_nns = [[points[idx] for idx in nn_indices] for nn_indices in all_nn_indices]
for nns in all_nns:
print(nns)

输出:

[(1, 1), (2, 2)]
[(1, 1), (2, 2), (3, 3)]
[(2, 2), (3, 3), (4, 4)]
[(3, 3), (4, 4), (5, 5)]
[(4, 4), (5, 5)]

请注意,每个点都包含在其给定半径内的最近邻列表中。如果要删除这些标识点,可以将线路计算all_nns更改为:

all_nns = [
[points[idx] for idx in nn_indices if idx != i]
for i, nn_indices in enumerate(all_nn_indices)
]

结果是:

[(2, 2)]
[(1, 1), (3, 3)]
[(2, 2), (4, 4)]
[(3, 3), (5, 5)]
[(4, 4)]

2023 更新

我不得不重新审视这一点,发现我的实现虽然非常快,但并不准确。SKLEARN应该是最好的。我在一段时间前写了下面,我需要自定义距离。(sklearn 不支持 KDTree 的自定义距离函数,但支持 BallTree。

这是一个Jupyter笔记本,其中包含sklearn BallTree和KDTree以及我的自定义代码的计时。 https://colab.research.google.com/drive/1ymx2r3J7oUMAuPlsZxDnESM7aTfwbLWV?usp=sharing

BallTree是准确的,但很慢,KDTree获得前 5 名,但顺序不同,不是那么准确,我的代码缺少前 5 名中的一个节点,因此不准确。

我相信这是由于将纬度,长度投影到矩形系统(KDTree 在 x,y.. 轴上切换),然后使用自定义距离函数。另请参阅类似的讨论。 我试图将其转换为笛卡尔坐标,结果并没有好多少。保持原始帖子原样,如果它有用


改编自我的要点,用于2D https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8

# From https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8
# Author alex punnen
from collections import namedtuple
from operator import itemgetter
import numpy as np

def find_nearest_neighbour(node,point,distance_fn,current_axis):
# Algorith to find nearest neighbour in a KD Tree;the KD tree has done a spatial sort
# of the given co-ordinates, such that to the left of the root lies co-ordinates nearest to the x-axis
# and to the right of the root ,lies the co-ordinates farthest from the x axis
# On the y axis split on the left of the parent/root node lies co-ordinates nearest to the y-axis and to
# the right of the root, lies the co-ordinates farthest from the y axis
# to find the nearest neightbour, from the root, you first check left and right node; if distance is closer
# to the right node,then the entire left node can be discarded from search, because of the spatial split
# and that node becomes the root node. This process is continued recursively till the nearest is found
# param:node: The current node
# param: point: The point to which the nearest neighbour is to be found
# param: distance_fn: to calculate the nearest neighbour
# param: current_axis: here assuming only two dimenstion and current axis will be either x or y , 0 or 1

if node is None:
return None,None
current_closest_node = node
closest_known_distance = distance_fn(node.cell[0],node.cell[1],point[0],point[1])
print closest_known_distance,node.cell

x = (node.cell[0],node.cell[1])
y = point

new_node = None
new_closest_distance = None
if x[current_axis] > y[current_axis]:
new_node,new_closest_distance= find_nearest_neighbour(node.left_branch,point,distance_fn,
(current_axis+1) %2)
else:
new_node,new_closest_distance = find_nearest_neighbour(node.right_branch,point,distance_fn,
(current_axis+1) %2) 

if  new_closest_distance and new_closest_distance < closest_known_distance:
print 'Reset closest node to ',new_node.cell
closest_known_distance = new_closest_distance
current_closest_node = new_node

return current_closest_node,closest_known_distance


class Node(namedtuple('Node','cell, left_branch, right_branch')):
# This Class is taken from wikipedia code snippet for  KD tree
pass

def create_kdtree(cell_list,current_axis,no_of_axis):
# Creates a KD Tree recursively following the snippet from wikipedia for KD tree
# but making it generic for any number of axis and changes in data strucure
if not cell_list:
return
# get the cell as a tuple list this is for 2 dimensions
k= [(cell[0],cell[1])  for cell  in cell_list]
# say for three dimension
# k= [(cell[0],cell[1],cell[2])  for cell  in cell_list]
k.sort(key=itemgetter(current_axis)) # sort on the current axis
median = len(k) // 2 # get the median of the list
axis = (current_axis + 1) % no_of_axis # cycle the axis
return Node(k[median], # recurse 
create_kdtree(k[:median],axis,no_of_axis),
create_kdtree(k[median+1:],axis,no_of_axis))
def eucleaden_dist(x1,y1,x2,y2):
a= np.array([x1,y1])
b= np.array([x2,y2])
dist = np.linalg.norm(a-b)
return dist

np.random.seed(0)
#cell_list = np.random.random((2, 2))
#cell_list = cell_list.tolist()
cell_list = [[2,2],[4,8],[10,2]]
print(cell_list)
tree = create_kdtree(cell_list,0,2)
node,distance = find_nearest_neighbour(tree,(1, 1),eucleaden_dist,0)
print 'Nearest Neighbour=',node.cell,distance
node,distance = find_nearest_neighbour(tree,(8, 1),eucleaden_dist,0)
print 'Nearest Neighbour=',node.cell,distance

我实现了这个问题的解决方案,我认为这可能会有所帮助。

from collections import namedtuple
from operator import itemgetter
from pprint import pformat
from math import inf

def nested_getter(idx1, idx2):
def g(obj):
return obj[idx1][idx2]
return g

class Node(namedtuple('Node', 'location left_child right_child')):
def __repr__(self):
return pformat(tuple(self))

def kdtree(point_list, depth: int = 0):
if not point_list:
return None
k = len(point_list[0])  # assumes all points have the same dimension
# Select axis based on depth so that axis cycles through all valid values
axis = depth % k
# Sort point list by axis and choose median as pivot element
point_list.sort(key=nested_getter(1, axis))
median = len(point_list) // 2
# Create node and construct subtrees
return Node(
location=point_list[median],
left_child=kdtree(point_list[:median], depth + 1),
right_child=kdtree(point_list[median + 1:], depth + 1)
)

def nns(q, n, p, w, depth: int = 0):
"""
NNS = Nearest Neighbor Search
:param depth:
:param q: point
:param n: node
:param p: ref point
:param w: ref distance
:return: new_p, new_w
"""
new_w = distance(q[1], n.location[1])
# below we test if new_w > 0 because we don't want to allow p = q
if (new_w > 0) and new_w < w:
p, w = n.location, new_w
k = len(p)
axis = depth % k
n_value = n.location[1][axis]
search_left_first = (q[1][axis] <= n_value)
if search_left_first:
if n.left_child and (q[1][axis] - w <= n_value):
new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
if n.right_child and (q[1][axis] + w >= n_value):
new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
else:
if n.right_child and (q[1][axis] + w >= n_value):
new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
if n.left_child and (q[1][axis] - w <= n_value):
new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
return p, w

def main():
"""Example usage of kdtree"""
point_list = [(7, 2), (5, 4), (9, 6), (4, 7), (8, 1), (2, 3)]
tree = kdtree(point_list)
print(tree)

def city_houses():
"""
Here we compute the distance to the nearest city from a list of N cities.
The first line of input contains N, the number of cities.
Each of the next N lines contain two integers x and y, which locate the city in (x,y),
separated by a single whitespace.
It's guaranteed that a spot (x,y) does not contain more than one city.
The output contains N lines, the line i with a number representing the distance
for the nearest city from the i-th city of the input.
"""
n = int(input())
cities = []
for i in range(n):
city = i, tuple(map(int, input().split(' ')))
cities.append(city)
# print(cities)
tree = kdtree(cities)
# print(tree)
ans = [(target[0], nns(target, tree, tree.location, inf)[1]) for target in cities]
ans.sort(key=itemgetter(0))
ans = [item[1] for item in ans]
print('n'.join(map(str, ans)))

def distance(a, b):
# Taxicab distance is used below. You can use squared euclidean distance if you prefer
k = len(b)
total = 0
for i in range(k):
total += abs(b[i] - a[i])
return total

if __name__ == '__main__':
city_houses()

最新更新