线程“<unnamed>”对断言感到恐慌

thread '<unnamed>' panicked at assertion

我在 Python Polars 中收到未知错误:

thread '<unnamed>' panicked at 'assertion failed: `(left == right)`
  left: `Float64[NaN, 1, NaN, NaN, NaN, ...[clip]...
  right: `Float64[NaN, 1, NaN, NaN, NaN, ...[clip]...

这是内部错误吗?

触发的代码是:

df.select([
    pl.col('total').shift().ewm_mean(half_life = 10).over('group')
])

我很难再问了,因为错误太高深莫测了。

这看起来确实像一个错误。它来自于在 window 函数 (over) 中对包含 NaN 值的表达式调用 shift

import polars as pl
import numpy as np

df = pl.DataFrame(
    {
        "group": ["a", "a", "a", "b", "b", "b"],
        "total": [1.0, 2, 3, 4, 5, np.NaN],
    }
)

df.select([
    pl.col('total').shift().over('group')
])
thread '<unnamed>' panicked at 'assertion failed: `(left == right)`
  left: `Float64[4, 5, NaN]`,
 right: `Float64[4, 5, NaN]`', /github/workspace/polars/polars-core/src/series/unstable.rs:39:9
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/corey/.virtualenvs/Whosebug3.10/lib/python3.10/site-packages/polars/internals/frame.py", line 4253, in select
    self.lazy()
  File "/home/corey/.virtualenvs/Whosebug3.10/lib/python3.10/site-packages/polars/internals/lazy_frame.py", line 476, in collect
    return self._dataframe_class._from_pydf(ldf.collect())
pyo3_runtime.PanicException: assertion failed: `(left == right)`
  left: `Float64[4, 5, NaN]`,
 right: `Float64[4, 5, NaN]`

由于您使用的是 sum 聚合,能否使用 fill_nan(0) 解决此问题?或者在这些情况下您是否需要保留 NaN 值?

df.select([
    pl.col('total').fill_nan(0).shift().sum().over('group')
])
shape: (6, 1)
┌─────────┐
│ literal │
│ ---     │
│ f64     │
╞═════════╡
│ 3.0     │
├╌╌╌╌╌╌╌╌╌┤
│ 3.0     │
├╌╌╌╌╌╌╌╌╌┤
│ 3.0     │
├╌╌╌╌╌╌╌╌╌┤
│ 9.0     │
├╌╌╌╌╌╌╌╌╌┤
│ 9.0     │
├╌╌╌╌╌╌╌╌╌┤
│ 9.0     │
└─────────┘

我会在 GitHub 上为它创建一个问题。

编辑:此问题现已在 Polars 0.13.19 及更高版本中修复,不再需要解决方法。

另一种解决此问题的临时方法是以另一种方式使用 over window 创建 shift 的结果。

假设我们有以下组、编号的观察值和总数。

import numpy as np
import polars as pl

df = pl.DataFrame(
    {
        "group": ["a", "a", "b", "a", "b", "b"],
        "obs": [1, 2, 1, 3, 2, 3],
        "total": [1.0, 2, 3, 4, 5, np.NaN],
    }
)
df
shape: (6, 3)
┌───────┬─────┬───────┐
│ group ┆ obs ┆ total │
│ ---   ┆ --- ┆ ---   │
│ str   ┆ i64 ┆ f64   │
╞═══════╪═════╪═══════╡
│ a     ┆ 1   ┆ 1.0   │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ a     ┆ 2   ┆ 2.0   │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b     ┆ 1   ┆ 3.0   │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ a     ┆ 3   ┆ 4.0   │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b     ┆ 2   ┆ 5.0   │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b     ┆ 3   ┆ NaN   │
└───────┴─────┴───────┘

以下代码将得到与 shift 相同的结果:

df = (
    df.sort(["group", "obs"])
    .with_column(pl.col("total").shift().alias("total_shifted"))
    .with_column(
        pl.when(pl.col("group").is_first())
        .then(None)
        .otherwise(pl.col("total_shifted"))
        .alias("result")
    )
)
df
shape: (6, 5)
┌───────┬─────┬───────┬───────────────┬────────┐
│ group ┆ obs ┆ total ┆ total_shifted ┆ result │
│ ---   ┆ --- ┆ ---   ┆ ---           ┆ ---    │
│ str   ┆ i64 ┆ f64   ┆ f64           ┆ f64    │
╞═══════╪═════╪═══════╪═══════════════╪════════╡
│ a     ┆ 1   ┆ 1.0   ┆ null          ┆ null   │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ a     ┆ 2   ┆ 2.0   ┆ 1.0           ┆ 1.0    │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ a     ┆ 3   ┆ 4.0   ┆ 2.0           ┆ 2.0    │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ b     ┆ 1   ┆ 3.0   ┆ 4.0           ┆ null   │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ b     ┆ 2   ┆ 5.0   ┆ 3.0           ┆ 3.0    │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ b     ┆ 3   ┆ NaN   ┆ 5.0           ┆ 5.0    │
└───────┴─────┴───────┴───────────────┴────────┘

(我已将中间计算留在数据集中以供检查,以展示算法的工作原理。)

请注意,result 列与您从 shift 组中获得的值相同。然后,您可以 运行 在 result 列上进行聚合,而无需使用 shift。

df.select([
    pl.col('result').ewm_mean(half_life = 10).over('group')
])

当然,您必须根据您的特定代码对其进行调整,但它应该可以工作。