Pyspark:如何从 collect_set 中删除项目?
Pyspark: How to remove an item from a collect_set?
在以下数据框中:
from pyspark.sql import functions as F
df = sqlContext.createDataFrame([
("a", "code1", "name"),
("a", "code1", "name2"),
("a", "code2", "name2"),
], ["id", "code", "name"])
df.show()
您可以运行此命令获取不同值的列表:
df.groupby("id").agg(F.collect_set("code")).show()
+---+-----------------+
| id|collect_set(code)|
+---+-----------------+
| a| [code2, code1]|
+---+-----------------+
如何删除上面的项目collect_set?例如。如何删除 'code2'
Spark 2.4+ 更新: 您可以通过 array_remove
:
实现
df_grouped = df.groupby("id")\
.agg(F.array_remove(F.collect_set("code"), "code2").alias("codes"))
Spark 2.3 及以下版本的原始答案
AFAIK 无法动态 ,因此如果您的数据已经在数组中,您有两个选择:
选项 1:分解、过滤、收集
使用pyspark.sql.functions.explode()
to turn the elements of the array into separate rows. Then use pyspark.sql.DataFrame.where()
筛选出所需的值。最后执行 groupBy()
和 collect_set()
将数据收集回一行。
df_grouped = df.groupby("id").agg(F.collect_set("code").alias("codes"))
df_grouped.select("*", F.explode("codes").alias("exploded"))\
.where(~F.col("exploded").isin(["code2"]))\
.groupBy("id")\
.agg(F.collect_set("exploded").alias("codes"))\
.show()
#+---+-------+
#| id| codes|
#+---+-------+
#| a|[code1]|
#+---+-------+
选项 2:使用 UDF
def filter_code(array):
bad_values={"code2"}
return [x for x in array if x not in bad_values]
filter_code_udf = F.udf(lambda x: filter_code(x), ArrayType(StringType()))
df_grouped = df.groupby("id").agg(F.collect_set("code").alias("codes"))
df_grouped.withColumn("codes_filtered", filter_code_udf("codes")).show()
#+---+--------------+--------------+
#| id| codes|codes_filtered|
#+---+--------------+--------------+
#| a|[code2, code1]| [code1]|
#+---+--------------+--------------+
当然,如果您从原始数据框开始(在 groupBy()
和 collect_set()
之前),您可以先过滤所需的值:
df.where(~F.col("code").isin(["code2"])).groupby("id").agg(F.collect_set("code")).show()
#+---+-----------------+
#| id|collect_set(code)|
#+---+-----------------+
#| a| [code1]|
#+---+-----------------+
在以下数据框中:
from pyspark.sql import functions as F
df = sqlContext.createDataFrame([
("a", "code1", "name"),
("a", "code1", "name2"),
("a", "code2", "name2"),
], ["id", "code", "name"])
df.show()
您可以运行此命令获取不同值的列表:
df.groupby("id").agg(F.collect_set("code")).show()
+---+-----------------+
| id|collect_set(code)|
+---+-----------------+
| a| [code2, code1]|
+---+-----------------+
如何删除上面的项目collect_set?例如。如何删除 'code2'
Spark 2.4+ 更新: 您可以通过 array_remove
:
df_grouped = df.groupby("id")\
.agg(F.array_remove(F.collect_set("code"), "code2").alias("codes"))
Spark 2.3 及以下版本的原始答案
AFAIK 无法动态
选项 1:分解、过滤、收集
使用pyspark.sql.functions.explode()
to turn the elements of the array into separate rows. Then use pyspark.sql.DataFrame.where()
筛选出所需的值。最后执行 groupBy()
和 collect_set()
将数据收集回一行。
df_grouped = df.groupby("id").agg(F.collect_set("code").alias("codes"))
df_grouped.select("*", F.explode("codes").alias("exploded"))\
.where(~F.col("exploded").isin(["code2"]))\
.groupBy("id")\
.agg(F.collect_set("exploded").alias("codes"))\
.show()
#+---+-------+
#| id| codes|
#+---+-------+
#| a|[code1]|
#+---+-------+
选项 2:使用 UDF
def filter_code(array):
bad_values={"code2"}
return [x for x in array if x not in bad_values]
filter_code_udf = F.udf(lambda x: filter_code(x), ArrayType(StringType()))
df_grouped = df.groupby("id").agg(F.collect_set("code").alias("codes"))
df_grouped.withColumn("codes_filtered", filter_code_udf("codes")).show()
#+---+--------------+--------------+
#| id| codes|codes_filtered|
#+---+--------------+--------------+
#| a|[code2, code1]| [code1]|
#+---+--------------+--------------+
当然,如果您从原始数据框开始(在 groupBy()
和 collect_set()
之前),您可以先过滤所需的值:
df.where(~F.col("code").isin(["code2"])).groupby("id").agg(F.collect_set("code")).show()
#+---+-----------------+
#| id|collect_set(code)|
#+---+-----------------+
#| a| [code1]|
#+---+-----------------+