我想使用np.random.choice在多处理池中,但我得到IndexError:列表索引超出范围. 当我在for循环中使用choice函数时,我没有得到任何错误(然后是串联的,而不是并行的)。对于如何克服这个问题有什么想法吗?这只是我例行程序的一小部分,但肯定会大大提高它的速度。我的代码如下所示。我在例程中的其他内容之前声明X,因此它作为全局变量工作,但它是在main中动态填充的. 我也注意到,有一些冲突与多处理和for循环. 对如何实现这一点有什么想法吗?
from multiprocessing import Pool
from numpy.random import choice
import numpy as np
K = 10
X = []
def function(k):
global X
np.random.RandomState(k)
aux = [i for i in np.arange(K) if i != k]
a,b,c = choice(aux,3,replace=False)
x = X[a]+0.7*(X[b]-X[c])
return x
if __name__ == '__main__':
X = np.arange(K)
for n in range(K):
pool = Pool(K)
w = pool.map(function,np.arange(K))
pool.close()
print(w)
子进程不共享父进程的内存空间。由于您在if __name__ ...
子句中填充X
,子进程只能访问顶部模块中定义的X,即X = []
一个快速的解决方案是将X = np.arange(K)
行移出子句,如下所示:
from multiprocessing import Pool
from numpy.random import choice
import numpy as np
K = 10
X = []
X = np.arange(K)
def function(k):
global X
np.random.RandomState(k)
aux = [i for i in np.arange(K) if i != k]
a, b, c = choice(aux, 3, replace=False)
x = X[a] + 0.7 * (X[b] - X[c])
return k, x
if __name__ == '__main__':
pool = Pool(10)
w = pool.map(function, np.arange(K))
pool.close()
print(w)
[(0, 10.899999999999999), (1, 9.4), (2, 5.7), (3, 7.4), (4, 1.1000000000000005), (5, -1.0999999999999996), (6, 5.6), (7, 3.8), (8, 5.5), (9, -4.8999999999999995)]
如果你不想为所有子进程初始化X
(内存限制?),你可以使用一个管理器来存储X
,它可以共享给进程,而不必为每个子进程复制它。要向子进程传递多个参数,还必须使用pool。starmap代替。最后,删除global X
,它没有做任何有用的事情,因为global
只在您计划从局部作用域修改全局变量时使用。
from multiprocessing import Pool, Manager
from numpy.random import choice
import numpy as np
K = 10
def function(X, k):
np.random.RandomState(k)
aux = [i for i in np.arange(K) if i != k]
a, b, c = choice(aux, 3, replace=False)
x = X[a] + 0.7 * (X[b] - X[c])
return k, x
if __name__ == '__main__':
m = Manager()
X = m.list(np.arange(K))
pool = Pool(10)
args = [(X, val) for val in np.arange(K)]
w = pool.starmap(function, args)
pool.close()
print(w)
[(0, -1.5999999999999996), (1, 7.3), (2, 4.9), (3, 1.9000000000000004), (4, 5.5), (5, -1.0999999999999996), (6, 4.800000000000001), (7, 7.3), (8, 0.10000000000000053), (9, 4.7)]