Scala foldLeft 当某些条件为真时

Scala foldLeft while some conditions are true

如何在 Scala 中模拟以下行为?即在满足累加器上的某些特定条件时继续折叠。

def foldLeftWhile[B](z: B, p: B => Boolean)(op: (B, A) => B): B

例如

scala> val seq = Seq(1, 2, 3, 4)
seq: Seq[Int] = List(1, 2, 3, 4)
scala> seq.foldLeftWhile(0, _ < 3) { (acc, e) => acc + e }
res0: Int = 1
scala> seq.foldLeftWhile(0, _ < 7) { (acc, e) => acc + e }
res1: Int = 6

更新:

根据@Dima 的回答,我意识到我的意图有点副作用。所以我让它与 takeWhile 同步,即如果谓词不匹配,则不会前进。并添加更多示例以使其更清楚。 (注意:这不适用于 Iterators)

简单地在累加器上使用分支条件:

seq.foldLeft(0, _ < 3) { (acc, e) => if (acc < 3) acc + e else acc}

但是,您将 运行 序列的每个条目。

首先,请注意您的示例似乎有误。如果我正确理解您的描述,结果应该是 1(满足谓词 _ < 3 的最后一个值),而不是 6

执行此操作的最简单方法是使用 return 语句,这在 Scala 中非常不受欢迎,但我想,为了完整起见,我会提到它。

def foldLeftWhile[A, B](seq: Seq[A], z: B, p: B => Boolean)(op: (B, A) => B): B = foldLeft(z) { case (b, a) => 
   val result = op(b, a) 
   if(!p(result)) return b
   result
}

因为我们想避免使用 return,scanLeft 可能是一种可能性:

seq.toStream.scanLeft(z)(op).takeWhile(p).last

这有点浪费,因为它累积了所有(匹配的)结果。 您可以使用 iterator 而不是 toStream 来避免这种情况,但是 Iterator 出于某种原因没有 .last,因此,您必须额外扫描它明确地:

 seq.iterator.scanLeft(z)(op).takeWhile(p).foldLeft(z) { case (_, b) => b }

在 Scala 中定义您想要的内容非常简单。您可以定义一个隐式 class ,它将您的函数添加到任何 TraversableOnce (包括 Seq)。

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft(init)((acc, next) => if (where(acc)) op(acc, next) else acc)
  }
}
Seq(1,2,3,4).foldLeftWhile(0)(_ < 3)((acc, e) => acc + e)

更新,问题已修改:

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft((init, false))((a,b) => if (a._2) a else {
      val r = op(a._1, b)
      if (where(r)) (op(a._1, b), false) else (a._1, true)
    })._1
  }
}

请注意,我将您的 (z: B, p: B => Boolean) 拆分为两个高阶函数。这只是个人对 Scala 风格的偏好。

这个怎么样:

def foldLeftWhile[A, B](z: B, xs: Seq[A], p: B => Boolean)(op: (B, A) => B): B = {
  def go(acc: B, l: Seq[A]): B = l match {
    case h +: t => 
        val nacc = op(acc, h)
        if(p(nacc)) go(op(nacc, h), t) else nacc
    case _ => acc
  }
  go(z, xs)
}

val a = Seq(1,2,3,4,5,6)
val r = foldLeftWhile(0, a, (x: Int) => x <= 3)(_ + _)
println(s"$r")

在谓词为真时递归迭代集合,然后return累加器。

你可以在scalafiddle

上试试

一段时间后,我收到了很多好看的答案。所以,我把它们组合成这个单曲 post

@Dima 的一个非常简洁的解决方案

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    seq.toStream.scanLeft(z)(op).takeWhile(p).lastOption.getOrElse(z)
  }
}

来自@ElBaulP(我做了一些修改以匹配@Dima 的评论)

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    @tailrec
    def foldLeftInternal(acc: B, seq: Seq[A]): B = seq match {
      case x :: _ =>
        val newAcc = op(acc, x)
        if (p(newAcc))
          foldLeftInternal(newAcc, seq.tail)
        else
          acc
      case _ => acc
    }

    foldLeftInternal(z, seq)
  }
}

由我来回答(涉及副作用)

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    var accumulator = z
    seq
      .map { e =>
        accumulator = op(accumulator, e)
        accumulator -> e
      }
      .takeWhile { case (acc, _) =>
        p(acc)
      }
      .lastOption
      .map { case (acc, _) =>
        acc
      }
      .getOrElse(z)
  }
}

第一个例子:每个元素的谓词

首先可以使用内尾递归函数

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case head :: tail if f(head) => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

或短版

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case head :: tail if f(head) => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}

那就用吧

val a = List(1, 2, 3, 4, 5, 6).foldLeftWhile(0, _ < 3)(_ + _) //a == 3

第二个例子:对于累加器值:

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case _ if !f(z) => z
      case head :: tail => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

或短版

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case _ if !f(z) => z
    case head :: tail => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}