Python:对数据记录(元组列表)进行最近邻居(或最匹配)过滤



我正在尝试编写一个函数,该函数将使用"最近邻居"或"最近匹配"类型的算法来过滤元组列表(模仿内存中的数据库)。

我想知道做这件事的最佳方式(即最Python的方式)。下面的示例代码有望说明我正在尝试做什么

datarows = [(10,2.0,3.4,100),
            (11,2.0,5.4,120),
            (17,12.9,42,123)]
filter_record = (9,1.9,2.9,99) # record that we are seeking to retrieve from 'database' (or nearest match)
weights = (1,1,1,1) # weights to approportion to each field in the filter
def get_nearest_neighbour(data, criteria, weights):
    for each row in data:
        # calculate 'distance metric' (e.g. simple differencing) and multiply by relevant weight
    # determine the row which was either an exact match or was 'least dissimilar'
    # return the match (or nearest match)
    pass
if __name__ == '__main__':
    result = get_nearest_neighbour(datarow, filter_record, weights)
    print result

对于上面的代码段,输出应该是:

(10,2.0,3.4100)

因为它与传递给函数get_nearest_neighbur()的样本数据"最接近"。

那么,我的问题是,实现get_nearest_neighbur()的最佳方法是什么?。为了简洁等目的,假设我们只处理数值,并且我们使用的"距离度量"只是从当前行中减去输入数据的算术运算。

简单的开箱即用解决方案:

import math
def distance(row_a, row_b, weights):
    diffs = [math.fabs(a-b) for a,b in zip(row_a, row_b)]
    return sum([v*w for v,w in zip(diffs, weights)])
def get_nearest_neighbour(data, criteria, weights):
    def sort_func(row):
        return distance(row, criteria, weights)
    return min(data, key=sort_func)
    

如果您需要处理巨大的数据集,您应该考虑切换到Numpy并使用Numpy的KDTree来查找最近的邻居。使用Numpy的优势在于,它不仅使用了更先进的算法,而且实现了高度优化的LAPACK(线性代数PACKage)。

关于naive NN:

这些其他答案中的许多都提出了"天真最近邻居",这是一种O(N*d)-每查询算法(d是维度,在这种情况下看起来是恒定的,所以它是O(N)-每查询)。

虽然O(N)-每个查询的算法非常糟糕,但如果你的数量少于(例如)中的任何一个,你可能会逃脱惩罚

  • 10次查询和100000分
  • 100个查询和10000个积分
  • 1000个查询和1000个积分
  • 10000次查询和100分
  • 100000次查询和10分

比天真的NN:做得更好

否则,您将希望使用中列出的技术之一(尤其是最近邻数据结构)

  • http://en.wikipedia.org/wiki/Nearest_neighbor_search(很可能是从该页面链接下来的),一些链接的例子:

    • http://en.wikipedia.org/wiki/K-d_tree
    • http://en.wikipedia.org/wiki/Locality_sensitive_hashing
    • http://en.wikipedia.org/wiki/Cover_tree

尤其是如果你计划多次运行你的程序。很可能有可用的库。如果您有大量的#querys**点,那么不使用NN数据结构将花费太多时间。正如用户"dsign"在评论中指出的那样,您可以通过使用numpy库来挤出一个额外的速度常数。

然而,如果你可以使用简单易实现的朴素NN,你应该使用它。

在生成器上使用heapq.nlargest计算每条记录的距离*权重。

类似于:

heapq.nlargest(N, ((row, dist_function(row,criteria,weight)) for row in data), operator.itemgetter(1))

最新更新