如何用 pytest 修补 class 构造函数?

How to patch class constructor with pytest?

我有一个 class SiteManger,我想在我的 main 函数的单元测试中修补它(pytest==6.2.5,python 3.10 .1):

from src.data_interface import DataInterface

class SiteManager:

    def __init__(self, time_span, process_id):
        self._time_span = time_span
        self._data_interface = DataInterface(self._time_span, process_id)
        self._sites = self._data_interface.create_sites()
 ...

classSiteManager使用另一个classDataInterface.

如果我修补 SiteManager,我希望 SiteManager 的原始构造函数不会在测试中被调用 。但是,我测试main returns 时出现错误,而 运行 上面显示的 SiteManager 的原始构造函数(因为 DataInterface 的代码包含错误) .

=> 我怎样才能正确修补 SiteManager class、 以便我对 main 的单元测试不依赖于修补的 class es 及其子依赖项?

如果 DataInterface 包含错误,只有 DataInterface 的单元测试应该失败,而不是 mainSiteManager.

的单元测试

我的主要测试:

from src.main import main
from mock import patch


class TestMain:

    @patch('src.site.site_manager.SiteManager')
    @patch('src.simulation.simulation.Simulation')
    @patch('src.time_utils.create_time_span', return_value=['Y2015'])
    @patch('builtins.print')
    def test_main(self, mocked_print, mocked_create_time_span, mocked_simulation, mocked_site_manager):

        main()

        #mocked_site_manager.__init__.assert_called_once()
        #mocked_get_message.__init__.assert_called_once()

main.py:

from src.site.site_manager import SiteManager
from src.simulation.simulation_mode import SimulationMode
from src.simulation.simulation import Simulation
from src.time_utils import create_time_span


def main():
    simulation_mode = SimulationMode.DETERMINISTIC  # SimulationMode.select()
    time_span = create_time_span(2015, 2017, 5)  # for example ['Y2015']
    process_id = 39
    co2_cost = 0  # input('CO2 cost: ')

    site_manager = SiteManager(time_span, process_id)
    simulation = Simulation(simulation_mode, time_span, co2_cost, site_manager)
    simulation.run()
    # simulation.show_evaluation()


if __name__ == '__main__':
    main()

错误:

self = <src.data_interface.DataInterface object at 0x0000023603665C60>, time_span = ['Y2015'], process_id = 39

    def __init__(self, time_span, process_id):
        years = ', '.join(time_span)
        self._scenario_id = '40100'
        self._country_id = '9'

>       connection = sqlite3.connect('../input/industrial_database.sqlite')
E       sqlite3.OperationalError: unable to open database file

使用@patch.object 而不是@patch 达到了目的:

from src.main import main
from src.site.site_manager import SiteManager
from src.simulation.simulation import Simulation
import src.time_utils as time_utils
from mock import patch, MagicMock


class MockedSiteManager:
    pass


class TestMain:

    def site_manager_init_mock(self, time_span, process_id):
        pass

    def simulation_init_mock(self, simulation_mode, time_span, co2_cost, site_manager):
        self.run = MagicMock()

    @patch.object(SiteManager, '__init__', site_manager_init_mock)
    @patch.object(Simulation, '__init__', simulation_init_mock)
    @patch.object(time_utils, 'create_time_span', return_value=['Y2015'])
    def test_main(self, foo):
        try:
            main()
        except Exception as exc:
            assert False, f"main raised an exception {exc}"
        main()