如何将np.linal.solve映射到矩阵数组并保持速度



我有一个线性问题要解决很多次:Ax=BA是dim n的平方矩阵,B是维数为n的向量。我需要找到x:

import numpy as np
A = np.random.rand(2,2)
B = np.random.rand(2)
x = np.linalg.solve(A,B)

这是非常基本的。问题是我想解决这个问题很多次。当前的实现是这样的:

import numpy as np
k = 50 # the number of systems to solve
A_list = np.random.rand(k,2,2)
B_list = np.random.rand(k,2)
x = np.array([np.linalg.solve(A,B) for A, B in zip(A_list, B_list)])

但速度很慢。在这个网站的人的帮助下,我可以使用np.newaxis来做智能广播,从而消除代码中的一个巨大瓶颈。我想知道是否有类似的技巧可以用于这类函数(np.linalg.solvenp.linalg.det等)

我的np.vectorize测试失败了。

编辑:

输出

>>> import numpy as np
>>> k = 50 # the number of systems to solve
>>> A_list = np.random.rand(k,2,2)
>>> B_list = np.random.rand(k,2)
>>> x = np.array([np.linalg.solve(A,B) for A, B in zip(A_list, B_list)])
---------------------------------------------------------------------------
LinAlgError                               Traceback (most recent call last)
<ipython-input-53-fecc7a7edaf9> in <module>()
----> 1 solution = np.linalg.solve(A_list,B_list)
/usr/lib/python3.3/site-packages/numpy/linalg/linalg.py in solve(a, b)
    309     if one_eq:
    310         b = b[:, newaxis]
--> 311     _assertRank2(a, b)
    312     _assertSquareness(a)
    313     n_eq = a.shape[0]
/usr/lib/python3.3/site-packages/numpy/linalg/linalg.py in _assertRank2(*arrays)
    153         if len(a.shape) != 2:
    154             raise LinAlgError('%d-dimensional array given. Array must be '
--> 155                     'two-dimensional' % len(a.shape))
    156 
    157 def _assertSquareness(*arrays):
LinAlgError: 3-dimensional array given. Array must be two-dimensional

Numpy 1.8

您不必从numpy版本1.8:开始做任何事情

np.linalg.solve(A_list, B_list)

演示:

>>> import numpy as np
>>> np.random.seed(11)
>>> k = 10
>>> A_list = np.random.rand(k,2,2)
>>> B_list = np.random.rand(k,2)
>>> solution = np.linalg.solve(A_list,B_list)
>>> all(np.allclose(np.dot(A_list[i, :], solution[i,:]), B_list[i, :])
...         for i in range(A_list.shape[0]))
True

Numpy 1.7及更早版本

在旧版本上,可以尝试使用scipy.linalg.block_diag,但它会引入一些开销,包括内存和速度,对于更大的k,它将输给zip方法:

import scipy.linalg
A = scipy.linalg.block_diag(*A_list)
B = B_list.reshape(-1)
solution = np.linalg.solve(A,B)
solution.reshape(-1,2)

1.8的速度测试

对于k=2000; seed=11:

>>> timeit('from __main__ import np, A_list, B_list; np.linalg.solve(A_list, B_list)', number = 100)
0.2786309433182055
>>> timeit('from __main__ import np, A_list, B_list; np.array([np.linalg.solve(A,B) for A, B in zip(A_list, B_list)])', number = 100)
8.431871369126554
>>> timeit('from __main__ import np, A, B; np.linalg.solve(A,B)', number = 100)
43.4851636171674712

相关内容

  • 没有找到相关文章

最新更新