邻接矩阵从Numpy数组使用欧几里得距离



有人可以帮助我请如何生成一个加权邻接矩阵从numpy数组基于所有行之间的欧几里德距离,即0和1,0和2,..1和2,…?

给定输入矩阵(5,4)的示例:

matrix = [[2,10,9,6],
[5,1,4,7],
[3,2,1,0], 
[10, 20, 1, 4], 
[17, 3, 5, 18]]

我想获得一个加权邻接矩阵(5,5),其中包含节点之间的最小距离,即

if dist(row0, row1)= 10,77 and dist(row0, row2)= 12,84, 
--> the output matrix will take the first distance as a column value. 

我已经用下面的代码解决了邻接矩阵生成的第一部分:

from scipy.spatial.distance import cdist
dist = cdist( matrix, matrix, metric='euclidean')

,得到如下结果:

array([[ 0.        , 10.77032961, 12.84523258, 15.23154621, 20.83266666],
[10.77032961,  0.        ,  7.93725393, 20.09975124, 16.43167673],
[12.84523258,  7.93725393,  0.        , 19.72308292, 23.17326045],
[15.23154621, 20.09975124, 19.72308292,  0.        , 23.4520788 ],
[20.83266666, 16.43167673, 23.17326045, 23.4520788 ,  0.        ]])

但是我还不知道如何指定我们为每个节点选择的邻居的数量,例如2个邻居。例如,我们定义邻居的数量N = 2,那么对于每一行,我们只选择两个距离最小的邻居,我们得到的结果是:

[[ 0.        , 10.77032961, 12.84523258, 0, 0],
[10.77032961,  0.        ,  7.93725393, 0, 0],
[12.84523258,  7.93725393,  0.        , 0, 0],
[15.23154621, 0, 19.72308292,  0.        , 0 ],
[20.83266666, 16.43167673, 0, 0 ,  0.        ]]

您可以使用这个更简洁的解决方案从矩阵中获得最小的n。试试下面的命令-

dist.argsort(1).argsort(1)在轴=1上创建一个排名顺序(最小为0,最大为4),而<= 2决定了您需要从排名顺序中获得的最小值的数量。np.where将其过滤或替换为0。

np.where(dist.argsort(1).argsort(1) <= 2, dist, 0)
array([[ 0.        , 10.77032961, 12.84523258,  0.        ,  0.        ],
[10.77032961,  0.        ,  7.93725393,  0.        ,  0.        ],
[12.84523258,  7.93725393,  0.        ,  0.        ,  0.        ],
[15.23154621,  0.        , 19.72308292,  0.        ,  0.        ],
[20.83266666, 16.43167673,  0.        ,  0.        ,  0.        ]])

这适用于任何轴,或者如果你想从矩阵中获得最大或最小值。

假设a是您的欧几里得距离矩阵,您可以使用np.argpartition来选择每行nmin/max值。请记住,对角线总是0,欧几里得距离是非负的,所以为了在每行中保持两个最近的点,您需要每行保持三个min(包括对角线上的0)。但是,如果你想设置为max,这就不成立了。

a[np.arange(a.shape[0])[:,None],np.argpartition(a, 3, axis=1)[:,3:]] = 0 

输出:

array([[ 0.        , 10.77032961, 12.84523258,  0.        ,  0.        ],
[10.77032961,  0.        ,  7.93725393,  0.        ,  0.        ],
[12.84523258,  7.93725393,  0.        ,  0.        ,  0.        ],
[15.23154621,  0.        , 19.72308292,  0.        ,  0.        ],
[20.83266666, 16.43167673,  0.        ,  0.        ,  0.        ]])

最新更新