熊猫数据帧中所有对的高效 k 最近邻



我有一个 pandas 数据帧,有 20K 行和 50 列。我想根据列的欧几里得距离找到此数据框中每行的 5 个最近邻。因此,结果是一个 20K * 5 的矩阵,其中列是数据帧中最近邻的 ID。

我正在寻找一种解决方案来尽可能高效地做到这一点,最好使用 pandas、并行操作或矢量化操作提供的索引。Scipy kd-tree相当缓慢。

知道吗?

看起来Scipy 的 kd 树确实很慢; 查询一个点大约需要 80 毫秒,我猜这会导致整个数据集的总计算时间约为 0.08 * 20_000 = 1600 秒。

高维数据(例如具有 50 列的数据集(的另一个选项可能是 Ball Tree 数据结构。正如链接中的页面所说:

由于球树节点的球形几何形状,它可以在高维度上优于 KD 树,尽管实际性能高度依赖于训练数据的结构。

玩弄以下代码:

from sklearn.neighbors import NearestNeighbors
import numpy as np
arr = np.random.rand(20_000, 50) * 20
nbrs = NearestNeighbors(n_neighbors = 5, algorithm = 'ball_tree').fit(arr)
%timeit nbrs.kneighbors(arr[:10, :])
# 24.6 ms ± 2.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit nbrs.kneighbors(arr[:100, :])
# 209 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit nbrs.kneighbors(arr[:1000, :])
# 2.02 s ± 226 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

查看这些%timeit结果,似乎算法大致呈线性扩展,因此对于 20k 行,您可能预计它大约需要 20_000/1_000 * 2 = ~40 秒。 40 秒比您很可能从 kd-tree 数据结构中期望的 ~1600 秒快得多。

最后,我绝对建议仔细阅读最近的邻居页面,以便您完全了解他们提供的算法的所有复杂性。

最新更新