Pyspark window 具有条件的函数来舍入旅行者的数量

Pyspark window function with conditions to round number of travelers

我正在使用 Pyspark,我想创建一个执行以下操作的函数:

描述列车用户交易的给定数据:

+----+----------+--------+-----+
|date|total_trav|num_trav|order|
+----+----------+--------+-----+
|   1|         9|     2.7|    1|
|   1|         9|     1.3|    2|
|   1|         9|     1.3|    3|
|   1|         9|     1.3|    4|
|   1|         9|     1.2|    5|
|   1|         9|     1.1|    6|
|   2|         9|     2.7|    1|
|   2|         9|     1.3|    2|
|   2|         9|     1.3|    3|
|   2|         9|     1.3|    4|
|   2|         9|     1.2|    5|
|   2|         9|     1.1|    6|
+----+----------+--------+-----+

我想根据 order 列中给定的顺序舍入 num_trav 列的数字,同时按 date 分组以获得 trav_res柱子。 它背后的逻辑是这样的:

例如,让我们考虑这个结果数据框,看看 trav_res 列是如何形成的:

+----+----------+--------+-----+--------+
|date|total_trav|num_trav|order|trav_res|
+----+----------+--------+-----+--------+
|   1|         9|     2.7|    1|       3|
|   1|         9|     1.3|    2|       2|
|   1|         9|     1.3|    3|       2|
|   1|         9|     1.3|    4|       2|
|   1|         9|     1.2|    5|       0|
|   1|         9|     1.1|    6|       0|
|   2|         9|     2.7|    1|       3|
|   2|         9|     1.3|    2|       2|
|   2|         9|     1.3|    3|       2|
|   2|         9|     1.3|    4|       2|
|   2|         9|     1.2|    5|       0|
|   2|         9|     1.1|    6|       0|
+----+----------+--------+-----+--------+

在上面的示例中,当您按日期分组时,您将有 2 个组,最大旅行者数量为 9(total_trav 列)。 例如,对于第 1 组,你将开始将 num_trav=2.7 舍入为 3(trav_res 列),然后将 num_trav=1.3 舍入为 2,然后将 num_trav=1.3 舍入为 2,num_trav=1.3 到 2(这是按照给定的顺序),然​​后对于下一个你没有旅行者离开,所以他们拥有的数量并不重要,因为没有旅行者离开,所以他们将得到 trav_res=0 在这两种情况下。

我已经尝试了一些 udf 函数,但你似乎无法完成这项工作。

您可以先将F.ceil应用于num_trav中的所有行,然后根据上限值创建cumsum列,然后在cumsum超过total_trav时将上限值设置为零,如下所示在下面的代码中

# create dataframe
import pyspark.sql.functions as F
from pyspark.sql import Window

data = [(1, 9, 2.7, 1),
        (1, 9, 1.3, 2),
        (1, 9, 1.3, 3),
        (1, 9, 1.3, 4),
        (1, 9, 1.2, 5),
        (1, 9, 1.1, 6),
        (2, 9, 2.7, 1),
        (2, 9, 1.3, 2),
        (2, 9, 1.3, 3),
        (2, 9, 1.3, 4),
        (2, 9, 1.2, 5),
        (2, 9, 1.1, 6)]

df = spark.createDataFrame(data, schema=["date", "total_trav", "num_trav", "order"])

# create ceiling column
df = df.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil"))
               .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))

解决方案基于@AnnaK。回答,再加一点。 这样它就考虑到必须使用的旅行者总数 (total_trav),不多也不少。

# create ceiling column
df = df_j_test_res.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil")
                     ).when((F.col('num_trav_ceil_cumsum')-F.col('total_trav')>0) & ((F.col('num_trav_ceil_cumsum')-F.col('total_trav')<=1)),
                      1)
              .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))