Pyspark:如何查找前 5 行值并将其转换为 1,其余全部为 0?

Pyspark : How to find and convert top 5 row values to 1 and rest all to 0?

我有一个数据框,我需要在每一行中找到最多 5 个值,仅将这些值转换为 1,其余全部为 0,同时保持数据框结构,即列名应保持不变

我尝试使用 toLocalIterator,然后将每一行转换为列表,然后将前 5 行转换为值 1。 但是当我 运行 大型数据集上的代码时,它给了我一个 java.lang.outOfMemoryError。 在查看日志时,我发现提交了一个非常大的任务(大约 25000KB),而最大推荐大小为 100KB

有没有更好的方法来查找前 5 个值并将其转换为特定值(在本例中为 1)并将其余值全部转换为 0,这样会占用更少的内存

编辑 1:

例如,如果我有这 10 列和 5 行作为输入

+----+----+----+----+----+----+----+----+----+----+
|   1|   2|   3|   4|   5|   6|   7|   8|   9|  10|
+----+----+----+----+----+----+----+----+----+----+
|0.74| 0.9|0.52|0.85|0.18|0.23| 0.3| 0.0| 0.1|0.07|
|0.11|0.57|0.81|0.81|0.45|0.48|0.86|0.38|0.41|0.45|
|0.03|0.84|0.17|0.96|0.09|0.73|0.25|0.05|0.57|0.66|
| 0.8|0.94|0.06|0.44| 0.2|0.89| 0.9| 1.0|0.48|0.14|
|0.73|0.86|0.68| 1.0|0.78|0.17|0.11|0.19|0.18|0.83|
+----+----+----+----+----+----+----+----+----+----+

这就是我想要的输出

+---+---+---+---+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|  7|  8|  9| 10|
+---+---+---+---+---+---+---+---+---+---+
|  1|  1|  1|  1|  0|  0|  1|  0|  0|  0|
|  0|  1|  1|  1|  0|  1|  1|  0|  0|  0|
|  0|  1|  0|  1|  0|  1|  0|  0|  1|  1|
|  1|  1|  0|  0|  0|  1|  1|  1|  0|  0|
|  1|  1|  0|  1|  1|  0|  0|  0|  0|  1|
+---+---+---+---+---+---+---+---+---+---+

如您所见,我想在每行中找到前(最大)5 个值,将它们转换为 1,将其余值转换为 0,同时保持结构,即行和列

这就是我正在使用的(这会给我 outOfMemoryError)

for row in prob_df.rdd.toLocalIterator():
    rowPredDict = {}
    for cat in categories:
        rowPredDict[cat]= row[cat]
        sorted_row = sorted(rowPredDict.items(), key=lambda kv: kv[1],reverse=True)
    #print(rowPredDict)
    rowPredDict = rowPredDict.fromkeys(rowPredDict,0)
    rowPredDict[sorted_row[0:5][0][0]] = 1
    rowPredDict[sorted_row[0:5][1][0]] = 1
    rowPredDict[sorted_row[0:5][2][0]] = 1
    rowPredDict[sorted_row[0:5][3][0]] = 1
    rowPredDict[sorted_row[0:5][4][0]] = 1
    #print(count,sorted_row[0:2][0][0],",",sorted_row[0:2][1][0])
    rowPredList.append(rowPredDict)
    #count=count+1

您可以像这样轻松地做到这一点。

例如,我们想对值列执行该任务,因此首先对值列进行排序,取第 5 个值,然后使用 when 条件更改值。

df2 = sc.parallelize([("fo", 100,20),("rogerg", 110,56),("franre", 1080,297),("f11", 10100,217),("franci", 10,227),("fran", 1002,5),("fran231cis", 10007,271),("franc3is", 1030,2)]).toDF(["name", "salary","value"])
df2 = df2.orderBy("value",ascending=False)

+----------+------+-----+
|      name|salary|value|
+----------+------+-----+
|    franre|  1080|  297|
|fran231cis| 10007|  271|
|    franci|    10|  227|
|       f11| 10100|  217|
|    rogerg|   110|   56|
|        fo|   100|   20|
|      fran|  1002|    5|
|  franc3is|  1030|    2|
+----------+------+-----+

maxx = df2.take(5)[4]["value"]
dff = df2.select(when(df2['value'] >= maxx, 1).otherwise(0).alias("value"),"name", "salary")

