python 模拟 sqlalchemy 连接

python mocking sqlalchemy connection

我有一个连接到数据库并获取一些数据的简单函数。

db.py

from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool


def _create_engine(app):
    impac_engine = create_engine(
        app['DB'],
        poolclass=NullPool  # this setting enables NOT to use Pooling, preventing from timeout issues.
    )
    return impac_engine


def get_all_pos(app):
    engine = _create_engine(app)
    qry = """SELECT DISTINCT id, name FROM p_t ORDER BY name ASC"""
    try:
        cursor = engine.execute(qry)
        rows = cursor.fetchall()
        return rows
    except Exception as re:
        raise re

我正在尝试通过模拟此连接来编写一些测试用例 -

tests.py

import unittest
from db import get_all_pos
from unittest.mock import patch
from unittest.mock import Mock


class TestPosition(unittest.TestCase):

    @patch('db.sqlalchemy')
    def test_get_all_pos(self, mock_sqlalchemy):
        mock_sqlalchemy.create_engine = Mock()
        get_all_pos({'DB': 'test'})




if __name__ == '__main__':
    unittest.main()

当我 运行 上述文件 python tests.py 时,出现以下错误 -

   "Could not parse rfc1738 URL from string '%s'" % name
sqlalchemy.exc.ArgumentError: Could not parse rfc1738 URL from string 'test'

不应该 mock_sqlalchemy.create_engine = Mock() 给我一个模拟对象并绕过 URL 检查。


另一种选择是模拟您的 _create_engine 函数。因为这是一个单元测试,我们想要测试 get_all_pos,所以我们不需要依赖 _create_engine 的行为,所以我们可以像这样打补丁。

import unittest
import db
from unittest.mock import patch


class TestPosition(unittest.TestCase):

    @patch.object(db, '_create_engine')
    def test_get_all_pos(self, mock_sqlalchemy):
        args = {'DB': 'test'}
        db.get_all_pos(args)
        mock_sqlalchemy.assert_called_once()
        mock_sqlalchemy.assert_called_with({'DB': 'test'})


if __name__ == '__main__':
    unittest.main()

如果您想测试某些结果,您需要正确设置所有相应的属性。我建议不要将它链接到一个调用中,这样它更具可读性,如下所示。

import unittest
import db
from unittest.mock import patch
from unittest.mock import Mock


class Cursor:
    def __init__(self, vals):
        self.vals = vals

    def fetchall(self):
        return self.vals


class TestPosition(unittest.TestCase):

    @patch.object(db, '_create_engine')
    def test_get_all_pos(self, mock_sqlalchemy):
        to_test = [1, 2, 3]

        mock_cursor = Mock()
        cursor_attrs = {'fetchall.return_value': to_test}
        mock_cursor.configure_mock(**cursor_attrs)

        mock_execute = Mock()
        engine_attrs = {'execute.return_value': mock_cursor}
        mock_execute.configure_mock(**engine_attrs)

        mock_sqlalchemy.return_value = mock_execute

        args = {'DB': 'test'}
        rows = db.get_all_pos(args)

        mock_sqlalchemy.assert_called_once()
        mock_sqlalchemy.assert_called_with({'DB': 'test'})
        self.assertEqual(to_test, rows)