spark scala - 按数组列分组

spark scala - Group by Array column

我是 spark scala 的新手。感谢你的帮助.. 我有一个数据框

val df = Seq(
  ("a", "a1", Array("x1","x2")), 
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
).toDF("k1", "k2", "k3")

我正在寻找一种方法来将它按 k1 和 k3 分组并将 k2 收集到一个数组中。 但是,k3 是一个数组,我需要应用包含(而不是精确 匹配)进行分组。换句话说,我正在寻找结果 像这样

k1   k3       k2                count
a   (x1,x2)   (a1,b1,c1,e1)     4
a    (x3)      (d1)             1
c    (x2)      (c3)             1

有人可以建议如何实现吗?

提前致谢!

我建议您按 k1 列分组收集 k2 和 k3 的结构列表pass收集到的列表到 udf 函数,用于计算 k3 中的数组何时包含在另一个 k3 数组中并添加 k2 的元素。

然后您可以使用 explodeselect 表达式来获得所需的输出

以下是完整的工作解决方案

val df = Seq(
  ("a", "a1", Array("x1","x2")),
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
  ).toDF("k1", "k2", "k3")

import org.apache.spark.sql.functions._
def containsGoupingUdf = udf((arr: Seq[Row]) => {
  val firstStruct =  arr.head
  val tailStructs =  arr.tail
  var result = Array((collection.mutable.Set(firstStruct.getAs[String]("k2")), firstStruct.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
  for(str <- tailStructs){
    var added = false
    for((res, index) <- result.zipWithIndex) {
      if (str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").exists(res._2) || res._2.exists(x => str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").contains(x))) {
        result(index) = (res._1 + str.getAs[String]("k2"), res._2 ++ str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, res._3 + 1)
        added = true
      }
    }
    if(!added){
      result = result ++ Array((collection.mutable.Set(str.getAs[String]("k2")), str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
    }
  }
  result.map(tuple => (tuple._1.toArray, tuple._2.toArray, tuple._3))
})

df.groupBy("k1").agg(containsGoupingUdf(collect_list(struct(col("k2"), col("k3")))).as("aggregated"))
    .select(col("k1"), explode(col("aggregated")).as("aggregated"))
    .select(col("k1"), col("aggregated._2").as("k3"), col("aggregated._1").as("k2"), col("aggregated._3").as("count"))
  .show(false)

哪个应该给你

+---+--------+----------------+-----+
|k1 |k3      |k2              |count|
+---+--------+----------------+-----+
|c  |[x2]    |[c3]            |1    |
|a  |[x1, x2]|[b1, e1, c1, a1]|4    |
|a  |[x3]    |[d1]            |1    |
+---+--------+----------------+-----+ 

希望回答对您有所帮助,您可以根据自己的需要进行修改