使用改编自这个问题的答案的代码,我可以在考虑周期性边界条件的情况下对 3D 数组进行暴力 NN 搜索。然后,代码返回最近邻居的索引,并对所有邻居执行此操作。
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
N = 10000 # Num of objects
# create random point positions
coords = np.random.random((3, N)).transpose()
def NN(point):
dist = np.abs(np.subtract(coords, point)) # Distance between point and all neighbors xyz values
dist = np.where(dist > 0.5, dist - 1, dist) # checking if distance is closer if it wraps around
return (np.square(dist)).sum(axis=-1).argsort()[1] # Calc distance and find index of nearest neighbour
# multi threading for speed increase
pool = ThreadPool(12)
match = pool.map(NN, coords)
pool.close()
pool.join()
对于 N ~ 50000,它按预期变得非常慢。
我想知道我将如何使用诸如sklearn之类的树来实现这一点。BallTree 或 scipy.spacial.cKDTree,并希望这样做,而不会像这里建议的那样再重复空间 8 次。
sklearn.BallTree
和scipy.spatial.cKDTree
都不能轻易修改以处理周期性边界条件。周期边界需要一组与这些实现中使用的假设根本不同的假设。
您应该考虑使用替代实现,例如periodic_kdtree,但请注意,这是一个(有些陈旧的)Python实现,不会像您提到的Cython/C++实现那样快。