PySpark - 根据列表过滤数据框列

PySpark - Filter dataframe columns based on list

我有一个包含一些列名称的数据框,我想根据列表过滤掉一些列。

我有一个我希望在最终数据框中包含的列的列表:

final_columns = ['A','C','E']

我的数据框是这样的:

data1 = [("James",  "Lee", "Smith","36636"),
         ("Michael","Rose","Boots","40288")]

schema1 = StructType([StructField("A",StringType(),True),    
                      StructField("B",StringType(),True),    
                      StructField("C",StringType(),True),    
                      StructField("D",StringType(),True)])

df1 = spark.createDataFrame(data=data1,schema=schema1)

我想转换 df1 以获得此 final_columns 列表的列。

所以,基本上,我希望生成的数据框看起来像这样

+--------+------+------+ 
|      A |    C |    E | 
+--------+------+------+ 
|  James |Smith |      | 
|Michael |Boots |      | 
+--------+------+------+

有什么聪明的方法可以做到这一点吗?

提前致谢

这是一种方法:使用 DataFrame drop() 方法和一个列表,该列表表示 DataFrame 当前列和最终列列表之间的

df = spark.createDataFrame([(1, 1, "1", 0.1),(1, 2, "1", 0.2),(3, 3, "3", 0.3)],('a','b','c','d'))

df.show()
+---+---+---+---+
|  a|  b|  c|  d|
+---+---+---+---+
|  1|  1|  1|0.1|
|  1|  2|  1|0.2|
|  3|  3|  3|0.3|
+---+---+---+---+

# list of desired final columns
final_cols = ['a', 'c', 'd']

df2 = df.drop( *set(final_cols).symmetric_difference(df.columns) )

注意对称差分运算的另一种语法:

df2 = df.drop( *(set(final_cols) ^ set(df.columns)) )

这给了我:

+---+---+---+
|  a|  c|  d|
+---+---+---+
|  1|  1|0.1|
|  1|  1|0.2|
|  3|  3|0.3|
+---+---+---+

我相信这就是你想要的。

根据您的要求编写了动态代码。这将根据提供的列表 select 列,如果 source/original 数据框中不存在该列,也会创建具有空值的列。

data1 = [("James",  "Lee", "Smith","36636"),
         ("Michael","Rose","Boots","40288")]

schema1 = StructType([StructField("A",StringType(),True),    
                      StructField("B",StringType(),True),    
                      StructField("C",StringType(),True),    
                      StructField("D",StringType(),True)])

df1 = spark.createDataFrame(data=data1,schema=schema1)
actual_columns = df1.schema.names
final_columns = ['A','C','E']


def Diff(li1, li2):
  diff = list(set(li2) - set(li1))
  return diff
def Same(li1, li2):
  same = list(sorted(set(li1).intersection(li2)))
  return same

df1 = df1.select(*Same(actual_columns,final_columns))
for i in Diff(actual_columns,final_columns):
  df1 = df1.withColumn(""+i+"",lit(''))
display(df1)

您可以使用 select 和列表理解来做到这一点。这个想法是遍历 final_columns,如果列在 df.colums 中,则添加它,如果不存在,则使用 lit 以正确的别名添加它。

如果您发现列表推导式的可读性较差,您可以使用 for 循环编写类似的逻辑。

from pyspark.sql.functions import lit

df1.select([c if c in df1.columns else lit(None).alias(c) for c in final_columns]).show()

+-------+-----+----+                                                            
|      A|    C|   E|
+-------+-----+----+
|  James|Smith|null|
|Michael|Boots|null|
+-------+-----+----+