为什么 Scala 对 `find` 使用 while 而不是递归

Why did Scala use a while and not recursion for `find`

只是好奇为什么 Scala 作者在 List 上实现 find 时不使用递归甚至模式匹配?

他们的实现是这样的:

  override final def find(p: A => Boolean): Option[A] = {
    var these: List[A] = this
    while (!these.isEmpty) {
      if (p(these.head)) return Some(these.head)
      these = these.tail
    }
    None
  }

使用 whileheadtail。他们本可以用递归为“scala-esq”做点什么?

  @tailrec
  def find(p: A => Boolean): Option[A] = {
    this match {
      case Nil                     => None
      case head :: tail if p(head) => Some(head)
      case elements                => find(p, elements.tail)
    }
  }

不会是因为尾调用优化吧?它以某种方式更有效率而我想念它吗?难道只是作者的喜好和风格?!当 A 可以是任何东西时,它有什么不灵活的地方吗?嗯嗯

快速实验(使用 Scala 2.13.2)。三个候选实现是:

  • while-loop
  • tail-recursive,但保持与 while 版本相同的逻辑
  • tail-recursive 具有模式匹配

我已经在适当的地方修改了逻辑以减少对编译器优化的依赖(nonEmpty!isEmpty 并显式保存 these.head 所以它不会被调用两次)。

  import scala.annotation.tailrec
  
  object ListFindComparison {
    def whileFind[A](lst: List[A])(p: A => Boolean): Option[A] = { 
      var these: List[A] = lst 
      while (these.nonEmpty) {
        val h = these.head

        if (p(h)) return Some(h)
        else these = these.tail
      }   
      None
    }
  
    def tailrecFind[A](lst: List[A])(p: A => Boolean): Option[A] = { 
      @tailrec
      def iter(these: List[A]): Option[A] =
        if (these.nonEmpty) {
          val h = these.head
          if (p(h)) Some(h)
          else iter(these.tail)
        } else None
  
      iter(lst)
    }
  
    def tailRecPM[A](lst: List[A])(p: A => Boolean): Option[A] = { 
      @tailrec
      def iter(these: List[A]): Option[A] =
        these match {
          case Nil => None
          case head :: tail if p(head) => Some(head)
          case _ => iter(these.tail)
        }   
  
      iter(lst)
    }
  }

检查字节码时(使用:javap ListFindComparison$),我们看到

对于whileFind,发出的代码很简单

Code:
   0: aload_1
   1: astore_3
   2: aload_3
   3: invokevirtual #25                 // Method scala/collection/immutable/List.nonEmpty:()Z
   6: ifeq          50
   9: aload_3
  10: invokevirtual #29                 // Method scala/collection/immutable/List.head:()Ljava/lang/Object;
  13: astore        4
  15: aload_2
  16: aload         4
  18: invokeinterface #35,  2           // InterfaceMethod scala/Function1.apply:(Ljava/lang/Object;)Ljava/lang/Object;
  23: invokestatic  #41                 // Method scala/runtime/BoxesRunTime.unboxToBoolean:(Ljava/lang/Object;)Z
  26: ifeq          39
  29: new           #43                 // class scala/Some
  32: dup
  33: aload         4
  35: invokespecial #46                 // Method scala/Some."<init>":(Ljava/lang/Object;)V
  38: areturn
  39: aload_3
  40: invokevirtual #49                 // Method scala/collection/immutable/List.tail:()Ljava/lang/Object;
  43: checkcast     #21                 // class scala/collection/immutable/List
  46: astore_3
  47: goto          2
  50: getstatic     #54                 // Field scala/None$.MODULE$:Lscala/None$;
  53: areturn

tail-recursive 的发现基本相同:

aload_0
aload_1
aload_2
invokespecial   // call the appropriate (private) iter methods
areturn

tailrecFind中的iter

