我必须使用带有数值积分的函数(scipy.integrate.quad(来评估矩阵的每个元素。矩阵的元素是 5202x3465 灰度图像的像素。
我可以使用GPU,我想并行评估尽可能多的元素,因为现在,使用线性编程,整个计算时间超过24小时。
这是示例代码:
for i in range(0, rows):
for j in range(0, columns):
img[i, j] = myFun(constant_args, i, j)
def myFunc(constant_args, i, j):
new_pixel = quad(integrand, constant_args, i, j)
... other calculations ...
return new_pixel
我尝试像这样使用多处理(作为 mp(:
arows = list(range(0, rows))
acolumns = list(range(0, columns))
with mp.Pool() as pool:
img = pool.map(myFunc, (constant_args, arows, acolumns))
或者使用 img = pool.map(myFunc(constant_args(, (arows, acolumns((但它给了我:TypeError:
myFunc((
缺少 2 个必需的位置参数:"j"和"i">
我不明白这在其他示例中是如何工作的,也不知道文档中使用的术语。
我只想将该嵌套循环划分为子线程,如果有人建议不同的方法,我会全力以赴。
附言。我尝试使用 numba,但在与某些 Scipy 库交互时会出现错误
提前感谢您的帮助!
首先,错误在于对map
- 操作的调用。它必须是:
arows = list(range(0, rows))
acolumns = list(range(0, columns))
with mp.Pool() as pool:
img = pool.map(myFunc, constant_args, arows, acolumns)
但是,这可能不会产生您要查找的内容,因为这仅通过 3 个参数(必须是列表(运行。它不是通过它们的组合运行的,尤其是arows
和acolumns
.例如,如果constant_args
有 3 个元素,Pool.map
将在 3 次迭代后停止,而不会运行较长的列表arows
或acolumns
。
首先,您应该进行行和列索引的所有组合
from itertools import product, repeat
comb = list(product(arows, acolumns))
这会产生类似的东西(所有可能的组合(
[(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)]
接下来,我会用你的constant_args
压缩这些组合
constant_args = [10, 11]
arguments = list(zip(comb , repeat(constant_args)))
它生成一个元组列表,每个元组包含两个元素。第一个是您的像素位置,第二个是您的constant_args
[((1, 1), [10, 11]),
((1, 2), [10, 11]),
((1, 3), [10, 11]),
((2, 1), [10, 11]),
((2, 2), [10, 11]),
((2, 3), [10, 11]),
((3, 1), [10, 11]),
((3, 2), [10, 11]),
((3, 3), [10, 11])]
现在我们必须稍微修改一下您的myFunc
:
def myFunc(pix_id, constant_args):
new_pixel = quad(integrand, constant_args, pix_id[0], pix_id[1])
... other calculations ...
return new_pixel
最后,我们使用Pool.starmap来施展魔法(见这里:starmap用法(:
with mp.Pool() as pool:
img = pool.starmap(myFunc, arguments )
发生的情况是,starmap
获取元组列表,并将其作为函数的输入。但是,starmap
会自动将元组列表解压缩为函数的单个参数。第一个参数pix_id
由两个元素组成,第二个参数constant_args
。
你可以使用 quadpy(我的一个项目(。它进行矢量化计算,因此运行速度非常快。 输出形状为 2x2 的示例:
import quadpy
def f(x):
return [[x ** 2, x**3], [x**4, x**5]]
val, err = quadpy.quad(f, 0, 1)
print(val)
[[0.33333333 0.25 ]
[0.2 0.16666667]]
f
的输出必须是形状(..., x.shape)
,...
可以是任何元组,例如(5202, 3465)
元组。