Pyspark:如何查找前 5 行值并将其转换为 1,其余全部为 0?
Pyspark : How to find and convert top 5 row values to 1 and rest all to 0?
我有一个数据框,我需要在每一行中找到最多 5 个值,仅将这些值转换为 1,其余全部为 0,同时保持数据框结构,即列名应保持不变
我尝试使用 toLocalIterator,然后将每一行转换为列表,然后将前 5 行转换为值 1。
但是当我 运行 大型数据集上的代码时,它给了我一个 java.lang.outOfMemoryError。
在查看日志时,我发现提交了一个非常大的任务(大约 25000KB),而最大推荐大小为 100KB
有没有更好的方法来查找前 5 个值并将其转换为特定值(在本例中为 1)并将其余值全部转换为 0,这样会占用更少的内存
编辑 1:
例如,如果我有这 10 列和 5 行作为输入
+----+----+----+----+----+----+----+----+----+----+
| 1| 2| 3| 4| 5| 6| 7| 8| 9| 10|
+----+----+----+----+----+----+----+----+----+----+
|0.74| 0.9|0.52|0.85|0.18|0.23| 0.3| 0.0| 0.1|0.07|
|0.11|0.57|0.81|0.81|0.45|0.48|0.86|0.38|0.41|0.45|
|0.03|0.84|0.17|0.96|0.09|0.73|0.25|0.05|0.57|0.66|
| 0.8|0.94|0.06|0.44| 0.2|0.89| 0.9| 1.0|0.48|0.14|
|0.73|0.86|0.68| 1.0|0.78|0.17|0.11|0.19|0.18|0.83|
+----+----+----+----+----+----+----+----+----+----+
这就是我想要的输出
+---+---+---+---+---+---+---+---+---+---+
| 1| 2| 3| 4| 5| 6| 7| 8| 9| 10|
+---+---+---+---+---+---+---+---+---+---+
| 1| 1| 1| 1| 0| 0| 1| 0| 0| 0|
| 0| 1| 1| 1| 0| 1| 1| 0| 0| 0|
| 0| 1| 0| 1| 0| 1| 0| 0| 1| 1|
| 1| 1| 0| 0| 0| 1| 1| 1| 0| 0|
| 1| 1| 0| 1| 1| 0| 0| 0| 0| 1|
+---+---+---+---+---+---+---+---+---+---+
如您所见,我想在每行中找到前(最大)5 个值,将它们转换为 1,将其余值转换为 0,同时保持结构,即行和列
这就是我正在使用的(这会给我 outOfMemoryError)
for row in prob_df.rdd.toLocalIterator():
rowPredDict = {}
for cat in categories:
rowPredDict[cat]= row[cat]
sorted_row = sorted(rowPredDict.items(), key=lambda kv: kv[1],reverse=True)
#print(rowPredDict)
rowPredDict = rowPredDict.fromkeys(rowPredDict,0)
rowPredDict[sorted_row[0:5][0][0]] = 1
rowPredDict[sorted_row[0:5][1][0]] = 1
rowPredDict[sorted_row[0:5][2][0]] = 1
rowPredDict[sorted_row[0:5][3][0]] = 1
rowPredDict[sorted_row[0:5][4][0]] = 1
#print(count,sorted_row[0:2][0][0],",",sorted_row[0:2][1][0])
rowPredList.append(rowPredDict)
#count=count+1
您可以像这样轻松地做到这一点。
例如,我们想对值列执行该任务,因此首先对值列进行排序,取第 5 个值,然后使用 when 条件更改值。
df2 = sc.parallelize([("fo", 100,20),("rogerg", 110,56),("franre", 1080,297),("f11", 10100,217),("franci", 10,227),("fran", 1002,5),("fran231cis", 10007,271),("franc3is", 1030,2)]).toDF(["name", "salary","value"])
df2 = df2.orderBy("value",ascending=False)
+----------+------+-----+
| name|salary|value|
+----------+------+-----+
| franre| 1080| 297|
|fran231cis| 10007| 271|
| franci| 10| 227|
| f11| 10100| 217|
| rogerg| 110| 56|
| fo| 100| 20|
| fran| 1002| 5|
| franc3is| 1030| 2|
+----------+------+-----+
maxx = df2.take(5)[4]["value"]
dff = df2.select(when(df2['value'] >= maxx, 1).otherwise(0).alias("value"),"name", "salary")
+---+----------+------+
|value| name|salary|
+---+----------+------+
| 1| franre| 1080|
| 1|fran231cis| 10007|
| 1| franci| 10|
| 1| f11| 10100|
| 1| rogerg| 110|
| 0| fo| 100|
| 0| fran| 1002|
| 0| franc3is| 1030|
+---+----------+------+
我没有足够的容量来进行性能测试,但您可以使用 spark functions array
api
尝试以下方法吗
1.准备数据集:
import pyspark.sql.functions as f
l1 = [(0.74,0.9,0.52,0.85,0.18,0.23,0.3,0.0,0.1,0.07),
(0.11,0.57,0.81,0.81,0.45,0.48,0.86,0.38,0.41,0.45),
(0.03,0.84,0.17,0.96,0.09,0.73,0.25,0.05,0.57,0.66),
(0.8,0.94,0.06,0.44,0.2,0.89,0.9,1.0,0.48,0.14),
(0.73,0.86,0.68,1.0,0.78,0.17,0.11,0.19,0.18,0.83)]
df = spark.createDataFrame(l1).toDF('col_1','col_2','col_3','col_4','col_5','col_6','col_7','col_8','col_9','col_10')
df.show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
| 0.74| 0.9| 0.52| 0.85| 0.18| 0.23| 0.3| 0.0| 0.1| 0.07|
| 0.11| 0.57| 0.81| 0.81| 0.45| 0.48| 0.86| 0.38| 0.41| 0.45|
| 0.03| 0.84| 0.17| 0.96| 0.09| 0.73| 0.25| 0.05| 0.57| 0.66|
| 0.8| 0.94| 0.06| 0.44| 0.2| 0.89| 0.9| 1.0| 0.48| 0.14|
| 0.73| 0.86| 0.68| 1.0| 0.78| 0.17| 0.11| 0.19| 0.18| 0.83|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
2。每行获得前 5 名
按照 df
上的以下步骤
- 创建数组并对元素排序
- 将前 5 个元素放入名为
all
的新列中
UDF 从已排序的元素中获取最多 5 个元素:
注: spark >= 2.4.0有slice
功能可以做类似的任务。我目前使用的是 2.2,因此创建了 UDF,但是如果您有 2.4 或更高版本,那么您可以尝试使用 slice
def get_n_elements_(arr, n):
return arr[:n]
get_n_elements = f.udf(get_n_elements_, t.ArrayType(t.DoubleType()))
df_all = df.withColumn('all', get_n_elements(f.sort_array(f.array(df.columns), False),f.lit(5)))
df_all.show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|all |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
|0.74 |0.9 |0.52 |0.85 |0.18 |0.23 |0.3 |0.0 |0.1 |0.07 |[0.9, 0.85, 0.74, 0.52, 0.3] |
|0.11 |0.57 |0.81 |0.81 |0.45 |0.48 |0.86 |0.38 |0.41 |0.45 |[0.86, 0.81, 0.81, 0.57, 0.48]|
|0.03 |0.84 |0.17 |0.96 |0.09 |0.73 |0.25 |0.05 |0.57 |0.66 |[0.96, 0.84, 0.73, 0.66, 0.57]|
|0.8 |0.94 |0.06 |0.44 |0.2 |0.89 |0.9 |1.0 |0.48 |0.14 |[1.0, 0.94, 0.9, 0.89, 0.8] |
|0.73 |0.86 |0.68 |1.0 |0.78 |0.17 |0.11 |0.19 |0.18 |0.83 |[1.0, 0.86, 0.83, 0.78, 0.73] |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
3。创建动态 sql 并执行 selectExpr
sql_stmt = ''' case when array_contains(all, {0}) then 1 else 0 end AS `{0}` '''
df_all.selectExpr(*[sql_stmt.format(c) for c in df.columns]).show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
| 1| 1| 1| 1| 0| 0| 1| 0| 0| 0|
| 0| 1| 1| 1| 0| 1| 1| 0| 0| 0|
| 0| 1| 0| 1| 0| 1| 0| 0| 1| 1|
| 1| 1| 0| 0| 0| 1| 1| 1| 0| 0|
| 1| 1| 0| 1| 1| 0| 0| 0| 0| 1|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
我有一个数据框,我需要在每一行中找到最多 5 个值,仅将这些值转换为 1,其余全部为 0,同时保持数据框结构,即列名应保持不变
我尝试使用 toLocalIterator,然后将每一行转换为列表,然后将前 5 行转换为值 1。 但是当我 运行 大型数据集上的代码时,它给了我一个 java.lang.outOfMemoryError。 在查看日志时,我发现提交了一个非常大的任务(大约 25000KB),而最大推荐大小为 100KB
有没有更好的方法来查找前 5 个值并将其转换为特定值(在本例中为 1)并将其余值全部转换为 0,这样会占用更少的内存
编辑 1:
例如,如果我有这 10 列和 5 行作为输入
+----+----+----+----+----+----+----+----+----+----+
| 1| 2| 3| 4| 5| 6| 7| 8| 9| 10|
+----+----+----+----+----+----+----+----+----+----+
|0.74| 0.9|0.52|0.85|0.18|0.23| 0.3| 0.0| 0.1|0.07|
|0.11|0.57|0.81|0.81|0.45|0.48|0.86|0.38|0.41|0.45|
|0.03|0.84|0.17|0.96|0.09|0.73|0.25|0.05|0.57|0.66|
| 0.8|0.94|0.06|0.44| 0.2|0.89| 0.9| 1.0|0.48|0.14|
|0.73|0.86|0.68| 1.0|0.78|0.17|0.11|0.19|0.18|0.83|
+----+----+----+----+----+----+----+----+----+----+
这就是我想要的输出
+---+---+---+---+---+---+---+---+---+---+
| 1| 2| 3| 4| 5| 6| 7| 8| 9| 10|
+---+---+---+---+---+---+---+---+---+---+
| 1| 1| 1| 1| 0| 0| 1| 0| 0| 0|
| 0| 1| 1| 1| 0| 1| 1| 0| 0| 0|
| 0| 1| 0| 1| 0| 1| 0| 0| 1| 1|
| 1| 1| 0| 0| 0| 1| 1| 1| 0| 0|
| 1| 1| 0| 1| 1| 0| 0| 0| 0| 1|
+---+---+---+---+---+---+---+---+---+---+
如您所见,我想在每行中找到前(最大)5 个值,将它们转换为 1,将其余值转换为 0,同时保持结构,即行和列
这就是我正在使用的(这会给我 outOfMemoryError)
for row in prob_df.rdd.toLocalIterator():
rowPredDict = {}
for cat in categories:
rowPredDict[cat]= row[cat]
sorted_row = sorted(rowPredDict.items(), key=lambda kv: kv[1],reverse=True)
#print(rowPredDict)
rowPredDict = rowPredDict.fromkeys(rowPredDict,0)
rowPredDict[sorted_row[0:5][0][0]] = 1
rowPredDict[sorted_row[0:5][1][0]] = 1
rowPredDict[sorted_row[0:5][2][0]] = 1
rowPredDict[sorted_row[0:5][3][0]] = 1
rowPredDict[sorted_row[0:5][4][0]] = 1
#print(count,sorted_row[0:2][0][0],",",sorted_row[0:2][1][0])
rowPredList.append(rowPredDict)
#count=count+1
您可以像这样轻松地做到这一点。
例如,我们想对值列执行该任务,因此首先对值列进行排序,取第 5 个值,然后使用 when 条件更改值。
df2 = sc.parallelize([("fo", 100,20),("rogerg", 110,56),("franre", 1080,297),("f11", 10100,217),("franci", 10,227),("fran", 1002,5),("fran231cis", 10007,271),("franc3is", 1030,2)]).toDF(["name", "salary","value"])
df2 = df2.orderBy("value",ascending=False)
+----------+------+-----+
| name|salary|value|
+----------+------+-----+
| franre| 1080| 297|
|fran231cis| 10007| 271|
| franci| 10| 227|
| f11| 10100| 217|
| rogerg| 110| 56|
| fo| 100| 20|
| fran| 1002| 5|
| franc3is| 1030| 2|
+----------+------+-----+
maxx = df2.take(5)[4]["value"]
dff = df2.select(when(df2['value'] >= maxx, 1).otherwise(0).alias("value"),"name", "salary")
+---+----------+------+
|value| name|salary|
+---+----------+------+
| 1| franre| 1080|
| 1|fran231cis| 10007|
| 1| franci| 10|
| 1| f11| 10100|
| 1| rogerg| 110|
| 0| fo| 100|
| 0| fran| 1002|
| 0| franc3is| 1030|
+---+----------+------+
我没有足够的容量来进行性能测试,但您可以使用 spark functions array
api
1.准备数据集:
import pyspark.sql.functions as f
l1 = [(0.74,0.9,0.52,0.85,0.18,0.23,0.3,0.0,0.1,0.07),
(0.11,0.57,0.81,0.81,0.45,0.48,0.86,0.38,0.41,0.45),
(0.03,0.84,0.17,0.96,0.09,0.73,0.25,0.05,0.57,0.66),
(0.8,0.94,0.06,0.44,0.2,0.89,0.9,1.0,0.48,0.14),
(0.73,0.86,0.68,1.0,0.78,0.17,0.11,0.19,0.18,0.83)]
df = spark.createDataFrame(l1).toDF('col_1','col_2','col_3','col_4','col_5','col_6','col_7','col_8','col_9','col_10')
df.show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
| 0.74| 0.9| 0.52| 0.85| 0.18| 0.23| 0.3| 0.0| 0.1| 0.07|
| 0.11| 0.57| 0.81| 0.81| 0.45| 0.48| 0.86| 0.38| 0.41| 0.45|
| 0.03| 0.84| 0.17| 0.96| 0.09| 0.73| 0.25| 0.05| 0.57| 0.66|
| 0.8| 0.94| 0.06| 0.44| 0.2| 0.89| 0.9| 1.0| 0.48| 0.14|
| 0.73| 0.86| 0.68| 1.0| 0.78| 0.17| 0.11| 0.19| 0.18| 0.83|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
2。每行获得前 5 名
按照 df
- 创建数组并对元素排序
- 将前 5 个元素放入名为
all
的新列中
UDF 从已排序的元素中获取最多 5 个元素:
注: spark >= 2.4.0有slice
功能可以做类似的任务。我目前使用的是 2.2,因此创建了 UDF,但是如果您有 2.4 或更高版本,那么您可以尝试使用 slice
def get_n_elements_(arr, n):
return arr[:n]
get_n_elements = f.udf(get_n_elements_, t.ArrayType(t.DoubleType()))
df_all = df.withColumn('all', get_n_elements(f.sort_array(f.array(df.columns), False),f.lit(5)))
df_all.show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|all |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
|0.74 |0.9 |0.52 |0.85 |0.18 |0.23 |0.3 |0.0 |0.1 |0.07 |[0.9, 0.85, 0.74, 0.52, 0.3] |
|0.11 |0.57 |0.81 |0.81 |0.45 |0.48 |0.86 |0.38 |0.41 |0.45 |[0.86, 0.81, 0.81, 0.57, 0.48]|
|0.03 |0.84 |0.17 |0.96 |0.09 |0.73 |0.25 |0.05 |0.57 |0.66 |[0.96, 0.84, 0.73, 0.66, 0.57]|
|0.8 |0.94 |0.06 |0.44 |0.2 |0.89 |0.9 |1.0 |0.48 |0.14 |[1.0, 0.94, 0.9, 0.89, 0.8] |
|0.73 |0.86 |0.68 |1.0 |0.78 |0.17 |0.11 |0.19 |0.18 |0.83 |[1.0, 0.86, 0.83, 0.78, 0.73] |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------------------------------+
3。创建动态 sql 并执行 selectExpr
sql_stmt = ''' case when array_contains(all, {0}) then 1 else 0 end AS `{0}` '''
df_all.selectExpr(*[sql_stmt.format(c) for c in df.columns]).show()
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
| 1| 1| 1| 1| 0| 0| 1| 0| 0| 0|
| 0| 1| 1| 1| 0| 1| 1| 0| 0| 0|
| 0| 1| 0| 1| 0| 1| 0| 0| 1| 1|
| 1| 1| 0| 0| 0| 1| 1| 1| 0| 0|
| 1| 1| 0| 1| 1| 0| 0| 0| 0| 1|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+