使用带条件的 PySpark window 函数添加行

Using PySpark window functions with conditions to add rows

我需要能够根据具有公共 ID 的其他行的内容向 PySpark df will 值添加新行。最终会有数百万个 ID,每个 ID 都有很多行。我尝试了以下有效但似乎过于复杂的方法。

我从以下格式的 df 开始(但实际上有更多列):

+-------+----------+-------+
|   id  | variable | value |
+-------+----------+-------+
|     1 | varA     |    30 |
|     1 | varB     |     1 |
|     1 | varC     |    -9 |
+-------+----------+-------+

目前我正在旋转这个 df 以获取以下格式:

+-----+------+------+------+
|  id | varA | varB | varC |
+-----+------+------+------+
|   1 |   30 |    1 |   -9 |
+-----+------+------+------+

在这个 df 上,我可以使用标准的 withColumn 和 when 功能来根据其他列中的值添加新列。例如:

df = df.withColumn("varD", when((col("varA") > 16) & (col("varC") != -9)), 2).otherwise(1)

这导致:

+-----+------+------+------+------+
|  id | varA | varB | varC | varD |
+-----+------+------+------+------+
|   1 |   30 |    1 |   -9 |    1 |
+-----+------+------+------+------+

然后我可以将此 df 旋转回原始格式,从而导致此:

+-------+----------+-------+
|   id  | variable | value |
+-------+----------+-------+
|     1 | varA     |    30 |
|     1 | varB     |     1 |
|     1 | varC     |    -9 |
|     1 | varD     |     1 |
+-------+----------+-------+

这行得通,但对于数百万行,它似乎会导致昂贵且不必要的操作。感觉它应该是可行的,而不需要旋转和取消旋转数据。我需要这样做吗?

我读过 Window 函数,听起来它们可能是实现相同结果的另一种方法,但老实说,我正在努力开始使用它们。我可以看到如何使用它们为每个 id 生成一个值,比如一个总和,或者找到一个最大值,但还没有找到一种方法来开始应用导致新行的复杂条件。

任何开始解决这个问题的帮助都将不胜感激。

您可以对分组数据使用pandas_udf for adding/deleting rows/col,并在pandas udf 中实现您的处理逻辑。

import pyspark.sql.functions as F

row_schema = StructType(
    [StructField("id", IntegerType(), True),
     StructField("variable", StringType(), True),
     StructField("value", IntegerType(), True)]
)

@F.pandas_udf(row_schema, F.PandasUDFType.GROUPED_MAP)
def addRow(pdf):
    val = 1
    if  (len(pdf.loc[(pdf['variable'] == 'varA') & (pdf['value'] > 16)]) > 0 ) & \
        (len(pdf.loc[(pdf['variable'] == 'varC') & (pdf['value'] != -9)]) > 0):
        val = 2
    return pdf.append(pd.Series([1, 'varD', val], index=['id', 'variable', 'value']), ignore_index=True)

df = spark.createDataFrame([[1, 'varA', 30],
                            [1, 'varB', 1],
                            [1, 'varC', -9]
                            ], schema=['id', 'variable', 'value'])

df.groupBy("id").apply(addRow).show()

结果

+---+--------+-----+
| id|variable|value|
+---+--------+-----+
|  1|    varA|   30|
|  1|    varB|    1|
|  1|    varC|   -9|
|  1|    varD|    1|
+---+--------+-----+