聚合 Spark 数据帧内映射中的结构数组
Aggregation on an array of structs in a map inside a Spark dataframe
我为冗长的标题道歉,但我真的想不出更好的东西。
基本上,我的数据具有以下架构:
|-- id: string (nullable = true)
|-- mainkey: map (nullable = true)
| |-- key: string
| |-- value: array (valueContainsNull = true)
| | |-- element: struct (containsNull = true)
| | | |-- price: double (nullable = true)
| | | |-- recordtype: string (nullable = true)
让我使用以下示例数据:
{"id":1, "mainkey":{"key1":[{"price":0.01,"recordtype":"BID"}],"key2":[{"price":4.3,"recordtype":"FIXED"}],"key3":[{"price":2.0,"recordtype":"BID"}]}}
{"id":2, "mainkey":{"key4":[{"price":2.50,"recordtype":"BID"}],"key5":[{"price":2.4,"recordtype":"BID"}],"key6":[{"price":0.19,"recordtype":"BID"}]}}
对于上面的两条记录,当记录类型为"BID"时,我想计算所有价格的平均值。因此,对于第一条记录("id":1),我们有 2 个这样的出价,价格分别为 0.01 和 2.0,因此四舍五入到小数点后两位的平均值为 1.01。对于第二条记录("id":2),有 3 个出价,价格分别为 2.5、2.4 和 0.19,平均值为 1.70。所以我想要以下输出:
+---+---------+
| id|meanvalue|
+---+---------+
| 1| 1.01|
| 2| 1.7|
+---+---------+
使用以下代码:
val exSchema = (new StructType().add("id", StringType).add("mainkey", MapType(StringType, new ArrayType(new StructType().add("price", DoubleType).add("recordtype", StringType), true))))
val exJsonDf = spark.read.schema(exSchema).json("file:///data/json_example")
var explodeExJson = exJsonDf.select($"id",explode($"mainkey")).explode($"value") {
case Row(recordValue: Seq[Row] @unchecked ) => recordValue.map{ recordValue =>
val price = recordValue(0).asInstanceOf[Double]
val recordtype = recordValue(1).asInstanceOf[String]
RecordValue(price, recordtype)
}
}.cache()
val filteredExJson = explodeExJson.filter($"recordtype"==="BID")
val aggExJson = filteredExJson.groupBy("id").agg(round(mean("price"),2).alias("meanvalue"))
问题是它使用了 "expensive" 爆炸操作,当我处理大量数据时,尤其是当地图中可能有很多键时,它就成了一个问题。
如果您能想到更简单的解决方案(使用 UDF 或其他方式),请告诉我。还请记住,我是 Spark 的初学者,因此可能错过了一些对您来说很明显的东西。
任何帮助将不胜感激。提前致谢!
如果聚合仅限于单个 Row
udf
将解决此问题:
import org.apache.spark.util.StatCounter
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row
val meanPrice = udf((map: Map[String, Seq[Row]]) => {
val prices = map.values
.flatMap(x => x)
.filter(_.getAs[String]("recordtype") == "BID")
.map(_.getAs[Double]("price"))
StatCounter(prices).mean
})
df.select($"id", meanPrice($"mainkey"))
我为冗长的标题道歉,但我真的想不出更好的东西。
基本上,我的数据具有以下架构:
|-- id: string (nullable = true)
|-- mainkey: map (nullable = true)
| |-- key: string
| |-- value: array (valueContainsNull = true)
| | |-- element: struct (containsNull = true)
| | | |-- price: double (nullable = true)
| | | |-- recordtype: string (nullable = true)
让我使用以下示例数据:
{"id":1, "mainkey":{"key1":[{"price":0.01,"recordtype":"BID"}],"key2":[{"price":4.3,"recordtype":"FIXED"}],"key3":[{"price":2.0,"recordtype":"BID"}]}}
{"id":2, "mainkey":{"key4":[{"price":2.50,"recordtype":"BID"}],"key5":[{"price":2.4,"recordtype":"BID"}],"key6":[{"price":0.19,"recordtype":"BID"}]}}
对于上面的两条记录,当记录类型为"BID"时,我想计算所有价格的平均值。因此,对于第一条记录("id":1),我们有 2 个这样的出价,价格分别为 0.01 和 2.0,因此四舍五入到小数点后两位的平均值为 1.01。对于第二条记录("id":2),有 3 个出价,价格分别为 2.5、2.4 和 0.19,平均值为 1.70。所以我想要以下输出:
+---+---------+
| id|meanvalue|
+---+---------+
| 1| 1.01|
| 2| 1.7|
+---+---------+
使用以下代码:
val exSchema = (new StructType().add("id", StringType).add("mainkey", MapType(StringType, new ArrayType(new StructType().add("price", DoubleType).add("recordtype", StringType), true))))
val exJsonDf = spark.read.schema(exSchema).json("file:///data/json_example")
var explodeExJson = exJsonDf.select($"id",explode($"mainkey")).explode($"value") {
case Row(recordValue: Seq[Row] @unchecked ) => recordValue.map{ recordValue =>
val price = recordValue(0).asInstanceOf[Double]
val recordtype = recordValue(1).asInstanceOf[String]
RecordValue(price, recordtype)
}
}.cache()
val filteredExJson = explodeExJson.filter($"recordtype"==="BID")
val aggExJson = filteredExJson.groupBy("id").agg(round(mean("price"),2).alias("meanvalue"))
问题是它使用了 "expensive" 爆炸操作,当我处理大量数据时,尤其是当地图中可能有很多键时,它就成了一个问题。
如果您能想到更简单的解决方案(使用 UDF 或其他方式),请告诉我。还请记住,我是 Spark 的初学者,因此可能错过了一些对您来说很明显的东西。
任何帮助将不胜感激。提前致谢!
如果聚合仅限于单个 Row
udf
将解决此问题:
import org.apache.spark.util.StatCounter
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row
val meanPrice = udf((map: Map[String, Seq[Row]]) => {
val prices = map.values
.flatMap(x => x)
.filter(_.getAs[String]("recordtype") == "BID")
.map(_.getAs[Double]("price"))
StatCounter(prices).mean
})
df.select($"id", meanPrice($"mainkey"))