通过条件pyspark相互减去列内的多个值

Subtracting multiple values inside a column from each other by conditional pyspark

我正在为一些我认为应该非常微不足道的事情而苦苦挣扎。

我有一个输入数据框:

id months feature_a feature_b feature_c
1 1 2 1 3
1 6 3 2 4
2 1 2 1 3
2 6 3 2 4

现在我想按 id 分组,并为每个特征减去月份为 6 的值减去月份为 1 的值。得到的输出数据帧为:

id feature_a feature_b feature_c
1 1 1 1
2 1 1 1

现在我设法用下面的代码做到这一点:

def get_month_diff(df, start_month=1, end_month=6):

    columns_to_agg = ['feature_a', 'feature_b', 'feature_c']
    result = (df
               .groupby('id')
               .pivot('months')
               .agg(*[F.sum(col).alias(f'{col}') for col in columns_to_agg])
               )
    #  Pyspark doesnt work nice with columns that have '.'s in them
    result = result.toDF(*(c.replace('.', '_') for c in result.columns))

    for col in columns_to_agg:
        result = result.withColumn(col, result[f"{end_month}_0_{col}"] - result[f"{start_month}_0_{col}"])

    return result

我不喜欢我必须从另一列中减去一列并在第一个火花转换之外创建这些新列。所以我正在寻找解决方案。

因此我想问是否有人可以在正确的方向上帮助我解决这个问题?

如果使用 maxwhen 来获取第 1 个月和第 6 个月的值,则可以避免数据透视:

import pyspark.sql.functions as F

df2 = df.groupBy('id').agg(*[
    (
        F.max(F.when(F.col('months') == 6, F.col(c))) - 
        F.max(F.when(F.col('months') == 1, F.col(c)))
    ).alias(c) 
    for c in df.columns[2:]
])

df2.show()
+---+---------+---------+---------+
| id|feature_a|feature_b|feature_c|
+---+---------+---------+---------+
|  1|        1|        1|        1|
|  2|        1|        1|        1|
+---+---------+---------+---------+