如何从 PySpark 列表中 select 行

How to select rows from list in PySpark

假设我们有两个数据帧 df1df2,其中 df1 有列 [a, b, c, p, q, r]df2 有列 [d, e, f, a, b, c]。假设公共列存储在列表 common_cols = ['a', 'b', 'c'].

如何在 sql 命令中使用 common_cols 列表连接两个数据帧?下面的代码尝试这样做。

common_cols = ['a', 'b', 'c']
filter_df = spark.sql("""
    select * from df1 inner join df2
    on df1.common_cols = df2.common_cols
""")

您可以使用 using 而不是 on。参见 documentation

common_cols = ['a', 'b', 'c']

spark.sql(
    f'''
    SELECT *
    FROM
    (SELECT 1 a, 2 b, 3 c, 10 val1)
    JOIN
    (SELECT 1 a, 2 b, 3 c, 20 val2)
    USING ({','.join(common_cols)})
    '''
).show()

+---+---+---+----+----+
|  a|  b|  c|val1|val2|
+---+---+---+----+----+
|  1|  2|  3|  10|  20|
+---+---+---+----+----+

演示设置

df1 = spark.createDataFrame([(1,2,3,4,5,6)],['a','b','c','p','q','r'])
df2 = spark.createDataFrame([(7,8,9,1,2,3)],['d','e','f','a','b','c'])
common_cols = ['a','b','c']

df1.show()

+---+---+---+---+---+---+
|  a|  b|  c|  p|  q|  r|
+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|
+---+---+---+---+---+---+


df2.show()

+---+---+---+---+---+---+
|  d|  e|  f|  a|  b|  c|
+---+---+---+---+---+---+
|  7|  8|  9|  1|  2|  3|
+---+---+---+---+---+---+

解决方案,基于using(连接的SQL语法)

df1.createOrReplaceTempView('df1')
df2.createOrReplaceTempView('df2')
common_cols_csv = ','.join(common_cols)

query = f'''\
select  * 
from    df1 inner join df2 using ({common_cols_csv})
'''

       

print(query)

select  * 
from    df1 inner join df2 using (a,b,c)

filter_df = spark.sql(query)

filter_df.show()

+---+---+---+---+---+---+---+---+---+
|  a|  b|  c|  p|  q|  r|  d|  e|  f|
+---+---+---+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|  7|  8|  9|
+---+---+---+---+---+---+---+---+---+

添加到@David דודו Markovitz 的回答,以便以动态方式获取列,您可以执行如下操作 -

输入数据

df1 = spark.createDataFrame([(1,2,3,4,5,6)],['a','b','c','p','q','r'])
df2 = spark.createDataFrame([(7,8,9,1,2,3)],['d','e','f','a','b','c'])

df1.createOrReplaceTempView("df1")
df2.createOrReplaceTempView("df2")

使用 set 查找公共列 intersection

common_cols = set(df1.columns).intersection(set(df2.columns))
print(common_cols)

{'a', 'b', 'c'}

正在创建 query string -

query = '''
select  * 
from    df1 inner join df2 using ({common_cols})
'''.format(common_cols=', '.join(map(str, common_cols)))

print(query)

select  * 
from    df1 inner join df2 using (a, b, c)

最后,在spark.sql-

内执行query
spark.sql(query).show()

+---+---+---+---+---+---+---+---+---+
|  a|  b|  c|  p|  q|  r|  d|  e|  f|
+---+---+---+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|  7|  8|  9|
+---+---+---+---+---+---+---+---+---+