Scala 中二叉树的尾递归折叠

Tail recursive fold on a binary tree in Scala

我正在尝试为二叉树找到尾递归折叠函数。给定以下定义:

// From the book "Functional Programming in Scala", page 45
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

实现非尾递归函数非常简单:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B =
  t match {
    case Leaf(v)      => map(v)
    case Branch(l, r) => 
      red(fold(l)(map)(red), fold(r)(map)(red))
  }

但是我现在正在苦苦寻找尾递归折叠函数,以便可以使用注解@annotation.tailrec

在我的研究过程中,我发现了几个例子,其中树上的尾递归函数可以,例如使用自己的堆栈计算所有叶子的总和,该堆栈基本上是 List[Tree[Int]]。但据我所知,在这种情况下,它只适用于加法,因为先评估运算符的左侧还是右侧并不重要。但对于广义折叠来说,这是非常相关的。为了展示我的意图,这里有一些示例树:

val leafs = Branch(Leaf(1), Leaf(2))
val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3)))
val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4)))
val cmb = Branch(right, Branch(bal, Branch(leafs, left)))
val trees = List(leafs, left, right, bal, cmb)

基于这些树,我想使用给定的折叠方法创建一个深层副本,例如:

val oldNewPairs = 
  trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))

然后证明相等条件对所有创建的副本都成立:

val conditionHolds = oldNewPairs.forall(p => {
  if (p._1 == p._2) true
  else {
    println(s"Original:\n${p._1}\nNew:\n${p._2}")
    false
  }
})
println("Condition holds: " + conditionHolds)

有人能给我一些指点吗?

您可以在 ScalaFiddle 找到本题中使用的代码:https://scalafiddle.io/sf/eSKJyp2/15

如果您停止使用函数调用堆栈并开始使用由您的代码和累加器管理的堆栈,您可能会找到尾递归解决方案:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          val leafRes = map(v)
          foldImp(
            toVisit.tail,
            acc :+ leafRes
          )
        case Branch(l, r) =>
          foldImp(l :: r :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.dropRight(2) ++   Vector(acc.takeRight(2).reduce(red)))
      }
    }

  foldImp(t::Nil, Vector.empty).head

}

想法是从左到右累加值,通过引入存根节点跟踪亲子关系,并使用累加器的最后两个元素使用 red 函数减少结果在探索中发现存根节点。

这个解决方案可以优化,但它已经是一个尾递归函数实现。

编辑:

可以通过将累加器数据结构更改为视为堆栈的列表来稍微简化:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          foldImp(
            toVisit.tail,
            map(v)::acc 
          )
        case Branch(l, r) =>
          foldImp(r :: l :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2))
      }
    }

  foldImp(t::Nil, Nil).head

}