我想让多个进程并行读取numpy数组的不同行,以加快速度。然而,当我运行以下代码时,第一个到达func的进程会抛出一个错误,就好像var不再在作用域中一样。为什么会发生这种情况?
import numpy as np
import multiprocessing as mp
num_procs = 16
num_points = 2500000
def init_worker(X):
global var
var = X
def func(proc):
X_np = np.frombuffer(var).reshape((num_procs, num_points))
for y in range(num_points):
z = X_np[proc][y]
if __name__ == '__main__':
data = np.random.randn(num_procs, num_points)
X = mp.RawArray('d', num_procs*num_points)
X_np = np.frombuffer(X).reshape((num_procs, num_points))
np.copyto(X_np, data)
pool = mp.Pool(processes=4, initializer=init_worker, initargs=(X,))
for proc in range(num_procs):
pool.apply_async(func(proc))
pool.close()
pool.join()
Traceback (most recent call last):
File "parallel_test.py", line 26, in <module>
pool.apply_async(func(proc))
File "parallel_test.py", line 13, in func
X_np = np.frombuffer(var).reshape((num_procs, num_points))
NameError: global name 'var' is not defined
更新:出于某种原因,如果我使用Pool.map而不是Pool.apply_async的For循环,它似乎可以工作。我不明白为什么。
是否有理由不在顶级作用域中将X
声明为global
?这消除了CCD_ 3。
import numpy as np
import multiprocessing as mp
num_procs = 16
num_points = 25000000
def func(proc):
X_np = np.frombuffer(X).reshape((num_procs, num_points))
for y in range(num_points):
z = X_np[proc][y]
if __name__ == '__main__':
data = np.random.randn(num_procs, num_points)
global X
X = mp.RawArray('d', num_procs*num_points)
X_np = np.frombuffer(X).reshape((num_procs, num_points))
np.copyto(X_np, data)
pool = mp.Pool(processes=4 )
for proc in range(num_procs):
pool.apply_async(func(proc))
pool.close()
pool.join()
当我运行这个问题的简化实例时,n=20:
import numpy as np
import multiprocessing as mp
num_procs = 4
num_points = 5
def func(proc):
X_np = np.frombuffer(X).reshape((num_procs, num_points))
for y in range(num_points):
z = X_np[proc][y]
if __name__ == '__main__':
data = np.random.randn(num_procs, num_points)
global X
X = mp.RawArray('d', num_procs*num_points)
X_np = np.frombuffer(X).reshape((num_procs, num_points))
np.copyto(X_np, data)
pool = mp.Pool(processes=4 )
for proc in range(num_procs):
pool.apply_async(func(proc))
pool.close()
pool.join()
print("n".join(map(str, X)))
我得到以下输出:
-0.63460378046191621.10057247100661070.334587633571652550.64093457149718890.71248887668519820.3676004592133329630.23593304931386933-0.8668969562941349-0.8842756219234690.0059790361056204221.386422154089567-0.877098878222145080.25187448339771057-0.2473967968471952-0.49097088839785210.54235214897502440.018749603863338020.0353047925043780551.32638726689566161.0199839603829742
您尚未提供预期输出的样本。这看起来和你期望的相似吗?