如何为单元测试全局设置np.random.default_rng种子



numpy创建随机数的推荐方法是创建一个类似的np.random.Generator

import numpy as np
def foo():
# Some more complex logic here, this is the top level method that creates the rng
rng = np.random.default_rng()
return rng.random()

现在假设我正在为我的代码库编写测试,并且我需要为可重复的结果播种rng。

是否可以告诉numpy每次使用相同的种子,而不管在哪里调用default_rng()?这基本上是np.random.seed()的旧行为。我之所以需要它,是因为我有很多这样的测试,并且必须模拟default_rng调用才能为每个测试使用种子,因为在pytest中,你必须在使用某个东西的位置进行模拟,而不是在定义它的位置。因此,像这样在全球范围内嘲笑它是行不通的。

使用旧方法,可以定义一个fixture,在conftest.py中自动设置每个测试的种子,如下所示:

# conftest.py
import pytest
import numpy as np
@pytest.fixture(autouse=True)
def set_random_seed():
# seeds any random state in the tests, regardless where is is defined
np.random.seed(0)
# test_foo.py
def test_foo():
assert np.isclose(foo(), 0.84123412)  # That's not the right number, just an example

随着default_rng的新使用方式,这似乎不再可能。相反,我需要在每个需要rng种子的测试模块中放置这样的固定装置。

# inside test_foo.py, but also every other test file
import pytest
from unittest import mock
import numpy as np

@pytest.fixture()
def seed_default_rng():
seeded_rng = np.random.default_rng(seed=0)
with mock.patch("module.containing.foo.np.random.default_rng") as mocked:
mocked.return_value = seeded_rng
yield 
def test_foo(seed_default_rng):
assert np.isclose(foo(), 0.84123412)

我想到的最好的办法是在conftest.py中有一个可参数化的fixture,比如这个

# conftest.py
import pytest
from unittest import mock
import numpy as np

@pytest.fixture
def seed_default_rng(request):
seeded_rng = np.random.default_rng(seed=0)
mock_location = request.node.get_closest_marker("rng_location").args[0]
with mock.patch(f"{mock_location}.np.random.default_rng") as mocked:
mocked.return_value = seeded_rng
yield

这可以在每次测试中使用,如下所示:

# test_foo.py
import pytest
from module.containing.foo import foo
@pytest.mark.rng_location("module.containing.foo")
def test_foo(seed_default_rng):
assert np.isclose(foo(), 0.84123412)  # just an example number

它仍然不如以前方便,但您只需要将标记添加到每个测试中,而不需要嘲笑default_rng方法。

如果您想要完整的numpy API,并保证在numpy版本之间有稳定的随机值,那么简单的答案是-您不能。

您可以对np.random.RandomState模块使用一个变通方法,但您牺牲了当前np.random模块的使用——没有好的、稳定的方法来解决这个问题。

为什么numpy.random在不同版本之间不稳定

从numpy v1.16开始,numpy.random.default_rng()使用默认的BitGenerator构建了一个新的Generator。但在np.random.Generator的描述中,附加了以下指南:

没有兼容性保证

Generator不提供版本兼容性保证。特别地,随着更好的算法的发展,比特流可能会改变。

因此,使用np.srandom.default_rng((将在不同平台上为相同版本的numpy保留随机数,但不会在不同版本之间保留。

自从采用NEP 0019:随机数生成器政策以来,情况就是如此。参见摘要:

在过去的十年里,NumPy对其所有随机数分布的数字流都有严格的向后兼容性策略。numpy中的其他数字组件通常被允许在修改结果时返回不同的结果,如果结果保持正确,我们有义务让随机数分布在每个版本中始终产生完全相同的数字。我们流兼容性保证的目的是为numpy版本的模拟提供精确的再现性,以促进可再现性研究。然而,这种策略使得很难用更快或更准确的算法来增强任何分布。经过十年的经验和对周围科学软件生态系统的改进,我们相信现在有更好的方法来实现这些目标。我们建议放宽我们严格的流兼容性政策,以消除阻碍我们接受对随机数生成能力的贡献的障碍。

使用pytest进行测试的变通方法

NEP的一节专门讨论了支持单元测试,并讨论了在遗留np.random.RandomState模块中保持跨版本和平台的流兼容性。从";Legacy Random Generation":

RandomState提供对遗留生成器的访问。这台发电机被认为是冻结的,不会有进一步的改进。它保证产生与NumPy v1.16的最终点发布相同的值。这些都取决于Box-Muller法线或反向CDF指数或伽玛。只有当必须具有与以前版本的NumPy所产生的随机数相同的随机数时,才应使用此类。

np.random.RandomState文档提供了一个示例用法,该用法可用于pytest。重要的一点是,使用np.random.random和其他方法的函数必须使用RandomState实例进行猴痘:

mymod.py的内容


import numpy as np
def myfunc():
return np.random.random(size=3)

test_mymod.py的内容

import pytest
import numpy as np
from numpy.random import RandomState
from mymod import myfunc
@pytest.fixture(autouse=True)
def mock_random(monkeypatch: pytest.MonkeyPatch):
def stable_random(*args, **kwargs):
rs = RandomState(12345)
return rs.random(*args, **kwargs)

monkeypatch.setattr('numpy.random.random', stable_random)
def test_myfunc():
# this test will work across numpy versions
known_result = np.array([0.929616, 0.316376, 0.183919])
np.testing.assert_allclose(myfunc(), known_result, atol=1e-6)

最新更新