如何使用连接作为上下文管理器模拟 class 与 Unittest 的数据库连接?
How to mock db connection with Unittest for class with connection as context manager?
这是 class 在 class 初始化时使用其中一种方法 运行:
class StatCollector:
def __init__(self, poll_stat) -> None:
self.polls = self.__get_polls()
def __get_polls(self) -> Dict[str, Poll]:
with pyodbc.connect(MSSQL_CONNECTION_PARAMS) as cnxn:
polls = dict()
cursor = cnxn.cursor()
query = self.poll_stat.poll_ids_query_getter()
cursor.execute(query, self.poll_stat.products_block)
for poll in map(Poll._make, cursor.fetchall()):
polls[poll.poll_id] = poll
return polls
我想测试这个 class 的其他方法,我的第一个目标是用初始值填充 self.polls
,与 db 没有真正的连接,并使用 __get_polls
方法。我的尝试:
@patch("pyodbc.connect")
class testStatCollector(unittest.TestCase):
def test_initial_values_setted(self, mock_connect):
cursor = MagicMock(name="my_cursor")
cursor.fetchall.return_value = [("2", "А", "B")]
cnxn = MagicMock(name="my_cnxn_mfk")
cnxn.cursor.return_value = cursor
mock_connect.return_value.__enter__ = cnxn
self.test_class = PollsStatCollector(IVR)
self.assertEqual(
self.test_class.polls, {"2": Poll("2", "A", "B")}
)
self.assertIsInstance(self.test_class.period_start_time, datetime)
但是self.polls
执行后是空的。我有:
AssertionError: {} != {'2': Poll(poll_id='2', product='A', products_block='B')}
我在调试中看到,当 __get_polls
执行时,cnxn 名称 = my_cnxn_mfk
,但随后游标的默认名称 = <MagicMock name='my_cnxn_mfk().cursor()' id='1883689785424'>
。
所以我猜我在这部分犯了错误cnxn.cursor.return_value = cursor
,但我不知道如何修正。
错误在这里:
mock_connect.return_value.__enter__ = cnxn
应替换为
mock_connect.return_value.__enter__.return_value = cnxn
这是 class 在 class 初始化时使用其中一种方法 运行:
class StatCollector:
def __init__(self, poll_stat) -> None:
self.polls = self.__get_polls()
def __get_polls(self) -> Dict[str, Poll]:
with pyodbc.connect(MSSQL_CONNECTION_PARAMS) as cnxn:
polls = dict()
cursor = cnxn.cursor()
query = self.poll_stat.poll_ids_query_getter()
cursor.execute(query, self.poll_stat.products_block)
for poll in map(Poll._make, cursor.fetchall()):
polls[poll.poll_id] = poll
return polls
我想测试这个 class 的其他方法,我的第一个目标是用初始值填充 self.polls
,与 db 没有真正的连接,并使用 __get_polls
方法。我的尝试:
@patch("pyodbc.connect")
class testStatCollector(unittest.TestCase):
def test_initial_values_setted(self, mock_connect):
cursor = MagicMock(name="my_cursor")
cursor.fetchall.return_value = [("2", "А", "B")]
cnxn = MagicMock(name="my_cnxn_mfk")
cnxn.cursor.return_value = cursor
mock_connect.return_value.__enter__ = cnxn
self.test_class = PollsStatCollector(IVR)
self.assertEqual(
self.test_class.polls, {"2": Poll("2", "A", "B")}
)
self.assertIsInstance(self.test_class.period_start_time, datetime)
但是self.polls
执行后是空的。我有:
AssertionError: {} != {'2': Poll(poll_id='2', product='A', products_block='B')}
我在调试中看到,当 __get_polls
执行时,cnxn 名称 = my_cnxn_mfk
,但随后游标的默认名称 = <MagicMock name='my_cnxn_mfk().cursor()' id='1883689785424'>
。
所以我猜我在这部分犯了错误cnxn.cursor.return_value = cursor
,但我不知道如何修正。
错误在这里:
mock_connect.return_value.__enter__ = cnxn
应替换为
mock_connect.return_value.__enter__.return_value = cnxn