获取一列中下一个非零值的行数并在另一列中求和 - Pyspark

Get the number of rows to the next non-zero value in one column and sum in another column - Pyspark

我在 Pyspark 中有以下 table:

-----------------------------------------------------------------
|  sku  | distribution center | leadtime | ind_abt |    date    | 
-----------------------------------------------------------------
|  1234 |      New York       |    10    |    0    | 2022-01-01 |
|  1234 |      New York       |    10    |    0    | 2022-01-02 |
|  1234 |      New York       |    10    |    0    | 2022-01-03 |
|  1234 |      New York       |    10    |    1    | 2022-01-04 |
-----------------------------------------------------------------

对于每一行,我想用非零“ind_abt”值计算到下一行的距离,并创建一个名为 leadtime_aux 的新列,它对提前期求和该距离的值。此距离必须在 ("sku", "distribution_center") window 中计算并查找当前行下方的 15 行。

例如,在第一行中,到第 ind_abt 列 != 0 的下一行的距离为 3。因此第 leadtime_aux 列将为 13(交货时间+3)。 对于第二行,到下一个非零行的距离是 2,所以 leadtime_aux = 12.

结果 table 看起来像这样:

--------------------------------------------------------------------------------
|  sku  | distribution center | leadtime | ind_abt |    date    | leadtime_aux | 
--------------------------------------------------------------------------------
|  1234 |      New York       |    10    |    0    | 2022-01-01 |      13      |
|  1234 |      New York       |    10    |    0    | 2022-01-02 |      12      |
|  1234 |      New York       |    10    |    0    | 2022-01-03 |      11      |
|  1234 |      New York       |    10    |    1    | 2022-01-04 |      10      |
--------------------------------------------------------------------------------

我想我在这里找到了解决办法。我会把它贴在这里,它可能对以后的人有用:

win = Window.partitionBy("sku", "distribution_center").orderBy("date")
win2 = Window.partitionBy("sku", "distribution_center").orderBy("date").rowsBetween(0, 15)
df = df.withColumn("rnum", row_number().over(win)) \
     .withColumn("delta", first(when(col("ind_abt") == 0, None).otherwise(col("rnum")), ignorenulls=True).over(win))\
     .withColumn("leadtime_aux", when(col("delta").isNull(), None).otherwise(abs(col("rnum")-col("delta")))+col("leadtime"))