Spark:如何在 A 的 ID 数组列不包含 B 的 ID 列的情况下连接两个“数据集”的 A 和 B?

Spark: How to join two `Dataset`s A and B with the condition that an ID array column of A does NOT contain the ID column of B?

我的问题不是 [ 的重复问题。我的问题是关于 "not in",而不是 "is in"。这是不同的!

我有两个 Dataset

案例分类 UserProfile 定义为

case class UserProfile(userId: Int, visitedJobIds: Array[Int])

案例 class JobModel 定义为

case class JobModel(JobId: Int, Model: Map[String, Double])

我还制作了两个对象(UserProfileFieldNamesJobModelFieldNames),其中包含这两种情况的字段名称 classes.

我的 objective 是,对于 userProfileDataset 中的每个用户,找到 UserProfile.visitedJobIds[=75= 中不包含的 JobModel.JobId ]. 怎么做?

我考虑过先使用 crossJoin,然后再使用 filter。它可能会起作用。有没有更直接或更有效的方法?


我尝试了以下方法,但 none 有效:

val result = userProfileDataset.joinWith(jobModelsDataset,
      !userProfileDataset.col(UserProfileFieldNames.visitedJobIds).contains(jobModelsDataset.col(JobModelFieldNames.jobId)),
      "left_outer"
    )

它导致:

Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve 'contains(_1.visitedJobIds, CAST(_2.JobId AS STRING))' due to data type mismatch: argument 1 requires string type, however, '_1.visitedJobIds' is of array type.;;

难道是因为contains方法只能用来检测一个字符串是否包含另一个字符串?

以下条件也不成立:

!jobModelsDataset.col(JobModelFieldNames.jobId)
  .isin(userProfileDataset.col(UserProfileFieldNames.visitedJobIds))

它导致:

Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '(_2.JobId IN (_1.visitedJobIds))' due to data type mismatch: Arguments must be same type but were: IntegerType != ArrayType(IntegerType,false);; 'Join LeftOuter, NOT _2#74.JobId IN (_1#73.visitedJobIds)

如果unique job id的个数不是太多,可以按如下方式收集广播

val jobIds = jobModelsDataset.map(_.JobId).distinct.collect().toSeq
val broadcastedJobIds = spark.sparkContext.broadcast(jobIds)

要将此广播序列与 visitedJobIds 列进行比较,您可以创建一个 UDF

val notVisited = udf((visitedJobs: Seq[Int]) => { 
  broadcastedJobIds.value.filterNot(visitedJobs.toSet)
})

val df = userProfileDataset.withColumn("jobsToDo", notVisited($"visitedJobIds"))

使用 jobIds = 1,2,3,4,5 和示例数据框

进行测试
+------+---------------+
|userId|  visitedJobIds|
+------+---------------+
|     1|      [1, 2, 3]|
|     2|      [3, 4, 5]|
|     3|[1, 2, 3, 4, 5]|
+------+---------------+

将给出最终的数据帧

+------+---------------+--------+
|userId|  visitedJobIds|jobsToDo|
+------+---------------+--------+
|     1|      [1, 2, 3]|  [4, 5]|
|     2|      [3, 4, 5]|  [1, 2]|
|     3|[1, 2, 3, 4, 5]|      []|
+------+---------------+--------+

您可以简单地 explode userProfileDataset 数组 列和 castIntegerTypejoinjobModelsDatasetJobId 已经是 IntegerType。然后最后使用collect_list内置函数得到最终结果

爆炸铸造如下

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val temp = userProfileDataset.withColumn("visitedJobIds", explode(col("visitedJobIds")))
    .withColumn("visitedJobIds", col("visitedJobIds").cast(IntegerType))

加入收集如下

temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left")
      .groupBy("userId")
      .agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds"))
    .show(false)

你应该得到你想要的东西

已更新

如果您正在寻找每个 userId 都没有关联的 JobIds,那么您可以按照以下步骤进行操作。

val list = jobModelsDataset.select(collect_list("JobId")).rdd.first()(0).asInstanceOf[collection.mutable.WrappedArray[Int]]
def notContained = udf((array: collection.mutable.WrappedArray[Int]) => list.filter(x => !(array.contains(x))))
temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left")
      .groupBy("userId")
      .agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds"))
      .withColumn("ModelJobIds", notContained(col("ModelJobIds")))
    .show(false)

您可以通过 broadcasting 改进答案。

最初我有另一种方法,它使用 crossJoin 然后 filter:

val result = userProfileDataset
  .crossJoin(jobModelsDataset) // 27353040 rows
  .filter(row => !row(2).asInstanceOf[Seq[Int]].contains(row.getInt(3))) //27352633 rows

如果我使用@Shaido 的方法,那么explode,我应该能够获得与此方法相同的结果。然而,即使在我的情况下使用 filter 这种方法也非常昂贵(我已经比较了经过的时间)。 explain 方法也可以打印出 Physical Plan。

所以我不会使用 crossJoin 方法。我只想 post 并保留在这里。