数据框上的通用迭代器 (Spark/scala)

Generic iterator over dataframe (Spark/scala)

我需要按特定顺序遍历数据框并应用一些复杂的逻辑来计算新列。

在下面的示例中,我将使用简单的表达式,其中 s 的当前值是所有先前值的乘积,因此看起来这可以使用 UDF 甚至分析函数来完成。然而,实际上逻辑要复杂得多。

下面的代码完成了所需的工作

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.encoders.RowEncoder

val q = """
select 10 x, 1 y
union all select 10, 2
union all select 10, 3
union all select 20, 6
union all select 20, 4
union all select 20, 5
"""
val df = spark.sql(q)
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
  iter.scanLeft(Row(0,0,1)) {
    case (r1, r2) => {
      val (x1, y1, s1) = r1 match {case Row(x: Int, y: Int, s: Int) => (x, y, s)}
      val (x2, y2)     = r2 match {case Row(x: Int, y: Int) => (x, y)}
      Row(x2, y2, s1 * y2)
    }
  }.drop(1)
}
val schema = new StructType().
             add(StructField("x", IntegerType, true)).
             add(StructField("y", IntegerType, true)).
             add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

输出

scala> df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+---+
|  x|  y|  s|
+---+---+---+
| 20|  4|  4|
| 20|  5| 20|
| 20|  6|120|
| 10|  1|  1|
| 10|  2|  2|
| 10|  3|  6|
+---+---+---+

我不喜欢的是

1) 尽管 Spark 可以推断数据框的名称和类型,但我明确定义了模式

scala> df
res1: org.apache.spark.sql.DataFrame = [x: int, y: int]

2) 如果我向数据框添加任何新列,则我必须再次声明架构,更烦人的是 - 重新定义函数!

假设数据框中有新列 z。在这种情况下,我必须更改 f_row.

中的几乎每一行
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
  iter.scanLeft(Row(0,0,"",1)) {
    case (r1, r2) => {
      val (x1, y1, z1, s1) = r1 match {case Row(x: Int, y: Int, z: String, s: Int) => (x, y, z, s)}
      val (x2, y2, z2)     = r2 match {case Row(x: Int, y: Int, z: String) => (x, y, z)}
      Row(x2, y2, z2, s1 * y2)
    }
  }.drop(1)
}
val schema = new StructType().
             add(StructField("x", IntegerType, true)).
             add(StructField("y", IntegerType, true)).
             add(StructField("z", StringType, true)).
             add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

输出

scala> df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+-----+---+
|  x|  y|    z|  s|
+---+---+-----+---+
| 20|  4|dummy|  4|
| 20|  5|dummy| 20|
| 20|  6|dummy|120|
| 10|  1|dummy|  1|
| 10|  2|dummy|  2|
| 10|  3|dummy|  6|
+---+---+-----+---+

有没有办法以更通用的方式实现逻辑,这样我就不需要创建函数来遍历每个特定的数据帧? 或者至少避免在将新列添加到计算逻辑中未使用的数据框后更改代码。

请参阅下面更新的问题。

更新

下面是两个以更通用的方式进行迭代的选项,但仍然存在一些缺点。

// option 1
def f_row(iter: Iterator[Row]): Iterator[Row] = {
  val r = Row.fromSeq(Row(0, 0).toSeq :+ 1)
  iter.scanLeft(r)((r1, r2) => 
    Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y")))
  ).drop(1)
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

// option 2
def f_row(iter: Iterator[Row]): Iterator[Row] = {
  iter.map{
    var s = 1
    r => {
      s = s * r.getInt(r.fieldIndex("y"))
      Row.fromSeq(r.toSeq :+ s)
    }
  }
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

如果将新列添加到数据框,则必须在选项 1 中更改 iter.scanLeft 的初始值。此外,我不太喜欢选项 2,因为它使用可变变量。

有没有办法改进代码,使其成为纯粹的功能,并且在将新列添加到数据框时不需要进行任何更改?

嗯,下面是充分的解决方案

def f_row(iter: Iterator[Row]): Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r = Row.fromSeq(head.toSeq :+ head.getInt(head.fieldIndex("y")))
    iter.scanLeft(r)((r1, r2) => 
      Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y"))))
  } else iter
}
val encoder = 
  RowEncoder(StructType(df.schema.fields :+ StructField("s", IntegerType, false)))
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

更新

可以避免像 getInt 这样的函数,而使用更通用的函数 getAs

此外,为了能够按名称访问 r1 的行,我们可以生成 GenericRowWithSchema,它是 Row.

的子类

隐式参数已添加到 f_row 以便函数可以使用数据框的当前模式,同时它可以用作 mapPartitions.[=19 的参数=]

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.encoders.RowEncoder

implicit val schema = StructType(df.schema.fields :+ StructField("result", IntegerType))
implicit val encoder = RowEncoder(schema)

def mul(x1: Int, x2: Int) = x1 * x2;

def f_row(iter: Iterator[Row])(implicit currentSchema : StructType) : Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r =
      new GenericRowWithSchema((head.toSeq :+ (head.getAs("y"))).toArray, currentSchema)

    iter.scanLeft(r)((r1, r2) =>
      new GenericRowWithSchema((r2.toSeq :+ mul(r1.getAs("result"), r2.getAs("y"))).toArray, currentSchema))
  } else iter
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show

终于可以用尾递归的方式实现逻辑了

import scala.annotation.tailrec

def f_row(iter: Iterator[Row]) = {
  @tailrec
  def f_row_(iter: Iterator[Row], tmp: Int, result: Iterator[Row]): Iterator[Row] = {
    if (iter.hasNext) {
      val r = iter.next
      f_row_(iter, mul(tmp, r.getAs("y")),
        result ++ Iterator(Row.fromSeq(r.toSeq :+ mul(tmp, r.getAs("y")))))
    } else result
  }
  f_row_(iter, 1, Iterator[Row]())
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show