从 Scala 中的 ArrayType 列中提取值并重塑为 long

Extract value from ArrayType column in Scala and reshape to long

我有一个 DataFrame,它由 ArrayType 的 Column 组成,并且该数组在每一行数据中的长度可能不同。我在下面提供了一些示例代码,可以创建一些具有类似结构的模拟数据。

您会看到,对于一笔交易,我有一个交易 ID 以及一些额外的数据,每个数据都存储在一个“段”中。在这里,我们看到一个存储客户信息的段(总是一个长度为 2 的数组),并且我们有一个额外的段用于每个项目购买。购买商品本身的信息是一个长度不一的数组;数组的前两个元素始终是购买商品的 ID 和名称;颜色等可能存在其他数组元素,但在此用例中我们可以忽略它们。

val dfschema = new StructType()
  .add("transaction",
    new StructType()
      .add(
        "transaction_id",
        StringType
      )
      .add(
        "segments",
        ArrayType(
          new StructType()
            .add("segment_id",StringType)
            .add("segment_fields",ArrayType(
              StringType,
              false
            )
          ), false
        )
      )
    )


val mockdata = Seq(
  Row(
    Row(
      "2e6d57769e49ae8cb0c4105548c4389d",
      List(
        Row(
          "CustomerInformation",
          List(
            "SomeCustomerName",
            "SomeCustomerEmail"
          )
        ),
        Row(
          "ItemPurchased",
          List(
            "SomeItemID",
            "SomeItemName"
          )
        ),
        Row(
          "ItemPurchased",
          List(
            "AnotherItemID",
            "AnotherItemName",
            "ItemColor"
          )
        ),
        Row(
          "ItemPurchased",
          List(
            "YetAnotherItemID",
            "YetAnotherItemName",
            "ItemColor"
          )
        )
      )
    )
  )
)

val df = spark.createDataFrame(
  spark.sparkContext.parallelize(mockdata),
  dfschema)

我想要完成的是将上面的内容转换为另一个包含两列的数据框,一列用于客户名称,一列用于项目名称。对于上面的例子,它会是:

customer.name item.name
SomeCustomerName SomeItemName
SomeCustomerName AnotherItemName
SomeCustomerName YetAnotherItemName

但是,我不想对要检索的数据字段进行硬编码;相反,我想写几个函数,你可以 运行 作为 select 命令的一部分,像这样:

df(
  select(
    get_single_subsegment("CustomerInformation", 0),
    get_repeated_subsements("ItemPurchased", 1)
  )
)

这样,如果我选择检索客户邮箱而不是姓名,我只需要将上面的更改0修改为1即可。我什至可以将索引号作为变量传递。

这可以做到吗?

从 Spark 3.0 开始,您可以使用 spark 的 built-in 函数来定义您的两个函数 get_single_subsegmentget_repeated_subsegments

对于get_single_subsegment,你可以先用segment_id和filter过滤你的段数组,然后用getItem得到这个过滤数组的第一个元素,然后使用 getFieldgetItem:

检索此段对象中所需索引处的元素
import org.apache.spark.sql.functions.{col, filter}
import org.apache.spark.sql.Column

def get_single_subsegment(segmentId: String, index: Int): Column = {
  filter(col("transaction.segments"), c => c.getField("segment_id") === segmentId)
    .getItem(0)
    .getField("segment_fields")
    .getItem(index)
}

对于 get_repeated_subsegments,您首先像 get_single_subsegment 中那样进行过滤,然后使用 transform to extract right segment fields index for each elements of filtered array, and then explode 这个数组以便逐行过滤数组的元素:

import org.apache.spark.sql.functions.{col, explode, filter, transform}
import org.apache.spark.sql.Column

def get_repeated_subsegments(segmentId: String, index: Int): Column = {
  explode(
    transform(
      filter(col("transaction.segments"), c => c.getField("segment_id") === segmentId)
        .getField("segment_fields"),
      c => c.getItem(index)
    )
  )
}

如果我们在您的示例中应用上面定义的两个函数,我们会得到以下结果:

df.select(
  get_single_subsegment("CustomerInformation", 0).as("customer_name"),
  get_repeated_subsegments("ItemPurchased", 1).as("item_name")
).show(false)

// +----------------+------------------+
// |customer_name   |item_name         |
// +----------------+------------------+
// |SomeCustomerName|SomeItemName      |
// |SomeCustomerName|AnotherItemName   |
// |SomeCustomerName|YetAnotherItemName|
// +----------------+------------------+

奖金 - 提取多列

如果像您评论的那样,您想使用 get_repeated_subsegments 提取多个列,则需要修改 get_repeated_subsegments 以便不在其中执行 explode 但是当您执行select。然后,您可以在通过应用 get_repeated_subsegments 获取的数组上使用 arrays_zip 添加多个列,如下所示:

import org.apache.spark.sql.functions.{arrays_zip, col, explode, filter, transform}
import org.apache.spark.sql.Column

def get_repeated_subsegments(segmentId: String, index: Int): Column = {
  transform(
    filter(col("transaction.segments"), c => c.getField("segment_id") === segmentId)
      .getField("segment_fields"),
    c => c.getItem(index)
  )
}

df.select(
    get_single_subsegment("CustomerInformation", 0).as("customer_name"),
    explode(
      arrays_zip(
        get_repeated_subsegments("ItemPurchased", 0).as("item_id"),
        get_repeated_subsegments("ItemPurchased", 1).as("item_name")
      )
    ).alias("items")
  )
  .select("customer_name", "items.*")
  .show(false)

// +----------------+----------------+------------------+
// |customer_name   |item_id         |item_name         |
// +----------------+----------------+------------------+
// |SomeCustomerName|SomeItemID      |SomeItemName      |
// |SomeCustomerName|AnotherItemID   |AnotherItemName   |
// |SomeCustomerName|YetAnotherItemID|YetAnotherItemName|
// +----------------+----------------+------------------+