Spark:如何在 A 的 ID 数组列不包含 B 的 ID 列的情况下连接两个“数据集”的 A 和 B?
Spark: How to join two `Dataset`s A and B with the condition that an ID array column of A does NOT contain the ID column of B?
我的问题不是 [ 的重复问题。我的问题是关于 "not in",而不是 "is in"。这是不同的!
我有两个 Dataset
:
userProfileDataset
: Dataset[UserProfile]
jobModelsDataset
: Dataset[JobModel]
案例分类 UserProfile
定义为
case class UserProfile(userId: Int, visitedJobIds: Array[Int])
案例 class JobModel
定义为
case class JobModel(JobId: Int, Model: Map[String, Double])
我还制作了两个对象(UserProfileFieldNames
和 JobModelFieldNames
),其中包含这两种情况的字段名称 classes.
我的 objective 是,对于 userProfileDataset
中的每个用户,找到 UserProfile.visitedJobIds
[=75= 中不包含的 JobModel.JobId
].
怎么做?
我考虑过先使用 crossJoin
,然后再使用 filter
。它可能会起作用。有没有更直接或更有效的方法?
我尝试了以下方法,但 none 有效:
val result = userProfileDataset.joinWith(jobModelsDataset,
!userProfileDataset.col(UserProfileFieldNames.visitedJobIds).contains(jobModelsDataset.col(JobModelFieldNames.jobId)),
"left_outer"
)
它导致:
Exception in thread "main" org.apache.spark.sql.AnalysisException:
cannot resolve 'contains(_1
.visitedJobIds
, CAST(_2
.JobId
AS
STRING))' due to data type mismatch: argument 1 requires string type,
however, '_1
.visitedJobIds
' is of array type.;;
难道是因为contains
方法只能用来检测一个字符串是否包含另一个字符串?
以下条件也不成立:
!jobModelsDataset.col(JobModelFieldNames.jobId)
.isin(userProfileDataset.col(UserProfileFieldNames.visitedJobIds))
它导致:
Exception in thread "main" org.apache.spark.sql.AnalysisException:
cannot resolve '(_2
.JobId
IN (_1
.visitedJobIds
))' due to data
type mismatch: Arguments must be same type but were: IntegerType !=
ArrayType(IntegerType,false);; 'Join LeftOuter, NOT _2#74.JobId IN
(_1#73.visitedJobIds)
如果unique job id的个数不是太多,可以按如下方式收集广播
val jobIds = jobModelsDataset.map(_.JobId).distinct.collect().toSeq
val broadcastedJobIds = spark.sparkContext.broadcast(jobIds)
要将此广播序列与 visitedJobIds
列进行比较,您可以创建一个 UDF
val notVisited = udf((visitedJobs: Seq[Int]) => {
broadcastedJobIds.value.filterNot(visitedJobs.toSet)
})
val df = userProfileDataset.withColumn("jobsToDo", notVisited($"visitedJobIds"))
使用 jobIds = 1,2,3,4,5
和示例数据框
进行测试
+------+---------------+
|userId| visitedJobIds|
+------+---------------+
| 1| [1, 2, 3]|
| 2| [3, 4, 5]|
| 3|[1, 2, 3, 4, 5]|
+------+---------------+
将给出最终的数据帧
+------+---------------+--------+
|userId| visitedJobIds|jobsToDo|
+------+---------------+--------+
| 1| [1, 2, 3]| [4, 5]|
| 2| [3, 4, 5]| [1, 2]|
| 3|[1, 2, 3, 4, 5]| []|
+------+---------------+--------+
您可以简单地 explode
userProfileDataset
的 数组 列和 cast
到 IntegerType
到 join
与 jobModelsDataset
的 JobId
列 已经是 IntegerType。然后最后使用collect_list
内置函数得到最终结果
爆炸和铸造如下
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val temp = userProfileDataset.withColumn("visitedJobIds", explode(col("visitedJobIds")))
.withColumn("visitedJobIds", col("visitedJobIds").cast(IntegerType))
加入和收集如下
temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left")
.groupBy("userId")
.agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds"))
.show(false)
你应该得到你想要的东西
已更新
如果您正在寻找每个 userId
都没有关联的 JobIds
,那么您可以按照以下步骤进行操作。
val list = jobModelsDataset.select(collect_list("JobId")).rdd.first()(0).asInstanceOf[collection.mutable.WrappedArray[Int]]
def notContained = udf((array: collection.mutable.WrappedArray[Int]) => list.filter(x => !(array.contains(x))))
temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left")
.groupBy("userId")
.agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds"))
.withColumn("ModelJobIds", notContained(col("ModelJobIds")))
.show(false)
您可以通过 broadcasting
改进答案。
最初我有另一种方法,它使用 crossJoin
然后 filter
:
val result = userProfileDataset
.crossJoin(jobModelsDataset) // 27353040 rows
.filter(row => !row(2).asInstanceOf[Seq[Int]].contains(row.getInt(3))) //27352633 rows
如果我使用@Shaido 的方法,那么explode
,我应该能够获得与此方法相同的结果。然而,即使在我的情况下使用 filter
这种方法也非常昂贵(我已经比较了经过的时间)。 explain
方法也可以打印出 Physical Plan。
所以我不会使用 crossJoin 方法。我只想 post 并保留在这里。
我的问题不是 [ 的重复问题。我的问题是关于 "not in",而不是 "is in"。这是不同的!
我有两个 Dataset
:
userProfileDataset
:Dataset[UserProfile]
jobModelsDataset
:Dataset[JobModel]
案例分类 UserProfile
定义为
case class UserProfile(userId: Int, visitedJobIds: Array[Int])
案例 class JobModel
定义为
case class JobModel(JobId: Int, Model: Map[String, Double])
我还制作了两个对象(UserProfileFieldNames
和 JobModelFieldNames
),其中包含这两种情况的字段名称 classes.
我的 objective 是,对于 userProfileDataset
中的每个用户,找到 UserProfile.visitedJobIds
[=75= 中不包含的 JobModel.JobId
].
怎么做?
我考虑过先使用 crossJoin
,然后再使用 filter
。它可能会起作用。有没有更直接或更有效的方法?
我尝试了以下方法,但 none 有效:
val result = userProfileDataset.joinWith(jobModelsDataset,
!userProfileDataset.col(UserProfileFieldNames.visitedJobIds).contains(jobModelsDataset.col(JobModelFieldNames.jobId)),
"left_outer"
)
它导致:
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve 'contains(
_1
.visitedJobIds
, CAST(_2
.JobId
AS STRING))' due to data type mismatch: argument 1 requires string type, however, '_1
.visitedJobIds
' is of array type.;;
难道是因为contains
方法只能用来检测一个字符串是否包含另一个字符串?
以下条件也不成立:
!jobModelsDataset.col(JobModelFieldNames.jobId)
.isin(userProfileDataset.col(UserProfileFieldNames.visitedJobIds))
它导致:
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '(
_2
.JobId
IN (_1
.visitedJobIds
))' due to data type mismatch: Arguments must be same type but were: IntegerType != ArrayType(IntegerType,false);; 'Join LeftOuter, NOT _2#74.JobId IN (_1#73.visitedJobIds)
如果unique job id的个数不是太多,可以按如下方式收集广播
val jobIds = jobModelsDataset.map(_.JobId).distinct.collect().toSeq
val broadcastedJobIds = spark.sparkContext.broadcast(jobIds)
要将此广播序列与 visitedJobIds
列进行比较,您可以创建一个 UDF
val notVisited = udf((visitedJobs: Seq[Int]) => {
broadcastedJobIds.value.filterNot(visitedJobs.toSet)
})
val df = userProfileDataset.withColumn("jobsToDo", notVisited($"visitedJobIds"))
使用 jobIds = 1,2,3,4,5
和示例数据框
+------+---------------+
|userId| visitedJobIds|
+------+---------------+
| 1| [1, 2, 3]|
| 2| [3, 4, 5]|
| 3|[1, 2, 3, 4, 5]|
+------+---------------+
将给出最终的数据帧
+------+---------------+--------+
|userId| visitedJobIds|jobsToDo|
+------+---------------+--------+
| 1| [1, 2, 3]| [4, 5]|
| 2| [3, 4, 5]| [1, 2]|
| 3|[1, 2, 3, 4, 5]| []|
+------+---------------+--------+
您可以简单地 explode
userProfileDataset
的 数组 列和 cast
到 IntegerType
到 join
与 jobModelsDataset
的 JobId
列 已经是 IntegerType。然后最后使用collect_list
内置函数得到最终结果
爆炸和铸造如下
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val temp = userProfileDataset.withColumn("visitedJobIds", explode(col("visitedJobIds")))
.withColumn("visitedJobIds", col("visitedJobIds").cast(IntegerType))
加入和收集如下
temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left")
.groupBy("userId")
.agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds"))
.show(false)
你应该得到你想要的东西
已更新
如果您正在寻找每个 userId
都没有关联的 JobIds
,那么您可以按照以下步骤进行操作。
val list = jobModelsDataset.select(collect_list("JobId")).rdd.first()(0).asInstanceOf[collection.mutable.WrappedArray[Int]]
def notContained = udf((array: collection.mutable.WrappedArray[Int]) => list.filter(x => !(array.contains(x))))
temp.join(jobModelsDataset, temp("visitedJobIds") === jobModelsDataset("JobId"), "left")
.groupBy("userId")
.agg(collect_list("visitedJobIds").as("visitedJobIds"), collect_list("JobId").as("ModelJobIds"))
.withColumn("ModelJobIds", notContained(col("ModelJobIds")))
.show(false)
您可以通过 broadcasting
改进答案。
最初我有另一种方法,它使用 crossJoin
然后 filter
:
val result = userProfileDataset
.crossJoin(jobModelsDataset) // 27353040 rows
.filter(row => !row(2).asInstanceOf[Seq[Int]].contains(row.getInt(3))) //27352633 rows
如果我使用@Shaido 的方法,那么explode
,我应该能够获得与此方法相同的结果。然而,即使在我的情况下使用 filter
这种方法也非常昂贵(我已经比较了经过的时间)。 explain
方法也可以打印出 Physical Plan。
所以我不会使用 crossJoin 方法。我只想 post 并保留在这里。