重命名 Spark DataFrame 的重复列?

Rename Duplicate Columns of a Spark DataFrame?

关于管理连接的数据帧中的重复列有几个很好的答案,例如 (How to avoid duplicate columns after join?),但是如果我只是简单地看到一个包含我必须处理的重复列的数据帧怎么办。我无法控制导致这一点的过程。

我有:

val data = Seq((1,2),(3,4)).toDF("a","a")
data.show

+---+---+
|  a|  a|
+---+---+
|  1|  2|
|  3|  4|
+---+---+

我想要的:

+---+---+
|  a|a_2|
+---+---+
|  1|  2|
|  3|  4|
+---+---+

withColumnRenamed("a","a_2") 不起作用,原因很明显。

我发现最简单的方法是:

val data = Seq((1,2),(3,4)).toDF("a","a")
val deduped = data.toDF("a","a_2")
deduped.show

+---+---+
|  a|a_2|
+---+---+
|  1|  2|
|  3|  4|
+---+---+

更通用的解决方案:

val data = Seq(
  (1,2,3,4,5,6,7,8),
  (9,0,1,2,3,4,5,6)
).toDF("a","b","c","a","d","b","e","b")
data.show

+---+---+---+---+---+---+---+---+
|  a|  b|  c|  a|  d|  b|  e|  b|
+---+---+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|  7|  8|
|  9|  0|  1|  2|  3|  4|  5|  6|
+---+---+---+---+---+---+---+---+

import scala.annotation.tailrec

def dedupeColumnNames(df: DataFrame): DataFrame = {
  
  @tailrec
  def dedupe(fixed_columns: List[String], columns: List[String]): List[String] = {
    if (columns.isEmpty) fixed_columns
    else {
      val count = columns.groupBy(identity).mapValues(_.size)(columns.head)
      if (count == 1) dedupe(columns.head :: fixed_columns, columns.tail)
      else dedupe(s"${columns.head}_${count}":: fixed_columns, columns.tail)
    }
  }
  
  val new_columns = dedupe(List.empty[String], df.columns.reverse.toList).toArray
  df.toDF(new_columns:_*)
}

data
  .transform(dedupeColumnNames)
  .show

+---+---+---+---+---+---+---+---+
|  a|  b|  c|a_2|  d|b_2|  e|b_3|
+---+---+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|  7|  8|
|  9|  0|  1|  2|  3|  4|  5|  6|
+---+---+---+---+---+---+---+---+