Scala 中的延续传递风格

Continuation-passing style in Scala

我粗略地阅读了几篇关于延续传递风格的博客articles/Wikipedia。我的高级目标是找到一种系统的技术来使任何递归函数(或者,如果有限制,请注意它们)尾递归。然而,我很难表达我的想法,我不确定我的尝试是否有意义。

为了示例的目的,我将提出一个简单的问题。目标是,给定一个唯一字符的排序列表,以字母顺序输出由这些字符组成的所有可能的单词。例如,sol("op".toList, 3) 应该 return ooo,oop,opo,opp,poo,pop,ppo,ppp.

我的递归解决方案如下:

def sol(chars: List[Char], n: Int) = {
    def recSol(n: Int): List[List[Char]] = (chars, n) match {
        case (_  , 0) => List(Nil)
        case (Nil, _) => Nil
        case (_  , _) =>
            val tail = recSol(n - 1)
            chars.map(ch => tail.map(ch :: _)).fold(Nil)(_ ::: _)
    }
    recSol(n).map(_.mkString).mkString(",")
}

我确实尝试通过添加一个函数作为参数来重写它,但我没能做出我确信是尾递归的东西。我不想在问题中包括我的尝试,因为我为他们的天真感到羞耻,所以请原谅我。

因此问题基本上是:上面的函数如何用 CPS 编写?

试试看:

import scala.annotation.tailrec
def sol(chars: List[Char], n: Int) = {
  @tailrec
  def recSol(n: Int)(cont: (List[List[Char]]) => List[List[Char]]): List[List[Char]] = (chars, n) match {
    case (_  , 0) => cont(List(Nil))
    case (Nil, _) => cont(Nil)
    case (_  , _) =>
      recSol(n-1){ tail =>
        cont(chars.map(ch => tail.map(ch :: _)).fold(Nil)(_ ::: _))
      }
  }
  recSol(n)(identity).map(_.mkString).mkString(",")
}

执行 CPS 转换的首要任务是确定延续的表示形式。我们可以将延续视为具有 "hole" 的暂停计算。当用一个值填充该孔时,可以计算剩余的计算。因此,函数是表示延续的自然选择,至少对于玩具示例而言:

type Cont[Hole,Result] = Hole => Result

这里Hole表示需要填补的洞的类型,Result表示计算最终计算出的值的类型。

既然我们有了表示延续的方法,我们就可以担心 CPS 转换本身了。基本上,这涉及以下步骤:

  • 转换以递归方式应用于表达式,在 "trivial" 表达式/函数调用处停止。在此上下文中,"trivial" 包括 Scala 定义的函数(因为它们未经过 CPS 转换,因此没有延续参数)。
  • 我们需要为每个函数添加一个Cont[Return,Result]类型的参数,其中Return是未转换函数的return类型,Result是未转换函数的类型整体计算的最终结果。这个新参数表示当前的延续。转换函数的 return 类型也更改为 Result.
  • 每个函数调用都需要进行转换以适应新的延续参数。 调用之后的所有内容都需要放入延续函数中,然后将其添加到参数列表中。

例如一个函数:

def f(x : Int) : Int = x + 1

变成:

def fCps[Result](x : Int)(k : Cont[Int,Result]) : Result = k(x + 1)

def g(x : Int) : Int = 2 * f(x)

变成:

def gCps[Result](x : Int)(k : Cont[Int,Result]) : Result = {
  fCps(x)(y => k(2 * y))
}

现在 gCps(5) returns(通过柯里化)表示部分计算的函数。我们可以从这个部分计算中提取值,并通过提供一个延续函数来使用它。例如,我们可以使用恒等函数提取值不变:

gCps(5)(x => x)
// 12

或者,我们可以使用 println 来打印它:

gCps(5)(println)
// prints 12

将此应用于您的代码,我们获得:

def solCps[Result](chars : List[Char], n : Int)(k : Cont[String, Result]) : Result = {
  @scala.annotation.tailrec
  def recSol[Result](n : Int)(k : Cont[List[List[Char]], Result]) : Result = (chars, n) match {
    case (_  , 0) => k(List(Nil))
    case (Nil, _) => k(Nil)
    case (_  , _) =>
      recSol(n - 1)(tail =>
                      k(chars.map(ch => tail.map(ch :: _)).fold(Nil)(_ ::: _)))
  }

  recSol(n)(result =>
              k(result.map(_.mkString).mkString(",")))
}

如您所见,虽然 recSol 现在是尾递归的,但它伴随着在每次迭代中构建更复杂的延续的成本。所以我们真正做的就是将 JVM 控制堆栈上的 space 换成堆上的 space —— CPS 转换不会神奇地降低算法的 space 复杂性。

此外,recSol 只是尾递归,因为对 recSol 的递归调用恰好是 recSol 执行的第一个(非平凡的)表达式。不过,一般来说,递归调用将发生在延续内。在有一个递归调用的情况下,我们可以通过将对递归函数的 调用转换为 CPS 来解决这个问题。即便如此,一般来说,我们仍然只是将堆栈 space 换成堆 space.