聚合 Spark 数据帧内映射中的结构数组

Aggregation on an array of structs in a map inside a Spark dataframe

我为冗长的标题道歉,但我真的想不出更好的东西。

基本上,我的数据具有以下架构:

 |-- id: string (nullable = true)
 |-- mainkey: map (nullable = true)
 |    |-- key: string
 |    |-- value: array (valueContainsNull = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- price: double (nullable = true)
 |    |    |    |-- recordtype: string (nullable = true)

让我使用以下示例数据:

{"id":1, "mainkey":{"key1":[{"price":0.01,"recordtype":"BID"}],"key2":[{"price":4.3,"recordtype":"FIXED"}],"key3":[{"price":2.0,"recordtype":"BID"}]}}
{"id":2, "mainkey":{"key4":[{"price":2.50,"recordtype":"BID"}],"key5":[{"price":2.4,"recordtype":"BID"}],"key6":[{"price":0.19,"recordtype":"BID"}]}}

对于上面的两条记录,当记录类型为"BID"时,我想计算所有价格的平均值。因此,对于第一条记录("id":1),我们有 2 个这样的出价,价格分别为 0.01 和 2.0,因此四舍五入到小数点后两位的平均值为 1.01。对于第二条记录("id":2),有 3 个出价,价格分别为 2.5、2.4 和 0.19,平均值为 1.70。所以我想要以下输出:

+---+---------+
| id|meanvalue|
+---+---------+
|  1|     1.01|
|  2|      1.7|
+---+---------+

使用以下代码:

val exSchema = (new StructType().add("id", StringType).add("mainkey", MapType(StringType, new ArrayType(new StructType().add("price", DoubleType).add("recordtype", StringType), true))))
val exJsonDf = spark.read.schema(exSchema).json("file:///data/json_example")
var explodeExJson = exJsonDf.select($"id",explode($"mainkey")).explode($"value") {
    case Row(recordValue: Seq[Row] @unchecked ) => recordValue.map{ recordValue =>
    val price = recordValue(0).asInstanceOf[Double]
    val recordtype = recordValue(1).asInstanceOf[String]
    RecordValue(price, recordtype)
    }
    }.cache()

val filteredExJson = explodeExJson.filter($"recordtype"==="BID")

val aggExJson = filteredExJson.groupBy("id").agg(round(mean("price"),2).alias("meanvalue")) 

问题是它使用了 "expensive" 爆炸操作,当我处理大量数据时,尤其是当地图中可能有很多键时,它就成了一个问题。

如果您能想到更简单的解决方案(使用 UDF 或其他方式),请告诉我。还请记住,我是 Spark 的初学者,因此可能错过了一些对您来说很明显的东西。

任何帮助将不胜感激。提前致谢!

如果聚合仅限于单个 Row udf 将解决此问题:

import org.apache.spark.util.StatCounter
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row

val meanPrice =  udf((map: Map[String, Seq[Row]]) => {
  val prices = map.values
    .flatMap(x => x)
    .filter(_.getAs[String]("recordtype") == "BID")
    .map(_.getAs[Double]("price"))
  StatCounter(prices).mean
})

df.select($"id", meanPrice($"mainkey"))