我有两个浮点数列表,我想计算它们之间的集合差
使用numpy,我最初编写了以下代码:
aprows = allpoints.view([('',allpoints.dtype)]*allpoints.shape[1])
rprows = toberemovedpoints.view([('',toberemovedpoints.dtype)]*toberemovedpoints.shape[1])
diff = setdiff1d(aprows, rprows).view(allpoints.dtype).reshape(-1, 2)
对于像整数这样的东西很有效。对于具有浮点坐标的2d点,它们是一些几何计算的结果,存在有限精度和舍入误差的问题,导致集合差错过一些等式。现在我使用了非常非常慢的
diff = []
for a in allpoints:
remove = False
for p in toberemovedpoints:
if norm(p-a) < 0.1:
remove = True
if not remove:
diff.append(a)
return array(diff)
但是有没有一种方法可以用numpy写这个并恢复速度?
请注意,我希望剩余的点仍然具有完整的精度,因此首先将数字四舍五入,然后进行一组差值可能不是前进的方向(或者是吗?:))
编辑添加基于scipy的解决方案。KDTree似乎可以工作:
def remove_points_fast(allpoints, toberemovedpoints):
diff = []
removed = 0
# prepare a KDTree
from scipy.spatial import KDTree
tree = KDTree(toberemovedpoints, leafsize=allpoints.shape[0]+1)
for p in allpoints:
distance, ndx = tree.query([p], k=1)
if distance < 0.1:
removed += 1
else:
diff.append(p)
return array(diff), removed
如果您希望使用矩阵形式执行此操作,那么使用较大的数组会消耗大量内存。如果这无关紧要,那么您可以通过以下方式得到差矩阵:
diff_array = allpoints[:,None] - toberemovedpoints[None,:]
结果数组的行数与所有点中的点数相等,列数与待移动点中的点数相等。然后你可以以任何你想要的方式操作它(例如计算绝对值),它会给你一个布尔数组。要查找哪些行有任何命中(绝对差<1),使用numpy.any
:
hits = numpy.any(numpy.abs(diff_array) < .1, axis=1)
现在你有了一个向量,它的项数与difference数组中的行数相同。您可以使用该向量对所有点进行索引(为负值,因为我们想要非匹配点):
return allpoints[-hits]
这是一种愚蠢的做法。但是,正如我上面所说的,它需要大量的内存。
如果你有更大的数据,那么你最好一点一点地做。像这样:
return allpoints[-numpy.array([numpy.any(numpy.abs(a-toberemoved) < .1) for a in allpoints ])]
这应该在大多数情况下表现良好,并且内存使用比矩阵解决方案低得多。(出于文体上的原因,你可能想使用numpy。而不是麻木。任意并将比较反过来以摆脱否定。
(注意,代码中可能有打印错误)