具有异步引擎的外部事务中的会话
Session in an External Transaction with an async engine
我正在尝试新的(测试版)1.4 sqlalchemy,在尝试使用异步 API 和 pytest
.
移植 "Joining a Session into an External Transaction (such as for test suite)" 配方时遇到困难
首先,我尝试将 zzzeek 的 unittest
示例转换为 pytest
,效果很好
import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session")
def engine_fixture():
engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def session(engine_fixture):
conn = engine_fixture.connect()
trans = conn.begin()
session = Session(bind=conn)
def _fixture(session):
session.add_all([Thing(), Thing(), Thing()])
session.commit()
# load fixture data within the scope of the transaction
_fixture(session)
# start the session in a SAVEPOINT...
session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
@event.listens_for(session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
session.close()
trans.rollback()
conn.close()
def _test_thing(session, extra_rollback=0):
rows = session.query(Thing).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
assert len(rows) == 6
session.rollback()
# after rollbacks, still @ 3 rows
rows = session.query(Thing).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
session.commit()
rows = session.query(Thing).all()
assert len(rows) == 5
session.add(Thing())
rows = session.query(Thing).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
session.rollback()
rows = session.query(Thing).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
def test_thing_one_pytest(session):
# run zero rollbacks
_test_thing(session, 0)
def test_thing_two_pytest(session):
# run two extra rollbacks
_test_thing(session, 2)
然后我尝试使用 pytest-asyncio
版本 0.14.0
切换到 asyncio
API
import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session", autouse=True)
def meta_migration():
# setup
sync_engine = create_engine(
"postgresql://postgres:changethis@db/app_test", echo=True
)
Base.metadata.drop_all(sync_engine)
Base.metadata.create_all(sync_engine)
yield sync_engine
# teardown
Base.metadata.drop_all(sync_engine)
@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
# setup
engine = create_async_engine(
"postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
)
yield engine
@pytest.fixture(scope="function")
async def session(async_engine):
conn = await async_engine.connect()
trans = await conn.begin()
session = AsyncSession(bind=conn)
async def _fixture(session: AsyncSession):
session.add_all([Thing(), Thing(), Thing()])
await session.commit()
# load fixture data within the scope of the transaction
await _fixture(session)
# start the session in a SAVEPOINT...
await session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
# NOTE: no async listeners yet
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
await session.close()
await trans.rollback()
await conn.close()
async def _test_thing(session: AsyncSession, extra_rollback=0):
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
await session.rollback()
# after rollbacks, still @ 3 rows
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
await session.commit()
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 5
session.add(Thing())
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
await session.rollback()
rows = (await session.execute(select(Thing))).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
@pytest.mark.asyncio
async def test_thing_one_pytest(session):
# run zero rollbacks
await _test_thing(session, 0)
@pytest.mark.asyncio
async def test_thing_two_pytest(session):
# run two extra rollbacks
await _test_thing(session, 2)
但是 "FAILED test_thing_two_pytest - assert 8 == 3"
失败了,因为第一次测试后 teardown
中的事务回滚不会恢复到 setup
阶段创建的保存点。
由于我对 sqlalchemy 内部原理的了解不是很好,我正在寻求设置方面的帮助,因为它对我的测试套件性能至关重要。
难道缺少 async
事件侦听器并根据 AsyncSession.sync_session
定义 restart_savepoint
是不够的,只能等待 1.4 的稳定版本 API?
谢谢!
原来是个bug,直接联系了SA开发者
注意:根据@zzzek 的说法,有 API 变化,应该使用 connection.begin_nested()
而不是 session.begin_nested()
:
The "legacy" pattern that you have above which uses "session.begin_nested()" to create the savepoint, this is not supported for the "future" style engine which asyncio uses. The new version uses the connection itself to recreate the savepoint inside the event.
我正在尝试新的(测试版)1.4 sqlalchemy,在尝试使用异步 API 和 pytest
.
首先,我尝试将 zzzeek 的 unittest
示例转换为 pytest
,效果很好
import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session")
def engine_fixture():
engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def session(engine_fixture):
conn = engine_fixture.connect()
trans = conn.begin()
session = Session(bind=conn)
def _fixture(session):
session.add_all([Thing(), Thing(), Thing()])
session.commit()
# load fixture data within the scope of the transaction
_fixture(session)
# start the session in a SAVEPOINT...
session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
@event.listens_for(session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
session.close()
trans.rollback()
conn.close()
def _test_thing(session, extra_rollback=0):
rows = session.query(Thing).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
assert len(rows) == 6
session.rollback()
# after rollbacks, still @ 3 rows
rows = session.query(Thing).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
session.commit()
rows = session.query(Thing).all()
assert len(rows) == 5
session.add(Thing())
rows = session.query(Thing).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
session.rollback()
rows = session.query(Thing).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
def test_thing_one_pytest(session):
# run zero rollbacks
_test_thing(session, 0)
def test_thing_two_pytest(session):
# run two extra rollbacks
_test_thing(session, 2)
然后我尝试使用 pytest-asyncio
版本 0.14.0
asyncio
API
import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session", autouse=True)
def meta_migration():
# setup
sync_engine = create_engine(
"postgresql://postgres:changethis@db/app_test", echo=True
)
Base.metadata.drop_all(sync_engine)
Base.metadata.create_all(sync_engine)
yield sync_engine
# teardown
Base.metadata.drop_all(sync_engine)
@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
# setup
engine = create_async_engine(
"postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
)
yield engine
@pytest.fixture(scope="function")
async def session(async_engine):
conn = await async_engine.connect()
trans = await conn.begin()
session = AsyncSession(bind=conn)
async def _fixture(session: AsyncSession):
session.add_all([Thing(), Thing(), Thing()])
await session.commit()
# load fixture data within the scope of the transaction
await _fixture(session)
# start the session in a SAVEPOINT...
await session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
# NOTE: no async listeners yet
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
await session.close()
await trans.rollback()
await conn.close()
async def _test_thing(session: AsyncSession, extra_rollback=0):
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
await session.rollback()
# after rollbacks, still @ 3 rows
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
await session.commit()
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 5
session.add(Thing())
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
await session.rollback()
rows = (await session.execute(select(Thing))).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
@pytest.mark.asyncio
async def test_thing_one_pytest(session):
# run zero rollbacks
await _test_thing(session, 0)
@pytest.mark.asyncio
async def test_thing_two_pytest(session):
# run two extra rollbacks
await _test_thing(session, 2)
但是 "FAILED test_thing_two_pytest - assert 8 == 3"
失败了,因为第一次测试后 teardown
中的事务回滚不会恢复到 setup
阶段创建的保存点。
由于我对 sqlalchemy 内部原理的了解不是很好,我正在寻求设置方面的帮助,因为它对我的测试套件性能至关重要。
难道缺少 async
事件侦听器并根据 AsyncSession.sync_session
定义 restart_savepoint
是不够的,只能等待 1.4 的稳定版本 API?
谢谢!
原来是个bug,直接联系了SA开发者
注意:根据@zzzek 的说法,有 API 变化,应该使用 connection.begin_nested()
而不是 session.begin_nested()
:
The "legacy" pattern that you have above which uses "session.begin_nested()" to create the savepoint, this is not supported for the "future" style engine which asyncio uses. The new version uses the connection itself to recreate the savepoint inside the event.