我有这样的uuid的数据类:
import uuid
from dataclasses import dataclass, field
from typing import Union
@dataclass
class Foo:
id: Union[uuid.UUID, None] = field(default_factory=uuid.uuid4)
当我调用Foo()
时,它会创建一个具有生成的 UUID 的对象 - 很好。
现在我想在测试中模拟这个 UUID 工厂,就像这样:
from unittest.mock import patch
TEST_UUIDS = ["uuid1", "uuid2"]
with patch.object(uuid, "uuid4", side_effects=TEST_UUIDS):
print(uuid.uuid4()) # Output: uuid1
print(Foo().id) # Output: an actual UUID
我的预期输出是uuid2
.那么问题来了:我如何才能正确修补工厂?在文档或这里找不到任何内容...
TL;DR最好只做print(Foo(id="uuid1").id)
解释
这是不起作用的:(除了side_effect(s(的错别字(
with patch.object(uuid, "uuid4", side_effect=TEST_UUIDS):
print(id(uuid.uuid4))
print(id(Foo.__dataclass_fields__["id"].default_factory))
输出:
140470453723216
140470456506848
这是因为patch 不会更改实际的uuid.uuid4
函数,而是创建一个新函数,覆盖命名空间中的正常uuid.uuid4
,如 id 的变化所示。由于工厂已经初始化,因此此更改不起作用,除非您在with
块中创建类,这对于实际测试是不切实际的。
但即使这样做了,也不会有什么不同,因为 dataclass生成__init__
函数,并在类定义期间只创建一次对默认工厂的引用,而不是每次__init__
调用。因此,以下内容也不起作用:
with patch.object(Foo.__dataclass_fields__["id"], "default_factory",
side_effect=TEST_UUIDS):
print(Foo().id)
您可以做的是覆盖__init__
函数:
from copy import copy
# Copy is required to avoid recursion after rewrite
foo_init_copy = copy(Foo.__init__)
def mock_init(foo):
foo_init_copy(foo)
foo.id = "x"
with patch.object(Foo, "__init__", new=mock_init):
...
这对我来说似乎有点矫枉过正,如果工厂通常不执行复杂的逻辑。如果是这种情况,还可以考虑创建一个普通类而不是数据类。