如何使用python获取mock对象的内部状态



我想在插入操作中使用mock对象来模拟数据库。

例如,假设我有一个类似insert(1(的方法,它调用一个数据库连接对象(我们称之为db_obj(,并对mytable(col1(值(%s(执行一些插入操作,其中%s是1。

我的想法是:为db_obj创建一些存储col1值的mock对象,所以当调用db_obj.insert(1(时,mock db_obj存储这个col1=1,然后,我可以获得mock对象col1值,并断言它是1。

这种方法有意义吗?如果是,我如何使用pytest来完成此操作?

下面是我尝试的一个例子

from hub.scripts.tasks.notify_job_module import http_success
from hub.scripts.tasks.notify_job_module import proceed_to
from hub.core.connector.mysql_db import MySql
from unittest.mock import patch
import logging
def proceed_to(conn,a,b,c,d,e):
conn.run_update('insert into table_abc (a,b,c,d,e) values (%s,%s,%s,%s,%s)',[a,b,c,d,e])
class MySql:
# this is the real conn implementation
def connect(self):
# ... create the database connection here
def run_update(self, query, params=None):
# ... perform some insert into the database here
class TestMySql:
# this is the mock implementation I want to get rid of
def connect(self):
pass
def run_update(self, query, params=None):
self.result = params
def get_result(self):
return self.result
def test_proceed_to():
logger = logging.getLogger("")
conn = TestMySql() ## originally, my code uses MySql() here
conn.connect()
proceed_to(conn,1,'2',3,4,5)
assert conn.get_result()[1] == 4

请注意,我不得不用TestMySql((替换MySql(((,所以我所做的是手动实现我自己的Mock对象。

这样做是有效的,但我觉得这显然不是最好的方法。为什么?

因为我们谈论的是mock对象,processed_to的定义在这里无关紧要:-(问题是:我必须实现TestMySql.get_result((并将我想要的数据存储在self.result中才能获得我想要的结果,而MySql本身没有get_result(!

我想避免在这里创建自己的mock对象,并在这里使用unittest.mock 使用一些更聪明的方法

您测试的基本上是用什么参数调用run_update。您可以对连接进行模拟,并在模拟上使用assert_colled_xxx方法,如果您想检查特定的参数而不是所有的参数,您可以在模拟上检查call_args。下面是一个与示例代码匹配的示例:

@mock.patch("some_module.MySql")
def test_proceed_to(mocked):
# we need the mocked instance, not the class
sql_mock = mocked.return_value  
conn = sql_mock.connect()  # conn is now a mock - you also could have created a mock manually
proceed_to(conn, 1, '2', 3, 4, 5)
# assert that the function was called with the correct arguments
conn.run_update.assert_called_once()
# conn.run_update.assert_called_once_with() would compare all arguments
assert conn.run_update.call_args[0][1] == [1, '2', 3, 4, 5]
# call_args[0] are the positional arguments, so this checks the second positional argument

最新更新