更快地计算两个三维点之间的距离



我有4个长度为160000的列表,分别为s、x、y、z。我做了一个x,y,z的三维数组的列表(点(。我需要找到所有点组合之间的距离作为标准,并将点的索引与列表s的索引相匹配,这样我就得到了满足它的2个点的s值。我正在使用下面的代码。有什么更快的方法吗?

import numpy as np
points = []
for i in range(len(xnew)):
a = np.array((xnew[i],ynew[i],znew[i]))
points.append(a)
for i in range(len(points)):
for j in range(len(points)):
d = np.sqrt(np.sum((points[i] - points[j]) ** 2))
if d <= 4 and d >=3:
print(s[i],s[j],d)

想法是使用cdist和np.where对处理进行矢量化

代码

import numpy as np
import scipy.spatial.distance
# Distance between all pairs of points
d = scipy.spatial.distance.cdist(points, points)
# Pairs within threshold
indexes = np.where(np.logical_and(d>=3, d<=4))
for i, j in indexes:
if i < j: # since distance is symmetric, not reporting j, i
print(s[i],s[j],d[i][j])

如果d矩阵太大而无法放入内存,则找出每个点到所有其他点的距离

for i in range(len(points)):
# Distance from point i to all other points
d = scipy.spatial.distance.cdist(points,[points[i]])
# Points within threshold
indexes = np.where(np.logical_and(d>=3, d<=4))

for ind in indexes:
if ind.size > 0:
for j in ind:
if i < j:   # since distance is symmetric, not reporting j, i
print(s[i], s[j], d[j][0])

测试

points = [
[1, 2, 3],
[1.1, 2.2, 3.3],
[4, 5, 6],
[2, 3, 4]
]
s = [0, 1, 2, 3]

输出(两种方法(

2 3 3.4641016151377544
points = np.array([x, y, z]).T                         
t1, t2 = np.triu_indices(len(points), k= 1)                  # triangular indices
p1 = points[t1]
p2 = points[t2]
d = p1 - p2                       # displacements from p1 to p2
d = np.linalg.norm(d, axis= -1)   # distances     from p1 to p2
mask = (3 <= d) & (d <= 4)
indx = np.where(mask)                     # indices where distance is between 4 and 3
ans = np.array([ s[t1[i]], s[t2[i]], d[i] ]).T

试运行:

n = 10
x = np.random.randint(10, size= [n])          # dummy data
y = np.random.randint(10, size= [n])
z = np.random.randint(10, size= [n])
s = np.random.randint(10, size= [n])

在运行上述代码之后

points
>>> array([
[9, 3, 5],
[7, 8, 1],
[0, 0, 2],
[6, 7, 2],
[4, 4, 3],
[8, 0, 9],
[5, 2, 6],
[0, 8, 9],
[2, 6, 9],
[4, 8, 4]])
s
>>> array([4, 2, 9, 9, 8, 2, 7, 6, 0, 5])
for e in ans:
print(*e)
>>> 9.0  8.0  3.7416573867739413
9.0  5.0  3.0
8.0  7.0  3.7416573867739413

最新更新