查找两个'connected'矩阵的最大最小值的最快方法



我想最大化以下功能:

f(i, j, k) = min(A(i, j), B(j, k))

其中AB为矩阵,ijk为矩阵各自维数范围内的指标。我想找到(i, j, k),这样f(i, j, k)最大化。我现在做的是:

import numpy as np
import itertools
shape_a = (100       , 150)
shape_b = (shape_a[1], 200)
A = np.random.rand(shape_a[0], shape_a[1])
B = np.random.rand(shape_b[0], shape_b[1])
# All the different i,j,k
combinations = itertools.product(np.arange(shape_a[0]), np.arange(shape_a[1]), np.arange(shape_b[1]))
combinations = np.asarray(list(combinations))
A_vals = A[combinations[:, 0], combinations[:, 1]]
B_vals = B[combinations[:, 1], combinations[:, 2]]
f = np.min([A_vals, B_vals], axis=0)
best_indices = combinations[np.argmax(f)]
print(best_indices)
[ 49  14 136]

这比遍历所有(i, j, k)要快,但是大量(和大部分)时间花在构造A_valsB_vals矩阵上。这是不幸的,因为它们包含许多重复的值,因为相同的i,jk出现多次。是否有一种方法可以做到这一点(1)numpy的矩阵计算速度可以保持,(2)我不必构建内存密集型A_valsB_vals数组。

在其他语言中,您可能可以构造矩阵,以便它们包含指向AB的指针,但我没有看到如何在Python中实现这一点。

也许您可以重新评估您如何在min和max实际做什么的上下文中看待问题。假设你有以下具体的例子:

>>> np.random.seed(1)
>>> print(A := np.random.randint(10, size=(4, 5)))
[[5 8 9 5 0]
[0 1 7 6 9]
[2 4 5 2 4]
[2 4 7 7 9]]
>>> print(B := np.random.randint(10, size=(5, 3)))
[[1 7 0]
[6 9 9]
[7 6 9]
[1 0 1]
[8 8 3]]

您正在寻找AB中的一对数字,使A中的列与B的行相同,并且您得到最大的较小数字。

对于任何一组数,最大的两两最小值出现在取两个最大的数时。因此,您要在A的每一列,B的每一行中寻找最大值,这些对的最小值,然后是最大值。下面是一个相对简单的解决方案:

candidate_i = A.argmax(axis=0)
candidate_k = B.argmax(axis=1)
j = np.minimum(A[candidate_i, np.arange(A.shape[1])], B[np.arange(B.shape[0]), candidate_k]).argmax()
i = candidate_i[j]
k = candidate_k[j]

事实上,你看到

>>> i, j, k
(0, 2, 2)
>>> A[i, j]
9
>>> B[j, k]
9

如果有碰撞,argmax总是会选择第一个选项。

您的值i,j,k由集合{A,B}的最大值的索引决定。您可以简单地使用np.argmax().

if np.max(A) < np.max(B):
ind = np.unravel_index(np.argmax(A),A.shape)
else:
ind = np.unravel_index(np.argmax(B),B.shape)

它将只返回两个值,如果max({A,B}) = max({A})i,j,如果max({A,B}) = max({B})j,k。但是,如果你得到i,j,那么k可以是适合数组B形状的任何值,所以随机选择其中一个值。

如果您还需要最大化其他值,则:

if np.max(A) < np.max(B):
ind = np.unravel_index(np.argmax(A),A.shape)
ind = ind + (np.argmax(B[ind[1],:]),)

else:
ind = np.unravel_index(np.argmax(B),B.shape)
ind = (np.argmax(A[:,ind[0]]),) + ind

相关内容

  • 没有找到相关文章

最新更新