如何在标志为 0 时获取最后一行值,以及在 pyspark 数据帧中标志为 1 时如何将当前行值获取到新列

How to get last row value when flag is 0 and get the current row value to new column when flag 1 in pyspark dataframe

Flag 1 时的场景 1: 对于 Flag 为 1 的行 Copy trx_date to Destination

场景 2 当 Flag 0 时: 对于 Flag 为 0 的行 Copy the previous Destination Value

输入:

+-----------+----+----------+
|customer_id|Flag|  trx_date|
+-----------+----+----------+
|          1|   1| 12/3/2020|
|          1|   0| 12/4/2020|
|          1|   1| 12/5/2020|
|          1|   1| 12/6/2020|
|          1|   0| 12/7/2020|
|          1|   1| 12/8/2020|
|          1|   0| 12/9/2020|
|          1|   0|12/10/2020|
|          1|   0|12/11/2020|
|          1|   1|12/12/2020|
|          2|   1| 12/1/2020|
|          2|   0| 12/2/2020|
|          2|   0| 12/3/2020|
|          2|   1| 12/4/2020|
+-----------+----+----------+

输出:

+-----------+----+----------+-----------+
|customer_id|Flag|  trx_date|destination|
+-----------+----+----------+-----------+
|          1|   1| 12/3/2020|  12/3/2020|
|          1|   0| 12/4/2020|  12/3/2020|
|          1|   1| 12/5/2020|  12/5/2020|
|          1|   1| 12/6/2020|  12/6/2020|
|          1|   0| 12/7/2020|  12/6/2020|
|          1|   1| 12/8/2020|  12/8/2020|
|          1|   0| 12/9/2020|  12/8/2020|
|          1|   0|12/10/2020|  12/8/2020|
|          1|   0|12/11/2020|  12/8/2020|
|          1|   1|12/12/2020| 12/12/2020|
|          2|   1| 12/1/2020|  12/1/2020|
|          2|   0| 12/2/2020|  12/1/2020|
|          2|   0| 12/3/2020|  12/1/2020|
|          2|   1| 12/4/2020|  12/4/2020|
+-----------+----+----------+-----------+

生成spark Dataframe的代码:

df = spark.createDataFrame([(1,1,'12/3/2020'),(1,0,'12/4/2020'),(1,1,'12/5/2020'),
(1,1,'12/6/2020'),(1,0,'12/7/2020'),(1,1,'12/8/2020'),(1,0,'12/9/2020'),(1,0,'12/10/2020'),
(1,0,'12/11/2020'),(1,1,'12/12/2020'),(2,1,'12/1/2020'),(2,0,'12/2/2020'),(2,0,'12/3/2020'),
(2,1,'12/4/2020')], ["customer_id","Flag","trx_date"])

您可以使用 window 函数。我不确定 spark sql 是否支持 lag().

的标准 ignore nulls 选项

如果是,你可以这样做:

select 
    t.*,
    case when flag = 1
        then trx_date
        else lag(case when flag = 1 then trx_date end ignore nulls) 
                over(partition by customer_id order by trx_date)
    end destination
from mytable t

否则,您可以先用 window 总和建立组:

select
    customer_id,
    flag,
    trx_date,
    case when flag = 1
        then trx_date
        else min(trx_date) over(partition by customer_id, grp order by trx_date)
    end destination
from (
    select t.*, sum(flag) over(partition by customer_id order by trx_date) grp
    from mytable t
) t

Pyspark 的方法。在datetype中得到trx_date后,先得到incremental sumFlag 创建我们需要的 groupings 以便使用 first window partitioned by those groupings. 函数 我们可以使用 date_format 将两列恢复为所需的日期格式。我假设您的格式是 MM/dd/yyyy,如果不同请在代码中将其更改为 dd/MM/yyyy

