Rust polars:when().then().otherwise() 在 groupby-agg 上下文中的意外行为

Rust polars : unexpected befaviour of when().then().otherwise() in groupby-agg context

我有一个复杂的映射逻辑,我试图在 groupby 上下文中执行。代码编译并且不会恐慌,但结果不正确。我知道逻辑实现是正确的。因此,我想知道 when-then-otherwise 是否应该在 groupby 中使用?

use polars::prelude::*;
use polars::df;

fn main() {
    let df = df! [
        "Region" => ["EU", "EU", "EU", "EU", "EU"],
        "MonthCCY" => ["APRUSD", "MAYUSD", "JUNEUR", "JULUSD", "APRUSD"],
        "values" => [1, 2, 3, 4, 5],
    ].unwrap();

    let df = df.lazy()
        .groupby_stable([col("MonthCCY")])
        .agg( [
            month_weight().alias("Weight"),
        ]
        );
}

pub fn month_weight() -> Expr {
    when(col("Region").eq(lit("EU")))
    .then(
        // First, If MonthCCY is JUNEUR or AUGEUR - apply 0.05
        when(col("MonthCCY").map( |s|{
            Ok( s.utf8()?
            .contains("JUNEUR|AUGEUR")?
            .into_series() )
         }
            , GetOutput::from_type(DataType::Boolean)
        ))
        .then(lit::<f64>(0.05))
        .otherwise(
            // Second, If MonthCCY is JANEUR - apply 0.0225
            when(col("MonthCCY").map( |s|{
                Ok( s.utf8()?
                .contains("JANEUR")?
                .into_series() )
             }
                , GetOutput::from_type(DataType::Boolean)
            ))
            .then(lit::<f64>(0.0225))
            .otherwise(
                // Third, If MonthCCY starts with JUL or FEB (eg FEBUSD or FEBEUR)- apply 0.15
                when(col("MonthCCY").apply( |s|{
                    let x = s.utf8()?
                    .str_slice(0, Some(3))?;
                    let y = x.contains("JUL|FEB")?
                    .into_series();
                    Ok(y)
                 }
                    , GetOutput::from_type(DataType::Boolean)
                ))
                .then(lit::<f64>(0.15))
                //Finally, if none of the above matched, apply 0.2
                .otherwise(lit::<f64>(0.20))
            )
        )
    ).otherwise(lit::<f64>(0.0))
}

我得到的结果是:

┌──────────┬─────────────┐
│ MonthCCY ┆ Weight      │
│ ---      ┆ ---         │
│ str      ┆ list [f64]  │
╞══════════╪═════════════╡
│ APRUSD   ┆ [0.2, 0.15] │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ MAYUSD   ┆ [0.2]       │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ JUNEUR   ┆ [0.05]      │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ JULUSD   ┆ [0.2]       │
└──────────┴─────────────┘

显然,我预计 JULUSD 为 0.15,APRUSD 为 [0.2, 0.2]。

我对 when().then().otherwise() 如何在 groupby 中工作的期望是错误的吗?

我在 Windows11, rustc 1.60.

是的,您按错误的顺序进行分组和映射。 month_weight()不是聚合表达式,而是简单的映射表达式。

事实上,DataFrame 的每一组都agg编入一个系列,最终从原始框架中的数据顺序派生。

您首先要创建一个 Weight 列,其值由您在 month_weight() 中指定的映射给出,然后 您想要聚合此列进入每个组的列表。

所以,你想要的是:

let df = df
    .lazy()
    .with_column(month_weight().alias("Weight")) // create new column first
    .groupby_stable([col("MonthCCY")]) // then group
    .agg([col("Weight").list()]); // then collect into a list per group

println!("{:?}", df.collect().unwrap());

输出:

shape: (4, 2)
┌──────────┬────────────┐
│ MonthCCY ┆ Weight     │
│ ---      ┆ ---        │
│ str      ┆ list [f64] │
╞══════════╪════════════╡
│ APRUSD   ┆ [0.2, 0.2] │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ MAYUSD   ┆ [0.2]      │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ JUNEUR   ┆ [0.05]     │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ JULUSD   ┆ [0.15]     │
└──────────┴────────────┘

此外,顺便说一句,.when().then() 可以无限期地链接;你不需要嵌套它们。因此,就像您可以编写链式 if ... else if ... else if ... else 一样,您可以编写 col().when().then().when().then() ... .otherwise(),这比嵌套每个附加条件要简单得多。