如何在 PySpark 中计算多列和滚动 window 上的不同元素

How to count distinct element over multiple columns and a rolling window in PySpark

假设我们有以下数据框:

port | flag | timestamp

---------------------------------------

20  | S    | 2009-04-24T17:13:14+00:00

30  | R    | 2009-04-24T17:14:14+00:00

32  | S    | 2009-04-24T17:15:14+00:00

21  | R    | 2009-04-24T17:16:14+00:00

54  | R    | 2009-04-24T17:17:14+00:00

24  | R    | 2009-04-24T17:18:14+00:00

我想计算 Pyspark 中 3 小时内不同 port, flag 的数量。

结果将类似于:

port | flag | timestamp | distinct_port_flag_overs_3h

---------------------------------------

20   | S    | 2009-04-24T17:13:14+00:00 | 1

30   | R    | 2009-04-24T17:14:14+00:00 | 1

32   | S    | 2009-04-24T17:15:14+00:00 | 2

21   | R    | 2009-04-24T17:16:14+00:00 | 2

54   | R    | 2009-04-24T17:17:14+00:00 | 2

24   | R    | 2009-04-24T17:18:14+00:00 | 3

SQL 请求看起来像:

SELECT     
COUNT(DISTINCT port) OVER my_window AS distinct_port_flag_overs_3h
FROM my_table
WINDOW my_window AS (
    PARTITION BY flag
    ORDER BY CAST(timestamp AS timestamp)
    RANGE BETWEEN INTERVAL 3 HOUR PRECEDING AND CURRENT
)

我发现 可以解决问题,但前提是我们要对一个字段中的不同元素进行计数。

有人知道如何实现吗:

只需收集一组结构 (port, flag) 并获取其大小。像这样:

w = Window.partitionBy("flag").orderBy("timestamp").rangeBetween(-10800, Window.currentRow)

df.withColumn("timestamp", to_timestamp("timestamp").cast("long"))\
  .withColumn("distinct_port_flag_overs_3h", size(collect_set(struct("port", "flag")).over(w)))\
  .orderBy(col("timestamp"))\
  .show()

我刚刚编写了类似的代码,适用于:


def hive_time(time:str)->int:
    """
    Convert string time to number of seconds
    time : str : must be in the following format, numberType
    For exemple 1hour, 4day, 3month
    """
    match = re.match(r"([0-9]+)([a-z]+)", time, re.I)
    if match:
        items = match.groups()
        nb, kind = items[0], items[1]
        try :
            nb = int(nb)
        except ValueError as e:
            print(e,  traceback.format_exc())
            print("The format of {} which is your time aggregaation is not recognize. Please read the doc".format(time))

        if kind == "second":
            return nb
        if kind == "minute":
            return 60*nb
        if kind == "hour":
            return 3600*nb
        if kind == "day":
            return 24*3600*nb

    assert False, "The format of {} which is your time aggregaation is not recognize. \
    Please read the doc".format(time)


# Rolling window in spark
def distinct_count_over(data, window_size:str, out_column:str, *input_columns, time_column:str='timestamp'):
    """
    data : pyspark dataframe
    window_size : Size of the rolling window, check the doc for format information
    out_column : name of the column where you want to stock the results
    input_columns : the columns where you want to count distinct 
    time_column : the name of the columns where the timefield is stocked (must be in ISO8601)

    return : a new dataframe whith the stocked result 
    """

    concatenated_columns = F.concat(*input_columns)

    w = (Window.orderBy(F.col("timestampGMT").cast('long')).rangeBetween(-hive_time(window_size), 0))

    return data \
.withColumn('timestampGMT', data.timestampGMT.cast(time_column)) \
.withColumn(out_column, F.size(F.collect_set(concatenated_columns).over(w)))

运行良好,尚未检查性能监控。