如何过滤 spark 数据框中的结构数组?

How to filter a struct array in a spark dataframe?

我有以下代码和输出。

import org.apache.spark.sql.functions.{collect_list, struct}
import sqlContext.implicits._

val df = Seq(
  ("john", "tomato", 1.99),
  ("john", "carrot", 0.45),
  ("bill", "apple", 0.99),
  ("john", "banana", 1.29),
  ("bill", "taco", 2.59)
).toDF("name", "food", "price")

df.groupBy($"name")
  .agg(collect_list(struct($"food", $"price")).as("foods"))
  .show(false)

df.printSchema

输出和架构:

+----+---------------------------------------------+
|name|foods                                        |
+----+---------------------------------------------+
|john|[[tomato,1.99], [carrot,0.45], [banana,1.29]]|
|bill|[[apple,0.99], [taco,2.59]]                  |
+----+---------------------------------------------+

root
 |-- name: string (nullable = true)
 |-- foods: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- food: string (nullable = true)
 |    |    |-- price: double (nullable = false)

我想根据 df("foods.price") > 1.00 进行过滤。我如何过滤它以获得下面的输出?

+----+---------------------------------------------+
|name|foods                                        |
+----+---------------------------------------------+
|john|[[banana,1.29], [tomato,1.99]]               |
|bill|[[[taco,2.59]]                               |
+----+---------------------------------------------+

我已尝试 df.filter($"foods.food" > 1.00),但这不起作用,因为我遇到了错误。还有什么我可以尝试的吗?

您正在尝试对数组应用过滤器,因此它会抛出错误,因为语法有误。您可以先对价格应用过滤器,然后根据需要进行转换。

val cf = df.filter("price > 1.0").groupBy($"name").agg(collect_list(struct($"food", $"price")).as("foods")