如何根据 Pyspark 中的值查找前 n 个键?
How to find the top n keys based on the value in Pyspark?
我有一个 pyspark 数据框,其架构如下所示:
root
|-- query: string (nullable = true)
|-- collect_list(docId): array (nullable = true)
| |-- element: string (containsNull = true)
|-- prod_count_dict: map (nullable = true)
| |-- key: string
| |-- value: integer (valueContainsNull = true)
数据框看起来像这样:
+--------------------+--------------------+--------------------+
| query| collect_list(docId)| prod_count_dict|
+--------------------+--------------------+--------------------+
|1/2 inch plywood ...|[471097-153-12CC,...|[530320-62634-100...|
| 1416445|[1416445-83-HHM5S...|[1054482-2251-FFC...
请注意,prod_count_dict
列是一个包含键值对的字典,例如:
{x: 12, a: 16, b:1, f:3, ....}
我想做的是我只想从键值对中选择 top n
最大的 values
中的 keys
并将其作为列表存储在另一列中like : [x, a,..] 对应于该行。
我试过下面的代码,但它给我一个错误,有什么办法可以解决这个特殊问题吗?
@F.udf(StringType())
def create_label(x):
# If the length of dictionary is less then 20, I want to return the keys of all the items in the dict.
if len(x) >= 20:
val_sort = sorted(list(x.values()), reverse = True)
cutoff = {k: v for (k, v) in x.items() if v > val_sort[20]}
return cutoff.keys()
else:
return x.keys()
label_df = label_count_df.withColumn("label", create_label("prod_count_dict"))
label_df.show()
首先我要引爆字典:
df = df.select("*", f.explode("prod_count_dict").alias("key", "value"))
之后,您可以使用Window函数获取每个键的前n个值
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.partitionBy(df['key']).orderBy(df['value'].desc())
df.select('*', f.rank().over(w).alias('rank'))\
.filter(col('rank') <= 2) \ # setup N here
.drop('rank')
你写的UDF是正确的。您只需要更改实际使用它的代码即可。如果您在 rdd
:
中使用 .map
,这可以很容易地完成
#Let the udf that you have written be a normal python function
def create_label(x):
# If the length of the dictionary is less than 20, I want to return the keys of all the items in the dict.
if len(x) >= 20:
val_sort = sorted(list(x.values()), reverse = True)
cutoff = {k: v for (k, v) in x.items() if v > val_sort[20]}
return cutoff.keys()
else:
return x.keys()
您需要更改的部分是:
label_df_col = ['query','prod_count_dict']
label_df = label_count_df.rdd.map(lambda x:(x.query, create_label(x.prod_count_dict))).toDF(label_df_col)
label_df.show()
这应该有效。
我有一个 pyspark 数据框,其架构如下所示:
root
|-- query: string (nullable = true)
|-- collect_list(docId): array (nullable = true)
| |-- element: string (containsNull = true)
|-- prod_count_dict: map (nullable = true)
| |-- key: string
| |-- value: integer (valueContainsNull = true)
数据框看起来像这样:
+--------------------+--------------------+--------------------+
| query| collect_list(docId)| prod_count_dict|
+--------------------+--------------------+--------------------+
|1/2 inch plywood ...|[471097-153-12CC,...|[530320-62634-100...|
| 1416445|[1416445-83-HHM5S...|[1054482-2251-FFC...
请注意,prod_count_dict
列是一个包含键值对的字典,例如:
{x: 12, a: 16, b:1, f:3, ....}
我想做的是我只想从键值对中选择 top n
最大的 values
中的 keys
并将其作为列表存储在另一列中like : [x, a,..] 对应于该行。
我试过下面的代码,但它给我一个错误,有什么办法可以解决这个特殊问题吗?
@F.udf(StringType())
def create_label(x):
# If the length of dictionary is less then 20, I want to return the keys of all the items in the dict.
if len(x) >= 20:
val_sort = sorted(list(x.values()), reverse = True)
cutoff = {k: v for (k, v) in x.items() if v > val_sort[20]}
return cutoff.keys()
else:
return x.keys()
label_df = label_count_df.withColumn("label", create_label("prod_count_dict"))
label_df.show()
首先我要引爆字典:
df = df.select("*", f.explode("prod_count_dict").alias("key", "value"))
之后,您可以使用Window函数获取每个键的前n个值
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.partitionBy(df['key']).orderBy(df['value'].desc())
df.select('*', f.rank().over(w).alias('rank'))\
.filter(col('rank') <= 2) \ # setup N here
.drop('rank')
你写的UDF是正确的。您只需要更改实际使用它的代码即可。如果您在 rdd
:
.map
,这可以很容易地完成
#Let the udf that you have written be a normal python function
def create_label(x):
# If the length of the dictionary is less than 20, I want to return the keys of all the items in the dict.
if len(x) >= 20:
val_sort = sorted(list(x.values()), reverse = True)
cutoff = {k: v for (k, v) in x.items() if v > val_sort[20]}
return cutoff.keys()
else:
return x.keys()
您需要更改的部分是:
label_df_col = ['query','prod_count_dict']
label_df = label_count_df.rdd.map(lambda x:(x.query, create_label(x.prod_count_dict))).toDF(label_df_col)
label_df.show()
这应该有效。