在 Scala 中使用尾调用递归获取子问题的结果

Get results for sub problems using tail call recursion in Scala

我正在尝试使用 @tailrec 计算每个子问题的结果 类似于正常的递归解决方案如何为每个子问题生成解决方案。 以下是我处理的示例。

@tailrec
  def collatz(
      n: BigInt,
      acc: BigInt,
      fn: (BigInt, BigInt) => Unit
  ): BigInt = {
    fn(n, acc)
    if (n == 1) {
      acc
    } else if (n % 2 == 0) {
      collatz(n / 2, acc + 1, fn)
    } else {
      collatz(3 * n + 1, acc + 1, fn)
    }
  }

这里我使用 Collatz Conjecture 计算一个数达到 1 时的计数。举个例子,让我们假设它用于数字 32

val n = BigInt("32")
    val c = collatz(n, 0, (num, acc) => {
      println("Num -> " + num + " " + " " + "Acc -> " + acc)
    })

我得到以下输出。

Num -> 32  Acc -> 0
Num -> 16  Acc -> 1
Num -> 8  Acc -> 2
Num -> 4  Acc -> 3
Num -> 2  Acc -> 4
Num -> 1  Acc -> 5

正常的递归解决方案将return精确计数每个数字。例如,数字 21 步中达到 1。因此,每个子问题都有精确的解决方案,但在 tailrec 方法中,只有最终结果才能正确计算。变量 acc 的行为与预期的循环变量完全一样。

如何更改尾调用优化的代码,同时我可以获得每个子问题的准确值。简而言之,我怎样才能为 acc 变量获得 Stack 类型的行为。

此外,一个相关的问题是,如果不使用 println 语句,对于 n 的大值,lambda 函数 fn 的开销有多大。

我正在添加一个可以为子问题生成正确解决方案的递归解决方案。

def collatz2(
      n: BigInt,
      fn: (BigInt, BigInt) => Unit
  ): BigInt = {

    val c: BigInt = if (n == 1) {
      0
    } else if (n % 2 == 0) {
      collatz2(n / 2, fn) + 1
    } else {
      collatz2(3 * n + 1, fn) + 1
    }
    fn(n, c)
    c
  }

它产生以下输出。

Num -> 1  Acc -> 0
Num -> 2  Acc -> 1
Num -> 4  Acc -> 2
Num -> 8  Acc -> 3
Num -> 16  Acc -> 4
Num -> 32  Acc -> 5

您不能 "attain Stack type of behavior" 使用尾递归(不使用显式堆栈)。 @tailrec 注释表示您没有使用调用堆栈,它可以被优化掉。您必须决定是要尾递归还是递归子问题求解。一些问题(例如二分搜索)非常适合尾递归,而其他问题(例如您的 collat​​z 代码)需要更多思考,还有一些问题(例如 DFS)过于依赖调用堆栈而无法从尾递归中获益.

我不确定我是否正确理解了你的问题。听起来你是在要求我们编写 collat​​z2 以便它是尾递归的。我用两种方式重写了它。

虽然我提供了两种解决方案,但它们实际上是一回事。一种使用 List 作为堆栈,其中 List 的头部是堆栈的顶部。另一个使用 mutable.Stack 数据结构。研究这两个解决方案,直到您明白为什么它们都与原始问题中的 collat​​z2 相同。

要使程序尾递归,我们要做的就是模拟将值压入栈中,然后一个一个弹出的效果。在弹出阶段,我们为 Acc 赋值。 (对于那些不记得的人,Hariharan 的说法中的 Acc 是每个术语的索引。)

import scala.collection.mutable

object CollatzCount {

  def main(args: Array[String]) = {
    val start = 32

    collatzFinalList(start, printer)

    collatzFinalStack(start, printer)

  }

  def collatzInnerList(n: Int, acc: List[Int]): List[Int] = {
    if (n == 1) n :: acc
    else if (n % 2 == 0) collatzInnerList(n/2, n :: acc )
    else collatzInnerList(3*n + 1, n :: acc )
  }

  def collatzFinalList(n: Int, fun: (Int, Int)=>Unit): Unit = {
    val acc = collatzInnerList(n, List())
    acc.foldLeft(0){ (ctr, e) =>
      fun(e, ctr)
      ctr + 1
    }
  }

  def collatzInnerStack(n: Int, stack: mutable.Stack[Int]): mutable.Stack[Int] = {
    if (n == 1) {
      stack.push(n)
      stack
    } else if (n % 2 == 0) {
      stack.push(n)
      collatzInnerStack(n/2, stack)
    } else {
      stack.push(n)
      collatzInnerStack(3*n + 1, stack)
    }
  }

  def popStack(ctr: Int, stack: mutable.Stack[Int], fun: (Int, Int)=>Unit): Unit = {
    if (stack.nonEmpty) {
      val popped = stack.pop
      fun(popped, ctr)
      popStack(ctr + 1, stack, fun)
    } else ()
  }


  def collatzFinalStack(n: Int, fun: (Int, Int) => Unit): Unit = {
    val stack = collatzInnerStack(n, mutable.Stack())
    popStack(0, stack, fun)
  }


  val printer = (x: Int, y: Int) => println("Num ->" + x + " " + " " + "Acc -> " + y)

}