如何使用 Scala 惰性集合实现 takeUntil

how to implement takeUntil with Scala lazy collections

我有一个昂贵的功能,我想 运行 尽可能少地满足以下要求:

我找不到使用 Iterator 的 takeWhile/dropWhile 的好解决方案,因为我想包含第一个匹配元素。刚刚得到以下解决方案:

val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)

def expensiveFunc(s:String) : Double = {
  pseudoResult(s)
}

val inputsToTry = Seq("a","b","c")

val inputIt = inputsToTry.iterator
val results = mutable.ArrayBuffer.empty[(String, Double)]

val earlyAbort = 0.5 // threshold

breakable {
  while (inputIt.hasNext) {
    val name = inputIt.next()
    val res = expensiveFunc(name)
    results += Tuple2(name,res)
    if (res<earlyAbort) break()
  }
}

println(results) // ArrayBuffer((a,0.6), (b,0.2))

val (name, bestResult) = results.minBy(_._2) // (b, 0.2)

如果我设置 val earlyAbort = 0.1,结果应该仍然是 (b, 0.2),无需再次评估所有情况。

在您的输入列表中使用视图: 尝试以下操作:

  val pseudoResult = Map("a" -> 0.6, "b" -> 0.2, "c" -> 1.0)

  def expensiveFunc(s: String): Double = {
    println(s"executed for ${s}")
    pseudoResult(s)
  }

  val inputsToTry = Seq("a", "b", "c")
  val earlyAbort = 0.5 // threshold

  def doIt(): List[(String, Double)] = {

    inputsToTry.foldLeft(List[(String, Double)]()) {
      case (n, name) =>


        val res = expensiveFunc(name)
        if(res < earlyAbort) {
          return n++List((name, res))
        }
        n++List((name, res))
    }

  }

  val (name, bestResult) = doIt().minBy(_._2)
  println(name)
  println(bestResult)

输出:

executed for a
executed for b
b
0.2

正如你所看到的,只有 a 和 b 被评估,而不是 c。

您可以使用 Stream 来实现您想要的,请记住 Stream 是某种惰性集合,它会按需评估操作。

这是 Scala Stream 文档。

您只需要这样做:

val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)
val earlyAbort = 0.5

def expensiveFunc(s: String): Double = {
  println(s"Evaluating for $s")
  pseudoResult(s)
}

val inputsToTry = Seq("a","b","c")

val results = inputsToTry.toStream.map(input => input -> expensiveFunc(input))
val finalResult = results.find { case (k, res) => res < earlyAbort }.getOrElse(results.minBy(_._2))

如果find没有得到任何值,你可以使用相同的流来查找最小值,并且函数不会再次计算,这是因为记忆:

The Stream class also employs memoization such that previously computed values are converted from Stream elements to concrete values of type A

考虑到如果原始集合为空,此代码将失败,如果你想支持空集合,你应该将 minBy 替换为 sortBy(_._2).headOption 并将 getOrElse 替换为 orElse:

val finalResultOpt = results.find { case (k, res) => res < earlyAbort }.orElse(results.sortBy(_._2).headOption)

输出为:

正在评估

正在评估 b

finalResult: (String, Double) = (b,0.2)

finalResultOpt: 选项[(String, Double)] = Some((b,0.2))

这是尾递归的用例之一:

  import scala.annotation.tailrec
  val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)

  def expensiveFunc(s:String) : Double = {
    pseudoResult(s)
  }

  val inputsToTry = Seq("a","b","c")

  val earlyAbort = 0.5 // threshold

  @tailrec
  def f(s: Seq[String], result: Map[String, Double] = Map()): Map[String, Double] = s match {
    case Nil => result
    case h::t =>
      val expensiveCalculation = expensiveFunc(h)
      val intermediateResult = result + (h -> expensiveCalculation)
      if(expensiveCalculation < earlyAbort) {
        intermediateResult
      } else {
        f(t, intermediateResult)
      }
  }
  val result = f(inputsToTry)

  println(result) // Map(a -> 0.6, b -> 0.2)

  val (name, bestResult) = f(inputsToTry).minBy(_._2) // ("b", 0.2)

最清楚、最简单的做法是 fold 处理输入,只传递当前最好的结果。

val inputIt :Iterator[String] = inputsToTry.iterator
val earlyAbort = 0.5 // threshold

inputIt.foldLeft(("",Double.MaxValue)){ case (low,name) =>
  if (low._2 < earlyAbort) low
  else Seq(low, (name, expensiveFunc(name))).minBy(_._2)
}
//res0: (String, Double) = (b,0.2)

它只根据需要多次调用 expensiveFunc(),但它会遍历整个输入迭代器。如果这仍然太繁琐(大量输入),那么我会采用尾递归方法。

val inputIt :Iterator[String] = inputsToTry.iterator
val earlyAbort = 0.5 // threshold

def bestMin(low :(String,Double) = ("",Double.MaxValue)) :(String,Double) = {
  if (inputIt.hasNext) {
    val name = inputIt.next()
    val res = expensiveFunc(name)
    if (res < earlyAbort) (name, res)
    else if (res < low._2) bestMin((name,res))
    else bestMin(low)
  } else low
}
bestMin()  //res0: (String, Double) = (b,0.2)

如果您实施 takeUntil 并使用它,如果您没有找到您要查找的内容,您仍然需要再次浏览列表以获得最低的那个。可能更好的方法是拥有一个将 findreduceOption 组合的函数,如果找到某些东西则尽早返回,或者返回将集合减少为单个项目的结果(在您的情况下,找到最小的一个)。

结果与使用 Stream 可以实现的结果相当,如之前的答复中所强调的那样,但避免了利用记忆,这对于非常大的集合来说可能很麻烦。

可能的实施方式如下:

import scala.annotation.tailrec

def findOrElse[A](it: Iterator[A])(predicate: A => Boolean,
                                   orElse: (A, A) => A): Option[A] = {
  @tailrec
  def loop(elseValue: Option[A]): Option[A] = {
    if (!it.hasNext) elseValue
    else {
      val next = it.next()
      if (predicate(next)) Some(next)
      else loop(Option(elseValue.fold(next)(orElse(_, next))))
    }
  }
  loop(None)
}

让我们添加我们的输入来测试这个:

def f1(in: String): Double = {
  println("calling f1")
  Map("a" -> 0.6, "b" -> 0.2, "c" -> 1.0, "d" -> 0.8)(in)
}

def f2(in: String): Double = {
  println("calling f2")
  Map("a" -> 0.7, "b" -> 0.6, "c" -> 1.0, "d" -> 0.8)(in)
}

val inputs = Seq("a", "b", "c", "d")

还有我们的帮手:

def apply[IN, OUT](in: IN, f: IN => OUT): (IN, OUT) =
  in -> f(in)

def threshold[A](a: (A, Double)): Boolean =
  a._2 < 0.5

def compare[A](a: (A, Double), b: (A, Double)): (A, Double) =
  if (a._2 < b._2) a else b

我们现在可以 运行 看看效果如何:

val r1 = findOrElse(inputs.iterator.map(apply(_, f1)))(threshold, compare)
val r2 = findOrElse(inputs.iterator.map(apply(_, f2)))(threshold, compare)
val r3 = findOrElse(Map.empty[String, Double].iterator)(threshold, compare)

r1Some(b, 0.2)r2Some(b, 0.6)r3 是(合理地)None。在第一种情况下,因为我们使用惰性迭代器并提前终止,所以我们只调用 f1 两次。

您可以查看结果并使用此代码 here on Scastie