我有一个重试装饰器,它根据引发的异常进行操作:
def retry(times, exceptions):
def decorator(func):
def wrapped_func(self):
attempt = 0
while attempt < times:
try:
return func(self)
except exceptions:
self.logger.info(
"Error, retry in process.. attempt {} of {}".format(
attempt, times
)
)
attempt += 1
return func(self)
return wrapped_func
return decorator
并将用于类函数,如:
@retry(times=10, exceptions=(ValueError,))
def function_that_fails(self):
raise ValueError
我试着测试这个功能,比如:
class sample_class:
def __init__(self, logger):
self._logger = logger
self._count = 0
@property
def count(self):
return self._count
@property
def logger(self):
return self._logger
@retry(times=10, exceptions=(ValueError,))
def sample_func(self):
self._count += 1
raise ValueError
class TestUtils(TestCase):
def setUp(self):
self._logger = Mock()
self._sample_class = sample_class(self._logger)
self._count = 0
def test_retry_function(self):
with self.assertRaises(ValueError):
self._sample_class.sample_func()
self.assertEqual(self._sample_class.count, 10)
def test_retry_function_2(self):
@retry(times=10, exceptions=(ValueError,))
def sample_func():
self._count += 1
raise ValueError
with self.assertRaises(ValueError):
sample_func()
self.assertEqual(self._count, 10)
第一个测试有效,但我想避免使用伪类,第二个我收到错误:
E TypeError: wrapped_func() missing 1 required positional argument: 'self'
由于重试函数中的self
参数,我尝试了多个选项,但都没有成功,我做错了什么?
您可以简单地为self
参数使用Mock
对象,而不是创建整个类:
from unittest.mock import Mock
def test_retry_function_2(self):
@retry(times=10, exceptions=(ValueError,))
def sample_func(mock_arg):
mock_arg._count += 1
raise ValueError
mock = Mock(_count=0)
with self.assertRaises(ValueError):
sample_func(mock)
self.assertEqual(mock._count, 10)
更简单的是,您可以使用Mock作为装饰函数,并使用内置的call_count
来代替必须自己递增的_count
:
def test_retry_function_3(self):
mock = Mock(side_effect=ValueError)
with self.assertRaises(ValueError):
retry(10, ValueError)(mock)(Mock())
self.assertEqual(mock.call_count, 11)
甚至可以很容易地验证装饰器在每次重试时调用logger.info
:
def test_retry_function_4(self):
mock_func = Mock(side_effect=ValueError)
mock_self = Mock()
with self.assertRaises(ValueError):
retry(10, ValueError)(mock_func)(mock_self)
self.assertEqual(mock_func.call_count, 11)
self.assertEqual(mock_self.logger.info.call_count, 10)