我正在尝试编写一个函数,该函数将使用"最近邻居"或"最近匹配"类型的算法来过滤元组列表(模仿内存中的数据库)。
我想知道做这件事的最佳方式(即最Python的方式)。下面的示例代码有望说明我正在尝试做什么
datarows = [(10,2.0,3.4,100),
(11,2.0,5.4,120),
(17,12.9,42,123)]
filter_record = (9,1.9,2.9,99) # record that we are seeking to retrieve from 'database' (or nearest match)
weights = (1,1,1,1) # weights to approportion to each field in the filter
def get_nearest_neighbour(data, criteria, weights):
for each row in data:
# calculate 'distance metric' (e.g. simple differencing) and multiply by relevant weight
# determine the row which was either an exact match or was 'least dissimilar'
# return the match (or nearest match)
pass
if __name__ == '__main__':
result = get_nearest_neighbour(datarow, filter_record, weights)
print result
对于上面的代码段,输出应该是:
(10,2.0,3.4100)
因为它与传递给函数get_nearest_neighbur()的样本数据"最接近"。
那么,我的问题是,实现get_nearest_neighbur()的最佳方法是什么?。为了简洁等目的,假设我们只处理数值,并且我们使用的"距离度量"只是从当前行中减去输入数据的算术运算。
简单的开箱即用解决方案:
import math
def distance(row_a, row_b, weights):
diffs = [math.fabs(a-b) for a,b in zip(row_a, row_b)]
return sum([v*w for v,w in zip(diffs, weights)])
def get_nearest_neighbour(data, criteria, weights):
def sort_func(row):
return distance(row, criteria, weights)
return min(data, key=sort_func)
如果您需要处理巨大的数据集,您应该考虑切换到Numpy并使用Numpy的KDTree
来查找最近的邻居。使用Numpy的优势在于,它不仅使用了更先进的算法,而且实现了高度优化的LAPACK(线性代数PACKage)。
关于naive NN:
这些其他答案中的许多都提出了"天真最近邻居",这是一种O(N*d)
-每查询算法(d是维度,在这种情况下看起来是恒定的,所以它是O(N)
-每查询)。
虽然O(N)
-每个查询的算法非常糟糕,但如果你的数量少于(例如)中的任何一个,你可能会逃脱惩罚
- 10次查询和100000分
- 100个查询和10000个积分
- 1000个查询和1000个积分
- 10000次查询和100分
- 100000次查询和10分
比天真的NN:做得更好
否则,您将希望使用中列出的技术之一(尤其是最近邻数据结构)
-
http://en.wikipedia.org/wiki/Nearest_neighbor_search(很可能是从该页面链接下来的),一些链接的例子:
- http://en.wikipedia.org/wiki/K-d_tree
- http://en.wikipedia.org/wiki/Locality_sensitive_hashing
- http://en.wikipedia.org/wiki/Cover_tree
尤其是如果你计划多次运行你的程序。很可能有可用的库。如果您有大量的#querys**点,那么不使用NN数据结构将花费太多时间。正如用户"dsign"在评论中指出的那样,您可以通过使用numpy库来挤出一个额外的速度常数。
然而,如果你可以使用简单易实现的朴素NN,你应该使用它。
在生成器上使用heapq.nlargest计算每条记录的距离*权重。
类似于:
heapq.nlargest(N, ((row, dist_function(row,criteria,weight)) for row in data), operator.itemgetter(1))