模拟测试数据库的sqlobject函数调用

Mocking sqlobject function call for test db

我正在尝试使用 pytest

模拟 sqlbuilder.func 测试用例

我成功模拟了 sqlbuilder.func.TO_BASE64 并且输出正确,但是当我尝试模拟 sqlbuilder.func.FROM_UNIXTIME 时我没有收到任何错误,但是生成的查询结果输出不正确。以下是该问题的最小工作示例。

models.py

from sqlobject import (
    sqlbuilder,
    sqlhub,
    SQLObject,
    StringCol,
    BLOBCol,
    TimestampCol,
)

class Store(SQLObject):
    name = StringCol()
    sample = BLOBCol()
    createdAt = TimestampCol()

DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
    query = sqlbuilder.Select([
            sqlbuilder.func.TO_BASE64(Store.q.sample),
        ],
        sqlbuilder.AND(
            Store.q.name == name,
            sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
        )
    )

    connection = sqlhub.getConnection()
    query = connection.sqlrepr(query)
    print(query)
    queryResult = connection.queryAll(query)
    return queryResult

conftest.py

import pytest

from models import Store
from sqlobject import sqlhub
from sqlobject.sqlite import sqliteconnection

@pytest.fixture(autouse=True, scope="session")
def sqlite_db_session(tmpdir_factory):
    file = tmpdir_factory.mktemp("db").join("sqlite.db")
    conn = sqliteconnection.SQLiteConnection(str(file))
    sqlhub.processConnection = conn
    init_tables()
    yield conn
    conn.close()

def init_tables():
    Store.createTable(ifNotExists=True)

test_ex1.py

import pytest

from sqlobject import sqlbuilder
from models import retrieve

try:
    import mock
    from mock import MagicMock
except ImportError:
    from unittest import mock
    from unittest.mock import MagicMock

def TO_BASE64(x):
    return x

def FROM_UNIXTIME(x, y):
    return 'strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x)

# @mock.patch("sqlobject.sqlbuilder.func.TO_BASE64")
# @mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", MagicMock(side_effect=lambda x: x))
# @mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", new_callable=MagicMock(side_effect=lambda x: x))
@mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", TO_BASE64)
@mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME", FROM_UNIXTIME)
def test_retrieve():
    result = retrieve('Some')
    assert result == []

当前 SQL:

SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND (1))

预期 SQL:

SELECT
  store.sample
FROM 
  store
WHERE
  store.name = 'Some'
AND
  strftime(
    '%Y%m%d',
    datetime(store.created_at, 'unixepoch', 'localtime')
  ) >= strftime(
    '%Y%m%d',
    datetime('2018-10-12', 'unixepoch', 'localtime')
  )

编辑示例

#! /usr/bin/env python

from sqlobject import *

__connection__ = "sqlite:/:memory:?debug=1&debugOutput=1"

try:
    import mock
    from mock import MagicMock
except ImportError:
    from unittest import mock
    from unittest.mock import MagicMock

class Store(SQLObject):
    name = StringCol()
    sample = BLOBCol()
    createdAt = TimestampCol()

Store.createTable()

DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
    query = sqlbuilder.Select([
            sqlbuilder.func.TO_BASE64(Store.q.sample),
        ],
        sqlbuilder.AND(
            Store.q.name == name,
            sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
        )
    )

    connection = Store._connection
    query = connection.sqlrepr(query)
    queryResult = connection.queryAll(query)
    return queryResult


def TO_BASE64(x):
    return x

def FROM_UNIXTIME(x, y):
    return 'strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x)

for p in [
    mock.patch("sqlobject.sqlbuilder.func.TO_BASE64",TO_BASE64),
    mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME",FROM_UNIXTIME),
]:
    p.start()

retrieve('Some')

mock.patch.stopall()

默认情况下,sqlbuilder.func 是一个 SQLExpression,它将其属性(sqlbuilder.func.datetime,例如)作为常量传递给 SQL 后端(sqlbuilder.func实际上是 sqlbuilder.ConstantSpace 的别名)。请参阅有关 SQLExpression, the FAQ and the code for func.

的文档

当您模拟 func 命名空间中的属性时,它由 SQLObject 评估并以简化形式传递到后端。如果你想 return 来自模拟函数的字符串文字,你需要告诉 SQLObject 它是一个必须按原样传递给后端的值,未计算。这样做的方法是像这样将文字包装在 SQLConstant 中:

def FROM_UNIXTIME(x, y):
    return sqlbuilder.SQLConstant('strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x))

参见 SQLConstant

整个测试脚本现在看起来像这样

#! /usr/bin/env python3.7

from sqlobject import *

__connection__ = "sqlite:/:memory:?debug=1&debugOutput=1"

try:
    import mock
    from mock import MagicMock
except ImportError:
    from unittest import mock
    from unittest.mock import MagicMock

class Store(SQLObject):
    name = StringCol()
    sample = BLOBCol()
    createdAt = TimestampCol()

Store.createTable()

DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
    query = sqlbuilder.Select([
            sqlbuilder.func.TO_BASE64(Store.q.sample),
        ],
        sqlbuilder.AND(
            Store.q.name == name,
            sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
        )
    )

    connection = Store._connection
    query = connection.sqlrepr(query)
    queryResult = connection.queryAll(query)
    return queryResult


def TO_BASE64(x):
    return x

def FROM_UNIXTIME(x, y):
    return sqlbuilder.SQLConstant('strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x))

for p in [
    mock.patch("sqlobject.sqlbuilder.func.TO_BASE64",TO_BASE64),
    mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME",FROM_UNIXTIME),
]:
    p.start()

retrieve('Some')

mock.patch.stopall()

输出为:

 1/Query   :  CREATE TABLE store (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    name TEXT,
    sample TEXT,
    created_at TIMESTAMP
)
 1/QueryR  :  CREATE TABLE store (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    name TEXT,
    sample TEXT,
    created_at TIMESTAMP
)
 2/QueryAll:  SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))
 2/QueryR  :  SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))
 2/QueryAll-> []

PS。全面披露:我是 SQLObject.

的当前维护者

正如@phd 指出的那样,SQLObject 在将表达式以简化形式传递给后端之前先对其求值。

然后我们也可以直接传递 SQLObject 将计算的表达式,而不是传递字符串文字,我们也可以像下面那样做

def FROM_UNIXTIME(x, y):
    return sqlbuilder.func.strftime("%Y%m%d", sqlbuilder.func.datetime(x, "unixepoch", "localtime"))

输出:

SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))