使用条件结果列连接 PySpark 数据框

Joining PySpark dataframes with conditional result column

我有这些表:

df1                  df2
+---+------------+   +---+---------+
| id|   many_cols|   | id|criterion|
+---+------------+   +---+---------+
|  1|lots_of_data|   |  1|    false|
|  2|lots_of_data|   |  1|     true|
|  3|lots_of_data|   |  1|     true|
+---+------------+   |  3|    false|
                     +---+---------+

我打算在 df1 中创建额外的列:

+---+------------+------+
| id|   many_cols|result|
+---+------------+------+
|  1|lots_of_data|     1|
|  2|lots_of_data|  null|
|  3|lots_of_data|     0|
+---+------------+------+
如果df2
中有对应的true

result应该是1 如果df2
中没有对应的trueresult应该是0 如果df2

中没有对应的idresult应该是null

我想不出有效的方法。加入后我只坚持第三个条件:

df = df1.join(df2, 'id', 'full')
df.show()

#  +---+------------+---------+
#  | id|   many_cols|criterion|
#  +---+------------+---------+
#  |  1|lots_of_data|    false|
#  |  1|lots_of_data|     true|
#  |  1|lots_of_data|     true|
#  |  3|lots_of_data|    false|
#  |  2|lots_of_data|     null|
#  +---+------------+---------+

PySpark 数据帧是这样创建的:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()

df1cols = ['id', 'many_cols']
df1data = [(1, 'lots_of_data'),
           (2, 'lots_of_data'),
           (3, 'lots_of_data')]
df2cols = ['id', 'criterion']
df2data = [(1, False),
           (1, True),
           (1, True),
           (3, None)]
df1 = spark.createDataFrame(df1data, df1cols)
df2 = spark.createDataFrame(df2data, df2cols)

查看此解决方案。加入后。您可以根据您的要求使用多个条件检查,并使用 when 子句相应地分配值,然后按 id 和其他列获取结果分组的最大值。如果您只对分区使用 id,也可以使用 window 函数来计算结果的最大值。

from pyspark.sql import functions as F
from pyspark.sql.window import Window

df1cols = ['id', 'many_cols']
df1data = [(1, 'lots_of_data'),
           (2, 'lots_of_data'),
           (3, 'lots_of_data')]
df2cols = ['id', 'criterion']
df2data = [(1, False),
           (1, True),
           (1, True),
           (3, False)]
df1 = spark.createDataFrame(df1data, df1cols)
df2 = spark.createDataFrame(df2data, df2cols)

df2_mod =df2.withColumnRenamed("id", "id_2")

df3=df1.join(df2_mod, on=df1.id== df2_mod.id_2, how='left')

cond1 = (F.col("id")== F.col("id_2"))& (F.col("criterion")==1)
cond2 = (F.col("id")== F.col("id_2"))& (F.col("criterion")==0)
cond3 = (F.col("id_2").isNull())

df3.select("id", "many_cols", F.when(cond1, 1).when(cond2,0).when(cond3, F.lit(None)).alias("result"))\
    .groupBy("id", "many_cols").agg(F.max(F.col("result")).alias("result")).orderBy("id").show()

Result:
------

+---+------------+------+
| id|   many_cols|result|
+---+------------+------+
|  1|lots_of_data|     1|
|  2|lots_of_data|  null|
|  3|lots_of_data|     0|
+---+------------+------+

使用window函数

w=Window().partitionBy("id")

df3.select("id", "many_cols", F.when(cond1, 1).when(cond2,0).when(cond3, F.lit(None)).alias("result"))\
    .select("id", "many_cols", F.max("result").over(w).alias("result")).drop_duplicates().show()

一个简单的方法是 groupby df2 通过 iddf1 的连接获得最大值 criterion,这样你就可以减少行数加入。如果至少有一个对应的真值,则布尔列的最大值为真:

from pyspark.sql import functions as F

df2_group = df2.groupBy("id").agg(F.max("criterion").alias("criterion"))

result = df1.join(df2_group, ["id"], "left").withColumn(
    "result",
    F.col("criterion").cast("int")
).drop("criterion")

result.show()
#+---+------------+------+
#| id|   many_cols|result|
#+---+------------+------+
#|  1|lots_of_data|     1|
#|  3|lots_of_data|     0|
#|  2|lots_of_data|  null|
#+---+------------+------+

您可以尝试使用相关子查询从 df2 中获取最大布尔值,并将其转换为整数。

df1.createOrReplaceTempView('df1') 
df2.createOrReplaceTempView('df2') 

df = spark.sql("""
    select
        df1.*,
        (select int(max(criterion)) from df2 where df1.id = df2.id) as result
    from df1
""")

df.show()
+---+------------+------+
| id|   many_cols|result|
+---+------------+------+
|  1|lots_of_data|     1|
|  3|lots_of_data|     0|
|  2|lots_of_data|  null|
+---+------------+------+

我不得不合并建议答案的想法以获得最适合我的解决方案。

# The `cond` variable is very useful, here it represents several complex conditions
cond = F.col('criterion') == True
df2_grp = df2.select(
    'id',
    F.when(cond, 1).otherwise(0).alias('c')
).groupBy('id').agg(F.max(F.col('c')).alias('result'))
df = df1.join(df2_grp, 'id', 'left')

df.show()
#+---+------------+------+
#| id|   many_cols|result|
#+---+------------+------+
#|  1|lots_of_data|     1|
#|  3|lots_of_data|     0|
#|  2|lots_of_data|  null|
#+---+------------+------+