如何使用多线程从 Polars 数据框中过滤记录 "sequences"?

How to filter record "sequences" from a Polars dataframe using multiple threads?

我有一个数据集,每个人都有多条记录 - 每个时间段一条记录。 如果某人在一段时间内丢失了记录,我需要删除该人以后的任何记录。 因此,给出这样的示例数据集:

import polars as pl
df = pl.DataFrame({'Id': [1,1,2,2,2,2,3,3,4,4,4,5,5,5,6,6,6],
 'Age': [1,4,1,2,3,4,1,2,1,2,3,1,2,4,2,3,4],
 'Value': [1,142,4,73,109,145,6,72,-8,67,102,-1,72,150,72,111,149]})
df
Id  Age Value
i64 i64 i64
1   1   1
1   4   142
2   1   4
2   2   73
2   3   109
2   4   145
3   1   6
3   2   72
4   1   -8
4   2   67
4   3   102
5   1   -1
5   2   72
5   4   150
6   2   72
6   3   111
6   4   149

我需要过滤如下:

Id  Age Value   Keep
i64 i64 i64 bool
1   1   1   true
2   1   4   true
2   2   73  true
2   3   109 true
2   4   145 true
3   1   6   true
3   2   72  true
4   1   -8  true
4   2   67  true
4   3   102 true
5   1   -1  true
5   2   72  true

因此,年龄记录配置文件为 1,3,4 的个人最终将只有 1 记录。像 Id 6 这样年龄记录为 2,3,4 的人在筛选后最终将没有任何记录。

我可以使用下面的方法实现这一点,但是当数据集包含数百万个人时,代码似乎无法 运行 并行并且性能非常慢(最终 [= 之前​​的步骤在具有 1650 万条记录的数据集上,18=] 表达式在约 22 秒内完成,最后一个 filter 表达式又需要 12.5 分钟 才能完成)。是否有替代单线程的方法,或调整下面的代码来实现?

df2 = (
    df.sort(by=["Id","Age"])
    .with_column(
        ((pl.col("Age").diff(1).fill_null(pl.col("Age") == 1) == 1)
        .over("Id")
        .alias("Keep")
    )
    .filter(
        (pl.col("Keep").cumprod() == 1).over("Id")
    )
)

请注意,window功能非常强大,但也相对昂贵。所以你已经可以从减少工作开始了。

df.sort(by=["Id", "Age"]).filter(
    ((pl.col("Age").diff(1).fill_null(1) == 1).over("Id"))
)

而且很可能,您也可以放弃昂贵的排序:

df.filter(
    ((pl.col("Age").diff(1).fill_null(1) == 1).over("Id"))
)

多线程

filter 操作已经包含多种形式的并行性。列的实现是并行的。在这种情况下,掩码的计算也是并行的。 window 表达式(over() 语法)在计算组和执行连接操作时是多线程的。

榨取window函数的最大性能

如果您的数据已经排序,您可以通过显式添加 list 聚合然后展平该结果来使 window 表达式更快。这是因为列表聚合是免费的,因为我们已经在聚合中有一个列表(实现细节)并且展平通常也是免费的。实现细节有点复杂,但这意味着 polars 不必计算每个聚合相对于原始 DataFrame.

的位置

只有当 DataFrame 已经按组排序时才有意义

# note that that only makes sense if the df is sorted by the groups
sorted_df.filter(
    ((pl.col("Age").diff(1).fill_null(1) == 1).list().over("Id").flatten())
)

我提出以下(修改后的)代码:

df2 = df.filter(pl.col('Age').rank().over('Id') == pl.col('Age'))

此代码在您的测试数据集上产生以下结果:

shape: (12, 3)
┌─────┬─────┬───────┐
│ Id  ┆ Age ┆ Value │
│ --- ┆ --- ┆ ---   │
│ i64 ┆ i64 ┆ i64   │
╞═════╪═════╪═══════╡
│ 1   ┆ 1   ┆ 1     │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2   ┆ 1   ┆ 4     │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2   ┆ 2   ┆ 73    │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2   ┆ 3   ┆ 109   │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2   ┆ 4   ┆ 145   │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3   ┆ 1   ┆ 6     │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3   ┆ 2   ┆ 72    │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4   ┆ 1   ┆ -8    │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4   ┆ 2   ┆ 67    │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4   ┆ 3   ┆ 102   │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 5   ┆ 1   ┆ -1    │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 5   ┆ 2   ┆ 72    │
└─────┴─────┴───────┘

基本上,当 Age 被跳过时(对于特定的 Id),AgerankAge 变量本身,并且对于 Id.

的所有更高 Age 值保持步调

与我之前的回答相比,这段代码有几个优点。它更简洁,更容易理解,最重要的是......它很好地利用了 Polars API,尤其是 over window 函数。

即使此代码在即将发布的 Polars 版本中的基准测试速度稍慢,出于上述原因我还是推荐它。

编辑 - Polars 0.13.15 上的基准测试

好的,哇哇哇...我刚刚下载了新发布的 Polars (0.13.15),re-benchmarked 我机器上的代码生成了 1700 万条记录,就像我之前的回答一样。

结果?

  • 题目中列出的修改代码:13.6秒
  • 我之前回答中的(丑陋的)代码:4.8 秒
  • 本答案中的one-line代码:3.3秒

并且在代码运行时观察 htop 命令,很明显新发布的 Polars 代码使用了我机器上的所有 64 个逻辑核心。 大规模并行

印象深刻!