在 pyspark 中旋转 ArrayType 列

Pivoting ArrayType columns in pyspark

我有一个具有以下架构的 pyspark 数据框

+----------+-------------------+-----------------------------------+------------------+
|      date|         numeric_id|                     feature_column|              city|
+----------+-------------------+-----------------------------------+------------------+
|2017-08-01|         2343434545|               [0.0, 0.0, 0.0, 0...|            Berlin|
|2017-08-01|         2343434545|               [0.0, 0.0, 0.0, 0...|              Rome|
|2017-08-01|         2343434545|               [0.0, 0.0, 0.0, 0...|           NewYork|
|2017-08-01|         2343434545|               [0.0, 0.0, 0.0, 0...|           Beijing|
|2019-12-01|         6455534545|               [0.0, 0.0, 0.0, 0...|            Berlin|
|2019-12-01|         6455534545|               [0.0, 0.0, 0.0, 0...|              Rome|
|2019-12-01|         6455534545|               [0.0, 0.0, 0.0, 0...|           NewYork|
|2019-12-01|         6455534545|               [0.0, 0.0, 0.0, 0...|           Beijing|
+----------+-------------------+-----------------------------------+------------------+

我想旋转数据框,以便我可以将每个 feature_column x city 作为一个新列,按 datenumeric_id 分组。输出数据框应该看起来像

+----------+-------------+----------------------+--------------------+-----------------------+----------------------+
|      date|   numeric_id| feature_column_Berlin| feature_column_Rome| feature_column_NewYork|feature_column_Beijing|
+----------+-------------+----------------------+--------------------+-----------------------+----------------------+
|2017-08-01|   2343434545|  [0.0, 0.0, 0.0, 0...|[0.0, 0.0, 0.0, 0...|[0.0, 0.0, 0.0, 0...   |[0.0, 0.0, 0.0, 0...  |
|2019-12-01|   6455534545|  [0.0, 0.0, 0.0, 0...|[0.0, 0.0, 0.0, 0...|[0.0, 0.0, 0.0, 0...   |[0.0, 0.0, 0.0, 0...  |
+----------+-------------+----------------------+--------------------+-----------------------+----------------------+

这与发布在旋转字符串 上的问题不同,因为我正在处理 ArrayType 列。 我认为在 Pandas 中实现它会更容易(但处理 ArrayType 列会很棘手),所以很好奇如何使用 spark SQL 来实现它。有什么建议吗?

//Initially I am creating the sample data to load the data in dataframe.
import org.apache.spark.sql.functions._
val df = Seq(("2017-08-01","2343434545",Array("0.0","0.0","0.0","0.0"),"Berlin"),("2017-08-01","2343434545",Array("0.0","0.0","0.0","0.0"),"Rome"),("2017-08-01","2343434545",Array("0.0","0.0","0.0","0.0"),"NewYork"),("2017-08-01","2343434545",Array("0.0","0.0","0.0","0.0"),"Beijing"),("2019-12-01","6455534545",Array("0.0","0.0","0.0","0.0"),"Berlin"),("2019-12-01","6455534545",Array("0.0","0.0","0.0","0.0"),"Rome"),("2019-12-01","6455534545",Array("0.0","0.0","0.0","0.0"),"NewYork"),("2019-12-01","6455534545",Array("0.0","0.0","0.0","0.0"),"Beijing"))
.toDF("date","numeric_id","feature_column","city")

df.groupBy("date","numeric_id").pivot("city")
 .agg(collect_list("feature_column"))
.withColumnRenamed("Beijing","feature_column_Beijing")
.withColumnRenamed("Berlin","feature_column_Berlin")
.withColumnRenamed("NewYork","feature_column_NewYork")
.withColumnRenamed("Rome","feature_column_Rome").show()

您可以看到如下输出: