numpy rng线程安全吗?



我实现了一个函数,它使用numpy随机生成器来模拟一些过程。下面是这样一个函数的最小示例:

def thread_func(cnt, gen):
s = 0.0
for _ in range(cnt):
s += gen.integers(6)
return s

现在我写了一个函数,使用python的starmap来调用thread_func。如果我这样写(将相同的rng引用传递给所有进程):

from multiprocessing import Pool
import numpy as np    
def evaluate(total_cnt, thread_cnt):
gen = np.random.default_rng()
cnt_per_thread = total_cnt // thread_cnt
with Pool(thread_cnt) as p:
vals = p.starmap(thread_func, [(cnt_per_thread,gen) for _ in range(thread_cnt)])
return vals

evaluate(100000, 5)的结果是一个包含5个相同值的数组,例如:

[49870.0, 49870.0, 49870.0, 49870.0, 49870.0]

然而,如果我传递一个不同的循环给所有进程,例如:

vals = p.starmap(thread_func, [(cnt_per_thread,np.random.default_rng()) for _ in range(thread_cnt)])
我得到预期的结果(5个不同的值),例如:
[49880.0, 49474.0, 50232.0, 50038.0, 50191.0]

为什么会发生这种情况?

TL;DR正如@MichaelSzczesny所指出的,主要问题似乎是你使用的进程在具有相同初始状态的相同RNG对象的副本上操作。


随机数生成器(RNG)对象初始化为一个称为种子的整数,当使用迭代操作生成新数字时修改该整数。(seed * huge_number) % another_huge_number)。

使用同一个RNG对象对多个线程的操作是固有顺序的,这不是一个好主意。在最好的情况下,如果两个线程以受保护的方式访问它(例如。使用临界区时,结果取决于线程的顺序。此外,性能也会受到影响,因为这样做会导致缓存线反弹,减慢访问同一对象的线程的执行速度。在最坏的情况下,RNG对象是不受保护的,这会导致竞争条件. 这样的问题可能会导致多个线程的种子是相同的,因此结果(应该是随机的)。

CPython使用名为全局解释器锁(GIL)的巨型互斥锁来保护对Python对象的访问。它可以防止多个线程同时执行Python字节码。目标是保护解释器,而不是对象状态。Numpy的许多函数释放GIL,因此代码可以并行扩展。问题是,如果你在同一个线程中使用它们,会导致竞争条件。你有责任使用锁来保护Numpy对象.

在你的情况下,我不能用线程重现问题,但我可以用进程。因此,我认为您在示例中使用了进程。对于进程,您应该使用:

from multiprocessing import Pool

对于线程,你应该使用:

from multiprocessing.pool import ThreadPool as Pool

进程的行为不同于线程,因为它们而不是操作共享对象(至少默认情况下不是)。相反,进程操作对象副本进程产生相同的输出,因为RNG对象的初始状态在所有进程中是相同的.

简而言之,请每个线程使用一个不同的RNG。一个典型的解决方案是用它们自己的RNG对象创建N个线程,然后与它们通信以发送一些工作(例如:使用队列)。这被称为线程池。另一种选择可能是使用线程本地存储。

请注意,Numpy文档在多线程生成一节中提供了一个示例。

最新更新