df.show() #sample data
#+-----------+----+----------+
#|customer_id|Flag|  trx_date|
#+-----------+----+----------+
#|          1|   1| 12/3/2020|
#|          1|   0| 12/4/2020|
#|          1|   1| 12/5/2020|
#|          1|   1| 12/6/2020|
#|          1|   0| 12/7/2020|
#|          1|   1| 12/8/2020|
#|          1|   0| 12/9/2020|
#|          1|   0|12/10/2020|
#|          1|   0|12/11/2020|
#|          1|   1|12/12/2020|
#|          2|   1| 12/1/2020|
#|          2|   0| 12/2/2020|
#|          2|   0| 12/3/2020|
#|          2|   1| 12/4/2020|
#+-----------+----+----------+

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w=Window().orderBy("customer_id","trx_date")
w1=Window().partitionBy("Flag2").orderBy("trx_date").rowsBetween(Window.unboundedPreceding,Window.unboundedFollowing)
df.withColumn("trx_date", F.to_date("trx_date", "MM/dd/yyyy"))\
  .withColumn("Flag2", F.sum("Flag").over(w))\
  .withColumn("destination", F.when(F.col("Flag")==0, F.first("trx_date").over(w1)).otherwise(F.col("trx_date")))\
  .withColumn("trx_date", F.date_format("trx_date","MM/dd/yyyy"))\
  .withColumn("destination", F.date_format("destination", "MM/dd/yyyy"))\
  .orderBy("customer_id","trx_date").drop("Flag2").show()

#+-----------+----+----------+-----------+
#|customer_id|Flag|  trx_date|destination|
#+-----------+----+----------+-----------+
#|          1|   1|12/03/2020| 12/03/2020|
#|          1|   0|12/04/2020| 12/03/2020|
#|          1|   1|12/05/2020| 12/05/2020|
#|          1|   1|12/06/2020| 12/06/2020|
#|          1|   0|12/07/2020| 12/06/2020|
#|          1|   1|12/08/2020| 12/08/2020|
#|          1|   0|12/09/2020| 12/08/2020|
#|          1|   0|12/10/2020| 12/08/2020|
#|          1|   0|12/11/2020| 12/08/2020|
#|          1|   1|12/12/2020| 12/12/2020|
#|          2|   1|12/01/2020| 12/01/2020|
#|          2|   0|12/02/2020| 12/01/2020|
#|          2|   0|12/03/2020| 12/01/2020|
#|          2|   1|12/04/2020| 12/04/2020|
#+-----------+----+----------+-----------+

如果您正在考虑数据帧,您可以通过以下方式实现此目的 API

#Convert date format while creating window itself

window = Window().orderBy("customer_id",f.to_date('trx_date','MM/dd/yyyy'))

df1 = df.withColumn('destination', f.when(f.col('Flag')==1,f.col('trx_date'))).\
withColumn('destination',f.last(f.col('destination'),ignorenulls=True).over(window))

df1.show()

+-----------+----+----------+-----------+
|customer_id|Flag|  trx_date|destination|
+-----------+----+----------+-----------+
|          1|   1| 12/3/2020|  12/3/2020|
|          1|   0| 12/4/2020|  12/3/2020|
|          1|   1| 12/5/2020|  12/5/2020|
|          1|   1| 12/6/2020|  12/6/2020|
|          1|   0| 12/7/2020|  12/6/2020|
|          1|   1| 12/8/2020|  12/8/2020|
|          1|   0| 12/9/2020|  12/8/2020|
|          1|   0|12/10/2020|  12/8/2020|
|          1|   0|12/11/2020|  12/8/2020|
|          1|   1|12/12/2020| 12/12/2020|
|          2|   1| 12/1/2020|  12/1/2020|
|          2|   0| 12/2/2020|  12/1/2020|
|          2|   0| 12/3/2020|  12/1/2020|
|          2|   1| 12/4/2020|  12/4/2020|
+-----------+----+----------+-----------+

希望对您有所帮助。