在使用 SqlAlchemy 创建的视图的字段上限制小数位数

Limit decimal digits on field of a view created with SqlAlchemy

已创建此视图:

class OpenPositionMetric(Base):
    stmt = (
        select(
            [
                OpenPosition.belongs_to.label("belongs_to"),
                OpenPosition.account_number.label("account_number"),
                OpenPosition.exchange.label("exchange"),
                OpenPosition.symbol.label("symbol"),
                round(OpenPosition.actual_shares * OpenPosition.avg_cost_per_share,3).label(
                    "cost_value"
                ),
            ]
        )
        .select_from(OpenPosition)
        .order_by("belongs_to", "account_number", "exchange", "symbol")
    )

    view = create_materialized_view(
        name="vw_open_positions_metrics",
        selectable=stmt,
        metadata=Base.metadata,
        indexes=None,
    )
    __table__ = view

我得到字段 cost_value 的示例结果:1067.2500060000000000.

有没有办法限制该视图字段的小数位数?

函数 round() 不起作用。可能是因为 round 是一个 python 函数,而 SqlAlchemy 期望 sql 表达式语言函数像 func.sum?

更新:

我找到了一个解决方案,但它并不完美。我相信还有更好的...

(text("ROUND (operations.tb_open_positions.actual_shares * operations.tb_open_positions.avg_cost_per_share,3) AS cost_value"))),

上面的值现在在视图中显示为 1067.250

限制小数位数的一种方法是将结果转换为 Numeric:

import sqlalchemy as sa

# …

class OpenPosition(Base):
    __tablename__ = "open_position"
    id = sa.Column(sa.Integer, primary_key=True, autoincrement=False)
    actual_shares = sa.Column(sa.Float)
    avg_cost_per_share = sa.Column(sa.Float)


Base.metadata.drop_all(engine, checkfirst=True)
Base.metadata.create_all(engine)

with sa.orm.Session(engine, future=True) as session:
    session.add(
        OpenPosition(id=1, actual_shares=1, avg_cost_per_share=1067.250606)
    )
    session.commit()
    result = session.query(
        (OpenPosition.actual_shares * OpenPosition.avg_cost_per_share).label(
            "cost_value"
        )
    ).all()
    print(result)  # [(1067.250606,)]

    result = session.query(
        sa.cast(
            (
                OpenPosition.actual_shares * OpenPosition.avg_cost_per_share
            ).label("cost_value"),
            sa.Numeric(10, 3),
        )
    ).all()
    print(result)  # [(Decimal('1067.251'),)]

解决方案(感谢Gord Thompson):

from sqlalchemy import cast, Numeric

class OpenPositionMetric(Base):
    stmt = (
        select(
            [
                OpenPosition.belongs_to.label("belongs_to"),
                OpenPosition.account_number.label("account_number"),
                OpenPosition.exchange.label("exchange"),
                OpenPosition.symbol.label("symbol"),
                (
                    cast(
                        OpenPosition.actual_shares * OpenPosition.avg_cost_per_share,
                        Numeric(10, 3),
                    )
                ).label("cost_value")

            ]
        )
        .select_from(OpenPosition)
        .order_by("belongs_to", "account_number", "exchange", "symbol")
    )

    view = create_materialized_view(
        name="vw_open_positions_metrics",
        selectable=stmt,
        metadata=Base.metadata,
        indexes=None,
    )
    __table__ = view