具有异步引擎的外部事务中的会话

Session in an External Transaction with an async engine

我正在尝试新的(测试版)1.4 sqlalchemy,在尝试使用异步 API 和 pytest.

移植 "Joining a Session into an External Transaction (such as for test suite)" 配方时遇到困难

首先,我尝试将 zzzeekunittest 示例转换为 pytest,效果很好

import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session")
def engine_fixture():
    engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)

    yield engine

    Base.metadata.drop_all(engine)


@pytest.fixture
def session(engine_fixture):
    conn = engine_fixture.connect()
    trans = conn.begin()
    session = Session(bind=conn)

    def _fixture(session):
        session.add_all([Thing(), Thing(), Thing()])
        session.commit()

    # load fixture data within the scope of the transaction
    _fixture(session)

    # start the session in a SAVEPOINT...
    session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    @event.listens_for(session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    session.close()
    trans.rollback()
    conn.close()


def _test_thing(session, extra_rollback=0):

    rows = session.query(Thing).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        assert len(rows) == 6

        session.rollback()

    # after rollbacks, still @ 3 rows
    rows = session.query(Thing).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    session.commit()

    rows = session.query(Thing).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = session.query(Thing).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        session.rollback()

    rows = session.query(Thing).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


def test_thing_one_pytest(session):
    # run zero rollbacks
    _test_thing(session, 0)


def test_thing_two_pytest(session):
    # run two extra rollbacks
    _test_thing(session, 2)

然后我尝试使用 pytest-asyncio 版本 0.14.0

切换到 asyncio API
import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session", autouse=True)
def meta_migration():
    # setup
    sync_engine = create_engine(
        "postgresql://postgres:changethis@db/app_test", echo=True
    )
    Base.metadata.drop_all(sync_engine)
    Base.metadata.create_all(sync_engine)

    yield sync_engine

    # teardown
    Base.metadata.drop_all(sync_engine)


@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
    # setup
    engine = create_async_engine(
        "postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
    )

    yield engine


@pytest.fixture(scope="function")
async def session(async_engine):
    conn = await async_engine.connect()
    trans = await conn.begin()
    session = AsyncSession(bind=conn)

    async def _fixture(session: AsyncSession):
        session.add_all([Thing(), Thing(), Thing()])
        await session.commit()

    # load fixture data within the scope of the transaction
    await _fixture(session)

    # start the session in a SAVEPOINT...
    await session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    # NOTE: no async listeners yet
    @event.listens_for(session.sync_session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    await session.close()
    await trans.rollback()
    await conn.close()


async def _test_thing(session: AsyncSession, extra_rollback=0):

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        assert len(rows) == 6

        await session.rollback()

    # after rollbacks, still @ 3 rows
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    await session.commit()

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        await session.rollback()

    rows = (await session.execute(select(Thing))).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


@pytest.mark.asyncio
async def test_thing_one_pytest(session):
    # run zero rollbacks
    await _test_thing(session, 0)


@pytest.mark.asyncio
async def test_thing_two_pytest(session):
    # run two extra rollbacks
    await _test_thing(session, 2)

但是 "FAILED test_thing_two_pytest - assert 8 == 3" 失败了,因为第一次测试后 teardown 中的事务回滚不会恢复到 setup 阶段创建的保存点。

由于我对 sqlalchemy 内部原理的了解不是很好,我正在寻求设置方面的帮助,因为它对我的测试套件性能至关重要。

难道缺少 async 事件侦听器并根据 AsyncSession.sync_session 定义 restart_savepoint 是不够的,只能等待 1.4 的稳定版本 API?

谢谢!

原来是个bug,直接联系了SA开发者

Github Issue

Fix

注意:根据@zzzek 的说法,有 API 变化,应该使用 connection.begin_nested() 而不是 session.begin_nested()

The "legacy" pattern that you have above which uses "session.begin_nested()" to create the savepoint, this is not supported for the "future" style engine which asyncio uses. The new version uses the connection itself to recreate the savepoint inside the event.