+---+----------+------+
|value|      name|salary|
+---+----------+------+
|  1|    franre|  1080|
|  1|fran231cis| 10007|
|  1|    franci|    10|
|  1|       f11| 10100|
|  1|    rogerg|   110|
|  0|        fo|   100|
|  0|      fran|  1002|
|  0|  franc3is|  1030|
+---+----------+------+

我没有足够的容量来进行性能测试,但您可以使用 spark functions array api

尝试以下方法吗

1.准备数据集:

import pyspark.sql.functions as f

l1 = [(0.74,0.9,0.52,0.85,0.18,0.23,0.3,0.0,0.1,0.07),
    (0.11,0.57,0.81,0.81,0.45,0.48,0.86,0.38,0.41,0.45),
    (0.03,0.84,0.17,0.96,0.09,0.73,0.25,0.05,0.57,0.66),
    (0.8,0.94,0.06,0.44,0.2,0.89,0.9,1.0,0.48,0.14),
    (0.73,0.86,0.68,1.0,0.78,0.17,0.11,0.19,0.18,0.83)]

df = spark.createDataFrame(l1).toDF('col_1','col_2','col_3','col_4','col_5','col_6','col_7','col_8','col_9','col_10')
df.show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
| 0.74|  0.9| 0.52| 0.85| 0.18| 0.23|  0.3|  0.0|  0.1|  0.07|
| 0.11| 0.57| 0.81| 0.81| 0.45| 0.48| 0.86| 0.38| 0.41|  0.45|
| 0.03| 0.84| 0.17| 0.96| 0.09| 0.73| 0.25| 0.05| 0.57|  0.66|
|  0.8| 0.94| 0.06| 0.44|  0.2| 0.89|  0.9|  1.0| 0.48|  0.14|
| 0.73| 0.86| 0.68|  1.0| 0.78| 0.17| 0.11| 0.19| 0.18|  0.83|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+

2。每行获得前 5 名

按照 df

上的以下步骤
  • 创建数组并对元素排序
  • 将前 5 个元素放入名为 all
  • 的新列中

UDF 从已排序的元素中获取最多 5 个元素:

注: spark >= 2.4.0slice功能可以做类似的任务。我目前使用的是 2.2,因此创建了 UDF,但是如果您有 2.4 或更高版本,那么您可以尝试使用 slice

def get_n_elements_(arr, n):
                return arr[:n]

get_n_elements = f.udf(get_n_elements_, t.ArrayType(t.DoubleType()))

df_all = df.withColumn('all', get_n_elements(f.sort_array(f.array(df.columns), False),f.lit(5)))

df_all.show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|all                           |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
|0.74 |0.9  |0.52 |0.85 |0.18 |0.23 |0.3  |0.0  |0.1  |0.07  |[0.9, 0.85, 0.74, 0.52, 0.3]  |
|0.11 |0.57 |0.81 |0.81 |0.45 |0.48 |0.86 |0.38 |0.41 |0.45  |[0.86, 0.81, 0.81, 0.57, 0.48]|
|0.03 |0.84 |0.17 |0.96 |0.09 |0.73 |0.25 |0.05 |0.57 |0.66  |[0.96, 0.84, 0.73, 0.66, 0.57]|
|0.8  |0.94 |0.06 |0.44 |0.2  |0.89 |0.9  |1.0  |0.48 |0.14  |[1.0, 0.94, 0.9, 0.89, 0.8]   |
|0.73 |0.86 |0.68 |1.0  |0.78 |0.17 |0.11 |0.19 |0.18 |0.83  |[1.0, 0.86, 0.83, 0.78, 0.73] |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+

3。创建动态 sql 并执行 selectExpr

sql_stmt = ''' case when array_contains(all, {0}) then 1 else 0 end AS `{0}` '''
df_all.selectExpr(*[sql_stmt.format(c) for c in df.columns]).show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
|    1|    1|    1|    1|    0|    0|    1|    0|    0|     0|
|    0|    1|    1|    1|    0|    1|    1|    0|    0|     0|
|    0|    1|    0|    1|    0|    1|    0|    0|    1|     1|
|    1|    1|    0|    0|    0|    1|    1|    1|    0|     0|
|    1|    1|    0|    1|    1|    0|    0|    0|    0|     1|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+