有没有一种方法可以模拟在另一个函数调用中调用的函数返回?例如:
def bar():
return "baz"
def foo():
return bar()
class Tests(unittest.TestCase):
def test_code(self):
# hijack bar() here to return "bat" instead
assert(foo() == "bat")
我尝试过使用@mock.patch
,但发现它只允许模拟我正在调用的函数,而不允许模拟由于调用不同函数而被调用的函数。
通用补丁
unittest.mock.patch
完全按照我另一个答案的建议行事。您可以根据需要添加任意数量的@patch
注释,您选择的对象将被修补:
from unittest.mock import patch
def bar():
return "baz"
def foo():
return bar()
class Tests(unittest.TestCase):
@patch(__name__ + '.bar', lambda: 'bat')
def test_code(self):
assert(foo() == "bat")
在这种配置中,一旦test_code
完成,功能bar
将被恢复。如果您希望将相同的补丁应用于类中的所有测试用例,请对整个类进行注释:
@patch(__name__ + '.bar', lambda: 'bat')
class Tests(unittest.TestCase):
def test_code(self):
assert(foo() == "bat")
补丁全局
您也可以在全局命名空间上unittest.mock.patch.dict
以获得相同的结果:
class Tests(unittest.TestCase):
@patch.dict(globals(), {'bar': lambda: 'bat'})
def test_code(self):
assert(foo() == "bat")
您可以编写一个上下文管理器来临时交换全局命名空间中的对象:
class Hijack:
def __init__(self, name, replacement, namespace):
self.name = name
self.replacement = replacement
self.namespace = namespace
def __enter__(self):
self.original = self.namespace[self.name]
self.namespace[self.name] = self.replacement
def __exit__(self, *args):
self.namespace[self.name] = self.original
您可以使用劫持的方法调用mock函数:
def bar():
return "baz"
def bar_mock():
return "bat"
def foo():
return bar()
class Tests(unittest.TestCase):
def test_code(self):
with Hijack('bar', bar_mock, globals()):
assert(foo() == "bat")
这是一种非常通用的方法,可以在单元测试之外使用。事实上,将其推广到任何可以表示为某种映射的可变对象上是非常简单的:
class Hijack:
def __init__(self, name, replacement, namespace, getter=None, setter=None):
self.name = name
self.replacement = replacement
self.namespace = namespace
self.getter = type(namespace).__getitem__ if getter is None else getter
self.setter = type(namespace).__setitem__ if setter is None else setter
def __enter__(self):
self.original = self.getter(self.namespace, self.name)
self.setter(self.namespace, self.name, self.replacement)
def __exit__(self, *args):
self.setter(self.namespace, self.name, self.original)
对于类和其他对象,可以使用getter=getattr
和setter=setattr
。对于None
优于KeyError
的情况,可以使用getter=dict.get
等
此外,您可以使用monkeypatch
夹具的pytestseattr
方法。对于第一个参数,它接受要修补的对象,或者字符串将被解释为虚线导入路径,最后一部分是属性名称:
# foo_module.py
def bar():
return "baz"
def foo():
return bar()
# test_foo.py
from foo_module import foo
def test_foo(monkeypatch):
monkeypatch.setattr('foo_module.bar', lambda: 'bat')
assert foo() == "bat"