如何展平结构数组类型的列(由 Spark ML API 返回)?

How to flatten columns of type array of structs (as returned by Spark ML API)?

也许这只是因为我对 API 比较陌生,但我觉得 Spark ML 方法通常 return DFs 很难使用。

这一次,让我失望的是 ALS 模型。特别是 recommendForAllUsers 方法。让我们重构它的 DF 类型 return:

scala> val arrayType = ArrayType(new StructType().add("itemId", IntegerType).add("rating", FloatType))

scala> val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations".cast(arrayType))

scala> recs.show()
+------+------------------+
|userId|   recommendations|
+------+------------------+
|     1|[[1,0.7], [2,0.5]]|
|     2|[[0,0.9], [4,0.1]]|
+------+------------------+
scala> recs.printSchema
root
 |-- userId: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- itemId: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)

现在,我只关心 recommendations 列中的 itemId。毕竟方法是recommendForAllUsers而不是recommendAndScoreForAllUsers(好吧好吧我就不啰嗦了...)

我该怎么做?

我以为我创建UDF的时候就有了:

scala> val itemIds = udf((arr: Array[(Int, Float)]) => arr.map(_._1))

但这会产生错误:

scala> recs.withColumn("items", items($"recommendations"))
org.apache.spark.sql.AnalysisException: cannot resolve 'UDF(recommendations)' due to data type mismatch: argument 1 requires array<struct<_1:int,_2:float>> type, however, '`recommendations`' is of array<struct<itemId:int,rating:float>> type.;;
'Project [userId#87, recommendations#92, UDF(recommendations#92) AS items#238]
+- Project [userId#87, cast(recommendations#88 as array<struct<itemId:int,rating:float>>) AS recommendations#92]
   +- Project [_1#84 AS userId#87, _2#85 AS recommendations#88]
      +- LocalRelation [_1#84, _2#85]

有什么想法吗?谢谢!

以数组作为列的类型,例如recommendations,使用 explode 函数(或更高级的 flatMap 运算符)你会非常有效率。

explode(e: Column): Column Creates a new row for each element in the given array or map column.

这为您提供了可以使用的裸结构。

import org.apache.spark.sql.types._
val structType = new StructType().
  add($"itemId".int).
  add($"rating".float)
val arrayType = ArrayType(structType)
val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations" cast arrayType)

val exploded = recs.withColumn("recs", explode($"recommendations"))
scala> exploded.show
+------+------------------+-------+
|userId|   recommendations|   recs|
+------+------------------+-------+
|     1|[[1,0.7], [2,0.5]]|[1,0.7]|
|     1|[[1,0.7], [2,0.5]]|[2,0.5]|
|     2|[[0,0.9], [4,0.1]]|[0,0.9]|
|     2|[[0,0.9], [4,0.1]]|[4,0.1]|
+------+------------------+-------+

结构在 select 运算符中很好用 *(星号)将它们展平为每个结构字段的列。

你可以做到 select($"element.*")

scala> exploded.select("userId", "recs.*").show
+------+------+------+
|userId|itemId|rating|
+------+------+------+
|     1|     1|   0.7|
|     1|     2|   0.5|
|     2|     0|   0.9|
|     2|     4|   0.1|
+------+------+------+

我认为这可以满足您的需求。


p.s。尽可能远离 UDF,因为它们 "trigger" 从内部格式 (InternalRow) 到 JVM 对象的行转换会导致过多的 GC。

哇,我的同事想出了一个非常优雅的解决方案:

scala> recs.select($"userId", $"recommendations.itemId").show
+------+------+
|userId|itemId|
+------+------+
|     1|[1, 2]|
|     2|[0, 4]|
+------+------+

所以也许 Spark ML API 并没有那么难 :)