在 null Safe Join scala spark 之后避免重复的列

Avoiding duplicate coulmns after nullSafeJoin scala spark

我有一个用例,其中我需要加入 nullable 列。我正在做同样的事情:

  def nullSafeJoin(leftDF: DataFrame, rightDF: DataFrame, joinOnColumns: Seq[String]) = {

    val dataset1 = leftDF.alias("dataset1")
    val dataset2 = rightDF.alias("dataset2")

    val firstColumn = joinOnColumns.head
    val colExpression: Column = (col(s"dataset1.$firstColumn").eqNullSafe(col(s"dataset2.$firstColumn")))

    val fullExpr = joinOnColumns.tail.foldLeft(colExpression) {
      (colExpression, p) => colExpression && (col(s"dataset1.$p").eqNullSafe(col(s"dataset2.$p")))
    }
    dataset1.join(dataset2, fullExpr)
  }

最终连接的数据集有重复的列。我试过使用这样的别名删除列:

dataset1.join(dataset2, fullExpr).drop(s"dataset2.$firstColumn")

但是没用。
我知道我们可以做一个 select 列而不是删除。

我正在尝试拥有一个通用代码库,因此不想将要 selected 的列列表传递给函数(在删除的情况下,我将不得不删除列表joinOnColumns 我们已经传递给函数)

任何关于如何解决这个问题的建议都会很有帮助。 谢谢!

编辑:(示例数据)

leftDF :
+------------------+-----------+---------+---------+-------+
|                 A|          B|        C|        D| status|
+------------------+-----------+---------+---------+-------+
|             14567|         37|        1|     game|Enabled|
|             14567|       BASE|        1|      toy| Paused|
|             13478|       null|        5|     game|Enabled|
|              2001|       BASE|        1|     null| Paused|
|              null|         37|        1|     home|Enabled|
+------------------+-----------+---------+---------+-------+

rightDF :
+------------------+-----------+---------+
|                 A|          B|        C|
+------------------+-----------+---------+
|               140|         37|        1|
|               569|       BASE|        1|
|             13478|       null|        5|
|              2001|       BASE|        1|
|              null|         37|        1|
+------------------+-----------+---------+

Final Join (Required):
+------------------+-----------+---------+---------+-------+
|                 A|          B|        C|        D| status|
+------------------+-----------+---------+---------+-------+
|             13478|       null|        5|     game|Enabled|
|              2001|       BASE|        1|     null| Paused|
|              null|         37|        1|     home|Enabled|
+------------------+-----------+---------+---------+-------+

您的最终 DataFrame 具有来自 leftDF 和 rightDF 的重复列,没有标识符来检查该列是来自 leftDF 还是来自 rightDF。

所以我重命名了 leftDF 和 rightDF 列。 leftDF 列以 left_[column_name] 开头,rightDF 列以 right_[column_name]

开头

希望下面的代码对您有所帮助。

scala> :paste
// Entering paste mode (ctrl-D to finish)

  val left = Seq(("14567", "37", "1", "game", "Enabled"), ("14567", "BASE", "1", "toy", "Paused"), ("13478", "null", "5", "game", "Enabled"), ("2001", "BASE", "1", "null", "Paused"), ("null", "37", "1", "home", "Enabled")).toDF("a", "b", "c", "d", "status")
  val right = Seq(("140", "37", 1), ("569", "BASE", 1), ("13478", "null", 5), ("2001", "BASE", 1), ("null", "37", 1)).toDF("a", "b", "c")

  import org.apache.spark.sql.DataFrame
  def nullSafeJoin(leftDF: DataFrame, rightDF: DataFrame, joinOnColumns: Seq[String]):DataFrame = {
    val leftRenamedDF = leftDF
      .columns
      .map(c => (c, s"left_${c}"))
      .foldLeft(leftDF){ (df, c) =>
        df.withColumnRenamed(c._1, c._2)
      }
    val rightRenamedDF = rightDF
      .columns
      .map(c => (c, s"right_${c}"))
      .foldLeft(rightDF){(df, c) =>
        df.withColumnRenamed(c._1, c._2)
      }

    val fullExpr = joinOnColumns
      .tail
    .foldLeft($"left_${joinOnColumns.head}".eqNullSafe($"right_${joinOnColumns.head}")){(cee, p) =>
        cee && ($"left_${p}".eqNullSafe($"right_${p}"))
      }

    val finalColumns = joinOnColumns
      .map(c => col(s"left_${c}").as(c)) ++ // Taking All columns from Join columns
      leftDF.columns.diff(joinOnColumns).map(c => col(s"left_${c}").as(c)) ++ // Taking missing columns from leftDF
      rightDF.columns.diff(joinOnColumns).map(c => col(s"right_${c}").as(c)) // Taking missing columns from rightDF

    leftRenamedDF.join(rightRenamedDF, fullExpr).select(finalColumns: _*)
  }

scala>

最终的 DataFrame 结果是:

scala> nullSafeJoin(left, right, Seq("a", "b", "c")).show(false)


// Exiting paste mode, now interpreting.

+-----+----+---+----+-------+
|a    |b   |c  |d   |status |
+-----+----+---+----+-------+
|13478|null|5  |game|Enabled|
|2001 |BASE|1  |null|Paused |
|null |37  |1  |home|Enabled|
+-----+----+---+----+-------+