蟒蛇推土机2D阵列的距离



我想计算两个2D阵列之间的推土机距离(这些不是图像(。

现在我浏览两个图书馆:scipy(https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wasserstein_distance.html(和pyemd(https://pypi.org/project/pyemd/(。

#define a sampeling method
def sampeling2D(n, mu1, std1, mu2, std2):
#sample from N(0, 1) in the 2D hyperspace
x = np.random.randn(n, 2)
#scale N(0, 1) -> N(mu, std)
x[:,0] = (x[:,0]*std1) + mu1
x[:,1] = (x[:,1]*std2) + mu2
return x
#generate two sets
Y1 = sampeling2D(1000, 0, 1, 0, 1)
Y2 = sampeling2D(1000, -1, 1, -1, 1)
#compute the distance
distance = pyemd.emd_samples(Y1, Y2)

虽然scipy版本不接受 2D 数组并返回错误,但pyemd方法返回一个值。如果您从文档中看到,它说它只接受 1D 数组,所以我认为输出是错误的。在这种情况下,我如何计算这个距离?

因此,如果我理解正确,您正在尝试传输采样分布,即计算所有集群权重为 1 的设置的距离。通常,您可以将 EMD 的计算视为最小成本流的实例,在您的情况下,这归结为线性分配问题:您的两个数组是二分图中的分区,两个顶点之间的权重是您选择的距离。假设你想使用欧几里得范数作为你的度量,边的权重,即地面距离,可以使用scipy.spatial.distance.cdist获得,事实上,SciPy也为线性和赋值问题提供了一个求解器,scipy.optimize.linear_sum_assignment(最近在SciPy 1.4中提供了巨大的性能改进。如果您遇到性能问题,您可能会对此感兴趣;对于 1000x1000 输入,1.3 实现有点慢(。

换句话说,你想做什么归结

from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
d = cdist(Y1, Y2)
assignment = linear_sum_assignment(d)
print(d[assignment].sum() / n)

也可以使用scipy.sparse.csgraph.min_weight_bipartite_full_matching作为linear_sum_assignment的直接替代品;虽然是为稀疏输入(你的当然不是(,但它可能会在某些情况下提供性能改进。

验证此计算的结果是否与您从最小成本流求解器获得的结果相匹配可能会有所启发;NetworkX 中有一个这样的求解器,我们可以在其中手动构建图形:

import networkx as nx
G = nx.DiGraph()
# Represent elements in Y1 by 0, ..., 999, and elements in
# Y2 by 1000, ..., 1999.
for i in range(n):
G.add_node(i, demand=-1)
G.add_node(n + i, demand=1)
for i in range(n):
for j in range(n):
G.add_edge(i, n + j, capacity=1, weight=d[i, j])

此时,我们可以验证上述方法是否与最小成本流一致:

In [16]: d[assignment].sum() == nx.algorithms.min_cost_flow_cost(G)
Out[16]: True

同样,看到结果与一维输入的scipy.stats.wasserstein_distance一致也是有启发性的:

from scipy.stats import wasserstein_distance
np.random.seed(0)
n = 100
Y1 = np.random.randn(n)
Y2 = np.random.randn(n) - 2
d =  np.abs(Y1 - Y2.reshape((n, 1)))
assignment = linear_sum_assignment(d)
print(d[assignment].sum() / n)       # 1.9777950447866477
print(wasserstein_distance(Y1, Y2))  # 1.977795044786648

最新更新