np.random.choice与multiprocessing冲突?在for循环内进行多处理? &g



我想使用np.random.choice多处理池中,但我得到IndexError:列表索引超出范围. 当我在for循环中使用choice函数时,我没有得到任何错误(然后是串联的,而不是并行的)。对于如何克服这个问题有什么想法吗?这只是我例行程序的一小部分,但肯定会大大提高它的速度。我的代码如下所示。我在例程中的其他内容之前声明X,因此它作为全局变量工作,但它是在main中动态填充的. 我也注意到,有一些冲突与多处理for循环. 对如何实现这一点有什么想法吗?

from multiprocessing import Pool
from numpy.random import choice
import numpy as np
K = 10
X = []
def function(k):

global X
np.random.RandomState(k)
aux = [i for i in np.arange(K) if i != k]
a,b,c = choice(aux,3,replace=False)
x = X[a]+0.7*(X[b]-X[c])
return x
if __name__ == '__main__':

X = np.arange(K)
for n in range(K):

pool = Pool(K)
w = pool.map(function,np.arange(K))
pool.close()

print(w)

子进程不共享父进程的内存空间。由于您在if __name__ ...子句中填充X,子进程只能访问顶部模块中定义的X,即X = []

一个快速的解决方案是将X = np.arange(K)行移出子句,如下所示:

from multiprocessing import Pool
from numpy.random import choice
import numpy as np
K = 10
X = []
X = np.arange(K)

def function(k):
global X
np.random.RandomState(k)
aux = [i for i in np.arange(K) if i != k]
a, b, c = choice(aux, 3, replace=False)
x = X[a] + 0.7 * (X[b] - X[c])
return k, x

if __name__ == '__main__':
pool = Pool(10)
w = pool.map(function, np.arange(K))
pool.close()
print(w)

[(0, 10.899999999999999), (1, 9.4), (2, 5.7), (3, 7.4), (4, 1.1000000000000005), (5, -1.0999999999999996), (6, 5.6), (7, 3.8), (8, 5.5), (9, -4.8999999999999995)]

如果你不想为所有子进程初始化X(内存限制?),你可以使用一个管理器来存储X,它可以共享给进程,而不必为每个子进程复制它。要向子进程传递多个参数,还必须使用pool。starmap代替。最后,删除global X,它没有做任何有用的事情,因为global只在您计划从局部作用域修改全局变量时使用。

from multiprocessing import Pool, Manager
from numpy.random import choice
import numpy as np
K = 10

def function(X, k):
np.random.RandomState(k)
aux = [i for i in np.arange(K) if i != k]
a, b, c = choice(aux, 3, replace=False)
x = X[a] + 0.7 * (X[b] - X[c])
return k, x

if __name__ == '__main__':
m = Manager()
X = m.list(np.arange(K))
pool = Pool(10)
args = [(X, val) for val in np.arange(K)]
w = pool.starmap(function, args)
pool.close()
print(w)

[(0, -1.5999999999999996), (1, 7.3), (2, 4.9), (3, 1.9000000000000004), (4, 5.5), (5, -1.0999999999999996), (6, 4.800000000000001), (7, 7.3), (8, 0.10000000000000053), (9, 4.7)]

最新更新