在 Spark Scala 中迭代数据框列 Array of Array

Iterate dataframe column Array of Array in Spark Scala

我正在尝试迭代数组数组作为 Spark 数据帧中的列。正在寻找执行此操作的最佳方法。

架构:

root
 |-- Animal: struct (nullable = true)
 |    |-- Species: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- mammal: array (nullable = true)
 |    |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |    |-- description: string (nullable = true)

目前我正在使用这个逻辑。这只获取第一个数组。

df.select(
   col("Animal.Species").getItem(0).getItem("mammal").getItem("description")
)

伪逻辑:

col("Animal.Species").getItem(0).getItem("mammal").getItem("description")
+
col("Animal.Species").getItem(1).getItem("mammal").getItem("description")
+
col("Animal.Species").getItem(2).getItem("mammal").getItem("description")
+
col("Animal.Species").getItem(...).getItem("mammal").getItem("description")

期望的示例输出(扁平元素作为字符串)

llama, sheep, rabbit, hare

您可以申请 explode 两次:第一次在 Animal.Species 上,第二次在第一次的结果上:

import org.apache.spark.sql.functions._
df.withColumn("tmp", explode(col("Animal.Species")))
  .withColumn("tmp", explode(col("tmp.mammal")))
  .select("tmp.description")
  .show()

不明显,但您可以使用 .(或 ColumngetField 方法)select“通过”结构数组。选择 Animal.Species.mammal returns 最内层结构的数组。不幸的是,这个数组数组阻止您使用 Animal.Species.mammal.description 之类的东西进一步向下钻取,因此您需要先将其展平,然后再使用 getField().

如果我正确理解您的架构,以下 JSON 应该是有效的输入:

{
  "Animal": {
    "Species": [
      {
        "mammal": [
          { "description": "llama" },
          { "description": "sheep" }
        ]
      },
      {
        "mammal": [
          { "description": "rabbit" },
          { "description": "hare" }
        ]
      }
    ]
  }
}
val df = spark.read.json("data.json")
df.printSchema
// root
//  |-- Animal: struct (nullable = true)
//  |    |-- Species: array (nullable = true)
//  |    |    |-- element: struct (containsNull = true)
//  |    |    |    |-- mammal: array (nullable = true)
//  |    |    |    |    |-- element: struct (containsNull = true)
//  |    |    |    |    |    |-- description: string (nullable = true)

df.select("Animal.Species.mammal").show(false)
// +----------------------------------------+
// |mammal                                  |
// +----------------------------------------+
// |[[{llama}, {sheep}], [{rabbit}, {hare}]]|
// +----------------------------------------+

df.select(flatten(col("Animal.Species.mammal"))).show(false)
// +------------------------------------+
// |flatten(Animal.Species.mammal)      |
// +------------------------------------+
// |[{llama}, {sheep}, {rabbit}, {hare}]|
// +------------------------------------+

现在这是一个结构数组,您可以使用 getField("description") 获取感兴趣的数组:

df.select(flatten(col("Animal.Species.mammal")).getField("description")).show(false)
// +--------------------------------------------------------+
// |flatten(Animal.Species.mammal AS mammal#173).description|
// +--------------------------------------------------------+
// |[llama, sheep, rabbit, hare]                            |
// +--------------------------------------------------------+

最后,array_join加上分隔符", "就可以得到想要的字符串:

df.select(
  array_join(
    flatten(col("Animal.Species.mammal")).getField("description"),
    ", "
  ) as "animals"
).show(false)
// +--------------------------+
// |animals                   |
// +--------------------------+
// |llama, sheep, rabbit, hare|
// +--------------------------+