Pyspark 中按另一列分组的列上的 Softmax 函数

Softmax function on a column groupby another column in Pyspark

我有一个 pyspark 数据框如下:

Variant Category Score Record
A 915 11 Record-1
A 907 10 Record-2
A 914 10 Record-3
B 914 9 Record-1
B 907 2 Record-1

我想计算按 Variant 列分组的 Score 列的 softmax 分数。这将使每个变体的总分达到 100,如下所示。变体可以按行重复 3、2 或 1 次。

Variant Category Score Record Softmax_Score
A 915 11 Record-1 0.35
A 907 10 Record-2 0.32
A 914 10 Record-3 0.32
B 914 9 Record-1 0.82
B 907 2 Record-1 0.18

我知道我们在 python 中有 softmax 函数,但不确定 Pyspark 是如何实现的。

Softmax公式:

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)

Pandas中的方法:

test['Softmax_Score'] = test.groupby('Variant')['Score'].transform(softmax)

您可以在 pyspark 中使用函数 expsum 在被 Variant 分区的 window 上以相同的方式计算它,如下所示:

from pyspark.sql import functions as F

result = df.withColumn(
    "Softmax_Score",
    F.exp("Score") / F.sum(F.exp("Score")).over(Window.partitionBy("Variant"))
)

result.show()
# +-------+--------+-----+--------+--------------------+
# |Variant|Category|Score|  Record|       Softmax_Score|
# +-------+--------+-----+--------+--------------------+
# |      A|     915|   11|Record-1|   0.576116884765829|
# |      A|     907|   10|Record-2| 0.21194155761708544|
# |      A|     914|   10|Record-3| 0.21194155761708544|
# |      B|     914|    9|Record-1|  0.9990889488055994|
# |      B|     907|    2|Record-1|9.110511944006454E-4|
# +-------+--------+-----+--------+--------------------+