Code:
   0: aload_1
   1: invokevirtual #25                 // Method scala/collection/immutable/List.nonEmpty:()Z
   4: ifeq          53
   7: aload_1
   8: invokevirtual #29                 // Method scala/collection/immutable/List.head:()Ljava/lang/Object;
  11: astore        4
  13: aload_2
  14: aload         4
  16: invokeinterface #35,  2           // InterfaceMethod scala/Function1.apply:(Ljava/lang/Object;)Ljava/lang/Object;
  21: invokestatic  #41                 // Method scala/runtime/BoxesRunTime.unboxToBoolean:(Ljava/lang/Object;)Z
  24: ifeq          39
  27: new           #43                 // class scala/Some
  30: dup
  31: aload         4
  33: invokespecial #46                 // Method scala/Some."<init>":(Ljava/lang/Object;)V
  36: goto          50
  39: aload_1
  40: invokevirtual #49                 // Method scala/collection/immutable/List.tail:()Ljava/lang/Object;
  43: checkcast     #21                 // class scala/collection/immutable/List
  46: astore_1
  47: goto          0
  50: goto          56
  53: getstatic     #54                 // Field scala/None$.MODULE$:Lscala/None$;
  56: areturn

while 和这个 iter 的核心没有重大区别:JIT 很可能在足够的调用后将它们带到相同的机器代码中。 tailrecFind 进入 iter 的常量开销比 whileFind 进入循环的开销略大。这里不太可能存在有意义的性能差异(事实上,由于 while 将语言定义保留在 dotty 中,while 的未来是作为一个库函数,tail-recursively 调用只要谓词通过一个块)。

iter与pattern-matching的区别很大:

Code:
   0: aload_1
   1: astore        5
   3: getstatic     #77                 // Field scala/collection/immutable/Nil$.MODULE$:Lscala/collection/immutable/Nil$;
   6: aload         5
   8: invokevirtual #80                 // Method java/lang/Object.equals:(Ljava/lang/Object;)Z
  11: ifeq          22
  14: getstatic     #54                 // Field scala/None$.MODULE$:Lscala/None$;
  17: astore        4
  19: goto          92
  22: goto          25
  25: aload         5
  27: instanceof    #82                 // class scala/collection/immutable/$colon$colon
  30: ifeq          78
  33: aload         5
  35: checkcast     #82                 // class scala/collection/immutable/$colon$colon
  38: astore        6
  40: aload         6
  42: invokevirtual #83                 // Method scala/collection/immutable/$colon$colon.head:()Ljava/lang/Object;
  45: astore        7
  47: aload_2
  48: aload         7
  50: invokeinterface #35,  2           // InterfaceMethod scala/Function1.apply:(Ljava/lang/Object;)Ljava/lang/Object;
  55: invokestatic  #41                 // Method scala/runtime/BoxesRunTime.unboxToBoolean:(Ljava/lang/Object;)Z
  58: ifeq          75
  61: new           #43                 // class scala/Some
  64: dup
  65: aload         7
  67: invokespecial #46                 // Method scala/Some."<init>":(Ljava/lang/Object;)V
  70: astore        4
  72: goto          92
  75: goto          81
  78: goto          81
  81: aload_1
  82: invokevirtual #49                 // Method scala/collection/immutable/List.tail:()Ljava/lang/Object;
  85: checkcast     #21                 // class scala/collection/immutable/List
  88: astore_1
  89: goto          0
  92: aload         4
  94: areturn

这不太可能像没有 pattern-matching 的版本那样高效(尽管公平地说,分支实际上 真的 对于预测者来说很容易: not-taken (not-Nil), not-taken (::), not-taken (predicate fails), 除了最后一个 运行) .

对我来说有点有趣的是,我们在检查 Nil 时接到了对 equals 的调用:它可能仍然比 isEmpty/nonEmpty 快,但是它没有 pattern-matching 和显式 eq/ne 反对 Nil.

会更快

我还注意到 pattern-matching 反对 this 有点反模式 IMO:在这一点上,你几乎肯定最好使用虚拟方法调度,因为你基本上实现了一个缓慢的vtable(如果您将常见情况放在首位,它确实具有潜在 pre-JIT 的优势)。

如果你真的很关心性能,我会尽量避免 pattern-matching。

PS:我没有分析过简单的foldLeft解决方法:

lst.foldLeft(None) { (acc, v) =>
  acc.orElse {
    if (p(v)) Some(v)
    else None
  }
}

但由于那不是 short-circuit,我怀疑它不会始终击败任何候选者,即使在最后一个元素之前没有匹配的情况下,它甚至可能不会击败pattern-match当时的版本。