Python中用于非平方代价矩阵的匈牙利算法



我想在非平方numpy数组上使用python中的匈牙利赋值算法。

我的输入矩阵X如下所示:

X = np.array([[0.26, 0.64, 0.16, 0.46, 0.5 , 0.63, 0.29],
[0.49, 0.12, 0.61, 0.28, 0.74, 0.54, 0.25],
[0.22, 0.44, 0.25, 0.76, 0.28, 0.49, 0.89],
[0.56, 0.13, 0.45, 0.6 , 0.53, 0.56, 0.05],
[0.66, 0.24, 0.61, 0.21, 0.47, 0.31, 0.35],
[0.4 , 0.85, 0.45, 0.14, 0.26, 0.29, 0.24]])

期望的结果是矩阵有序,例如X变成X_desired_output:

X_desired_output = np.array([[0.63, 0.5 , 0.29, 0.46, 0.26, 0.64, 0.16], 
[0.54, 0.74, 0.25, 0.28, 0.49, 0.12, 0.61], 
[[0.49, 0.28, 0.89, 0.76, 0.22, 0.44, 0.25], 
[[0.56, 0.53, 0.05, 0.6 , 0.56, 0.13, 0.45], 
[[0.31, 0.47, 0.35, 0.21, 0.66, 0.24, 0.61], 
[[0.29, 0.26, 0.24, 0.14, 0.4 , 0.85, 0.45]])

在这里,我想最大化成本,而不是最小化,因此算法的输入在理论上是1-X或简单地是X

我发现https://software.clapper.org/munkres/导致:

from munkres import Munkres
m = Munkres()
indices = m.compute(-X)
indices
[(0, 5), (1, 4), (2, 6), (3, 3), (4, 0), (5, 1)]
# getting the indices in list format
ii = [i for (i,j) in indices]
jj = [j for (i,j) in indices]

如何使用这些对X进行排序?CCD_ 7仅包含6个元素,而CCD_。

我正在寻找真正的矩阵排序。

花了几个小时研究之后,我找到了一个解决方案。问题是由于X.shape[1] > X.shape[0],某些列根本没有分配,这导致了问题的出现。

文件指出

"Munkres算法假设成本矩阵是平方的。但是,如果您首先填充矩形矩阵,则可以使用它使用0值使其为正方形。此模块自动填充矩形成本矩阵,使其成为正方形">

from munkres import Munkres
m = Munkres()
indices = m.compute(-X)
indices
[(0, 5), (1, 4), (2, 6), (3, 3), (4, 0), (5, 1)]
# getting the indices in list format
ii = [i for (i,j) in indices]
jj = [j for (i,j) in indices]
# re-order matrix
X_=X[:,jj]  # re-order columns
X_=X_[ii,:] # re-order rows
# HERE IS THE TRICK: since the X is not diagonal, some columns are not assigned to the rows !
not_assigned_columns = X[:, [not_assigned for not_assigned in np.arange(X.shape[1]).tolist() if not_assigned not in jj]].reshape(-1,1)
X_desired = np.concatenate((X_, not_assigned_columns), axis=1)
print(X_desired)
array([[0.63, 0.5 , 0.29, 0.46, 0.26, 0.64, 0.16],
[0.54, 0.74, 0.25, 0.28, 0.49, 0.12, 0.61],
[0.49, 0.28, 0.89, 0.76, 0.22, 0.44, 0.25],
[0.56, 0.53, 0.05, 0.6 , 0.56, 0.13, 0.45],
[0.31, 0.47, 0.35, 0.21, 0.66, 0.24, 0.61],
[0.29, 0.26, 0.24, 0.14, 0.4 , 0.85, 0.45]])

最新更新