使用 scipy.integrate.quad 并行计算矩阵元素



我必须使用带有数值积分的函数(scipy.integrate.quad(来评估矩阵的每个元素。矩阵的元素是 5202x3465 灰度图像的像素。
我可以使用GPU,我想并行评估尽可能多的元素,因为现在,使用线性编程,整个计算时间超过24小时。
这是示例代码:

for i in range(0, rows):
for j in range(0, columns):
img[i, j] = myFun(constant_args, i, j)
def myFunc(constant_args, i, j):
new_pixel = quad(integrand, constant_args, i, j)
... other calculations ... 
return new_pixel

我尝试像这样使用多处理(作为 mp(:

arows = list(range(0, rows))
acolumns = list(range(0, columns))
with mp.Pool() as pool:
img = pool.map(myFunc, (constant_args, arows, acolumns))

或者使用 img = pool.map(myFunc(constant_args(, (arows, acolumns((但它给了我:TypeError:
myFunc((
缺少 2 个必需的位置参数:"j"和"i">

我不明白这在其他示例中是如何工作的,也不知道文档中使用的术语。
我只想将该嵌套循环划分为子线程,如果有人建议不同的方法,我会全力以赴。
附言。我尝试使用 numba,但在与某些 Scipy 库交互时会出现错误

提前感谢您的帮助!

首先,错误在于对map- 操作的调用。它必须是:

arows = list(range(0, rows))
acolumns = list(range(0, columns))
with mp.Pool() as pool:
img = pool.map(myFunc, constant_args, arows, acolumns)

但是,这可能不会产生您要查找的内容,因为这仅通过 3 个参数(必须是列表(运行。它不是通过它们的组合运行的,尤其是arowsacolumns.例如,如果constant_args有 3 个元素,Pool.map将在 3 次迭代后停止,而不会运行较长的列表arowsacolumns

首先,您应该进行行和列索引的所有组合

from itertools import product, repeat
comb = list(product(arows, acolumns))

这会产生类似的东西(所有可能的组合(

[(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)]

接下来,我会用你的constant_args压缩这些组合

constant_args = [10, 11]
arguments = list(zip(comb , repeat(constant_args)))

它生成一个元组列表,每个元组包含两个元素。第一个是您的像素位置,第二个是您的constant_args

[((1, 1), [10, 11]),
((1, 2), [10, 11]),
((1, 3), [10, 11]),
((2, 1), [10, 11]),
((2, 2), [10, 11]),
((2, 3), [10, 11]),
((3, 1), [10, 11]),
((3, 2), [10, 11]),
((3, 3), [10, 11])]

现在我们必须稍微修改一下您的myFunc

def myFunc(pix_id, constant_args):
new_pixel = quad(integrand, constant_args, pix_id[0], pix_id[1])
... other calculations ... 
return new_pixel

最后,我们使用Pool.starmap来施展魔法(见这里:starmap用法(:

with mp.Pool() as pool:
img = pool.starmap(myFunc, arguments )

发生的情况是,starmap获取元组列表,并将其作为函数的输入。但是,starmap会自动将元组列表解压缩为函数的单个参数。第一个参数pix_id由两个元素组成,第二个参数constant_args

你可以使用 quadpy(我的一个项目(。它进行矢量化计算,因此运行速度非常快。 输出形状为 2x2 的示例:

import quadpy

def f(x):
return [[x ** 2, x**3], [x**4, x**5]]

val, err = quadpy.quad(f, 0, 1)
print(val)
[[0.33333333 0.25      ]
[0.2        0.16666667]]

f的输出必须是形状(..., x.shape)...可以是任何元组,例如(5202, 3465)元组。

相关内容

  • 没有找到相关文章