从 3D 矩阵堆栈构建块对角矩阵 3D 堆栈的有效方法



我正在尝试从给定的矩阵堆栈(nXmXm(中以numpy/scipy的形式构造一堆nXMXM形式的块对角矩阵,其中M = k * m,k是矩阵堆栈的数量。目前,我正在使用 for 循环中的 scipy.linalg.block_diag 函数来执行此任务:

import numpy as np
import scipy.linalg as linalg
a = np.ones((5,2,2))
b = np.ones((5,2,2))
c = np.ones((5,2,2))
result = np.zeros((5,6,6))
for k in range(0,5):
result[k,:,:] = linalg.block_diag(a[k,:,:],b[k,:,:],c[k,:,:])

但是,由于 n 在我的情况下变得非常大,我正在寻找一种比 for 循环更有效的方法。我发现 3D numpy 数组进入块对角矩阵,但这并不能真正解决我的问题。我能想象到的任何事情都是将每堆矩阵转换为块对角线

import numpy as np
import scipy.linalg as linalg
a = np.ones((5,2,2))
b = np.ones((5,2,2))
c = np.ones((5,2,2))
a = linalg.block_diag(*a)
b = linalg.block_diag(*b)
c = linalg.block_diag(*c)

并通过重塑从中构建生成的矩阵

result = linalg.block_diag(a,b,c)
result = result.reshape((5,6,6))

不会重塑。我什至不知道这种方法是否会更有效,所以我问我是否走在正确的轨道上,或者是否有人知道构建这个块对角线 3D 矩阵的更好方法,或者我是否必须坚持使用 for 循环解决方案。

编辑:由于我是这个平台的新手,我不知道该把这个留在哪里(编辑或回答?(,但我想分享我的最终解决方案panadestein 的突出显示解决方案工作得非常好且简单,但我现在使用更高维数组,我的矩阵驻留在最后两个维度。此外,我的矩阵不再具有相同的维度(主要是 1x1、2x2、3x3 的混合物(,所以我采用了 V. Ahrat 的解决方案,但进行了细微的更改:

def nd_block_diag(arrs):
shapes = np.array([i.shape for i in arrs])
out = np.zeros(np.append(np.amax(shapes[:,:-2],axis=0), [shapes[:,-2].sum(), shapes[:,-1].sum()]))
r, c = 0, 0
for i, (rr, cc) in enumerate(shapes[:,-2:]):
out[..., r:r + rr, c:c + cc] = arrs[i]
r += rr
c += cc
return out

这也适用于数组广播,如果输入数组的形状正确(即,要广播的维度不会自动添加(。感谢 pandestein 和 V. Ayrat 的善意和快速帮助,我学到了很多关于列表推导和数组索引/切片的可能性!

我不认为你可以逃脱所有可能的循环来解决你的问题。我发现比您的for循环更方便且可能更有效的一种方法是使用列表理解:

import numpy as np
from scipy.linalg import block_diag
# Define input matrices
a = np.ones((5, 2, 2))
b = np.ones((5, 2, 2))
c = np.ones((5, 2, 2))
# Generate block diagonal matrices
mats = np.array([a, b, c]).reshape(5, 3, 2, 2)
result = [block_diag(*bmats) for bmats in mats]

也许这可以给你一些想法来改进你的实现。

block_diag也只是遍历形状。几乎所有时间都花在复制数据上,因此您可以按照自己的任何方式进行操作,例如,只需很少更改block_diag的源代码

arrs = a, b, c
shapes = np.array([i.shape for i in arrs])
out = np.zeros([shapes[0, 0], shapes[:, 1].sum(), shapes[:, 2].sum()])
r, c = 0, 0
for i, (_, rr, cc) in enumerate(shapes):
out[:, r:r + rr, c:c + cc] = arrs[i]
r += rr
c += cc
print(np.allclose(result, out))
# True

最新更新