重组方法以修复不在尾部位置的尾递归调用
Restructuring a method to fix tail recursive call not in tail position
考虑以下递归幂法乘法:
import scala.annotation.tailrec
@tailrec def mult(x: Double, n:Int) : Double = {
n match {
case 0 => 1
case 1 => x
case _ if ((n & 0x01) != 0) => x * mult(x*x, (n-1)/2)
case _ => mult(x*x, n/2)
}
}
编译错误为:
<console>:28: error: could not optimize @tailrec annotated method mult:
it contains a recursive call not in tail position
y * mult(x*x,(n-2)/2)
^
所以 .. 鉴于递归调用 是 最后一个条目 - 我认为产品 y *(尾递归子句)有问题?如何正确构建它?
更新
这是已接受答案的修改版本 - 我很懒,只是在调用的方法中放置了第三个累加器。
@tailrec def mult(x: Double, n:Int, accum: Double = 1.0) : Double = {
n match {
case 0 => accum
case 1 => accum * x
case _ if ((n & 0x01) != 0) => mult(x*x, (n-1)/2, x * accum)
case _ => mult(x*x, n/2, accum)
}
}
mult: (x: Double, n: Int, accum: Double)Double
试试看:
scala> mult(2, 7)
res0: Double = 128.0
scala> mult(2, 8)
res1: Double = 256.0
你的函数 mult
不是尾递归的,因为在函数体中你想对递归调用的结果做一些事情,即你想将它与 y
相乘。
要使此尾递归,您应该构造函数 mult
以便它可以将值 y
作为参数以在递归调用后删除乘法。这是一个带有阶乘的简单示例:http://c2.com/cgi/wiki?TailRecursion
尾递归调用是那些最后一条语句只是函数调用本身的调用。
也就是说,您的代码的最后一条语句应该仅 mult(x*x,(n-2)/2)
。
你可以试试这个。
import scala.annotation.tailrec
@tailrec
def mult(x: Double, n:Int,res:Double=1) : Double = {
n match {
case 0 => res
case _ => mult(x,n-1,res *x)
}
}
有两种方法可以解决这类问题。第一种是在调用中移动乘法,可能是通过添加辅助方法:
import scala.annotation.tailrec
def mult(x: Double, n: Int): Double = {
@tailrec
def go(x: Double, n: Int, mult: Double): Double = n match {
case 0 => mult
case 1 => mult * x
case _ if (n & 0x01) != 0 => go(x * x, (n - 1) / 2, x * mult)
case _ => go(x * x, n / 2, mult)
}
go(x, n, 1)
}
另一个并不能真正回答您的问题,但在某些情况下它可能是一种更方便的方法。它被称为 "trampolining":
import scala.util.control.TailCalls._
def mult(x: Double, n: Int): Double = {
def go(x: Double, n: Int): TailRec[Double] = n match {
case 0 => done(1)
case 1 => done(x)
case _ if (n & 0x01) != 0 => tailcall(go(x * x, (n - 1) / 2).map(_ * x))
case _ => tailcall(go(x * x, n / 2))
}
go(x, n).result
}
这不需要您重构您的方法,并且保证不会炸毁堆栈,但它确实会引入一些额外的开销。
考虑以下递归幂法乘法:
import scala.annotation.tailrec
@tailrec def mult(x: Double, n:Int) : Double = {
n match {
case 0 => 1
case 1 => x
case _ if ((n & 0x01) != 0) => x * mult(x*x, (n-1)/2)
case _ => mult(x*x, n/2)
}
}
编译错误为:
<console>:28: error: could not optimize @tailrec annotated method mult:
it contains a recursive call not in tail position
y * mult(x*x,(n-2)/2)
^
所以 .. 鉴于递归调用 是 最后一个条目 - 我认为产品 y *(尾递归子句)有问题?如何正确构建它?
更新
这是已接受答案的修改版本 - 我很懒,只是在调用的方法中放置了第三个累加器。
@tailrec def mult(x: Double, n:Int, accum: Double = 1.0) : Double = {
n match {
case 0 => accum
case 1 => accum * x
case _ if ((n & 0x01) != 0) => mult(x*x, (n-1)/2, x * accum)
case _ => mult(x*x, n/2, accum)
}
}
mult: (x: Double, n: Int, accum: Double)Double
试试看:
scala> mult(2, 7)
res0: Double = 128.0
scala> mult(2, 8)
res1: Double = 256.0
你的函数 mult
不是尾递归的,因为在函数体中你想对递归调用的结果做一些事情,即你想将它与 y
相乘。
要使此尾递归,您应该构造函数 mult
以便它可以将值 y
作为参数以在递归调用后删除乘法。这是一个带有阶乘的简单示例:http://c2.com/cgi/wiki?TailRecursion
尾递归调用是那些最后一条语句只是函数调用本身的调用。
也就是说,您的代码的最后一条语句应该仅 mult(x*x,(n-2)/2)
。
你可以试试这个。
import scala.annotation.tailrec
@tailrec
def mult(x: Double, n:Int,res:Double=1) : Double = {
n match {
case 0 => res
case _ => mult(x,n-1,res *x)
}
}
有两种方法可以解决这类问题。第一种是在调用中移动乘法,可能是通过添加辅助方法:
import scala.annotation.tailrec
def mult(x: Double, n: Int): Double = {
@tailrec
def go(x: Double, n: Int, mult: Double): Double = n match {
case 0 => mult
case 1 => mult * x
case _ if (n & 0x01) != 0 => go(x * x, (n - 1) / 2, x * mult)
case _ => go(x * x, n / 2, mult)
}
go(x, n, 1)
}
另一个并不能真正回答您的问题,但在某些情况下它可能是一种更方便的方法。它被称为 "trampolining":
import scala.util.control.TailCalls._
def mult(x: Double, n: Int): Double = {
def go(x: Double, n: Int): TailRec[Double] = n match {
case 0 => done(1)
case 1 => done(x)
case _ if (n & 0x01) != 0 => tailcall(go(x * x, (n - 1) / 2).map(_ * x))
case _ => tailcall(go(x * x, n / 2))
}
go(x, n).result
}
这不需要您重构您的方法,并且保证不会炸毁堆栈,但它确实会引入一些额外的开销。