免费 ~> Trampoline:递归程序因 OutOfMemoryError 而崩溃

Free ~> Trampoline : recursive program crashes with OutOfMemoryError

假设我正在尝试实现一种非常简单的领域特定语言,只有一个操作:

printLine(line)

然后我想写一个程序,它接受一个整数 n 作为输入,如果 n 可以被 10k 整除则打印一些东西,然后用 n + 1 调用它自己,直到 n 达到某个最大值 N.

省略所有由 for-comprehensions 引起的语法噪音,我想要的是:

@annotation.tailrec def p(n: Int): Unit = {
  if (n % 10000 == 0) printLine("line")
  if (n > N) () else p(n + 1)
}

本质上,这将是一种 "fizzbuzz"。

以下是使用 Scalaz 7.3.0-M7 中的 Free monad 实现此功能的一些尝试:

import scalaz._

object Demo1 {

  // define operations of a little domain specific language
  sealed trait Lang[X]
  case class PrintLine(line: String) extends Lang[Unit]

  // define the domain specific language as the free monad of operations
  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}

  // lift operations into the free monad
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  // write a program that is just a loop that prints current index 
  // after every few iteration steps
  val mod =  100000
  val N =   1000000

  // straightforward syntax: deadly slow, exits with OutOfMemoryError
  def p0(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- (if (i > N) ret else p0(i + 1))
  } yield ()

  // Same as above, but written out without `for`
  def p1(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
    }

  // Same as above, with `map` attached to recursive call
  def p2(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p2(i + 1).map{ ignore2 => () })
    }

  // Same as above, but without the `map`; performs ok.
  def p3(i: Int): Prog[Unit] = {
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ 
      ignore1 =>
      if (i > N) ret else p3(i + 1)
    }
  }

  // Variation of the above; Ok.
  def p4(i: Int): Prog[Unit] = (for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
  } yield ()).flatMap{ ignored2 => 
    if (i > N) ret else p4(i + 1) 
  }

  // try to use the variable returned by the last generator after yield,
  // hope that the final `map` is optimized away (it's not optimized away...)
  def p5(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    stopHere <- (if (i > N) ret else p5(i + 1))
  } yield stopHere

  // define an interpreter that translates the programs into Trampoline
  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]  
  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case PrintLine(l) => Trampoline.delay(println(l))
    }
  }

  // try it out
  def main(args: Array[String]): Unit = {
    println("\n p0")
    p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p1")
    p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p2")
    p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p3")
    p3(0).foldMap(interpreter).run // ok 
    println("\n p4")
    p4(0).foldMap(interpreter).run // ok
    println("\n p5")
    p5(0).foldMap(interpreter).run // OutOfMemory
  }
}

不幸的是,直截了当的翻译 (p0) 似乎 运行 具有某种 O(N^2) 的开销,并因 OutOfMemoryError 而崩溃。问题似乎是 for-comprehension 在对 p0 的递归调用之后附加了一个 map{x => ()},这会强制 Free monad 填充整个内存并提醒 "finish 'p0' and then do nothing"。 如果我手动 "unroll" 理解 for,并显式写出最后的 flatMap(如 p3p4),那么问题就消失了,并且一切 运行 都很顺利。然而,这是一个非常脆弱的解决方法:如果我们简单地向它附加一个 map(id) ,程序的行为会发生巨大变化,并且这个 map(id) 在代码中甚至不可见,因为它是生成的自动被for-理解。

在这个较旧的post这里:https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ 人们反复建议将递归调用包装到 suspend 中。这是 Applicative 实例和 suspend:

的尝试
import scalaz._

// Essentially same as in `Demo1`, but this time with 
// an `Applicative` and an explicit `Suspend` in the 
// `for`-comprehension
object Demo2 {

  sealed trait Lang[H]

  case class Const[H](h: H) extends Lang[H]
  case class PrintLine[H](line: String) extends Lang[H]

  implicit object Lang extends Applicative[Lang] {
    def point[A](a: => A): Lang[A] = Const(a)
    def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
      case Const(x) => {
        f match {
          case Const(ab) => Const(ab(x))
          case _ => throw new Error
        }
      }
      case PrintLine(l) => PrintLine(l)
    }
  }

  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  val mod = 100000
  val N = 2000000

  // try to suspend the entire second generator
  def p7(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- Free.suspend(if (i > N) ret else p7(i + 1))
  } yield ()

  // try to suspend the recursive call
  def p8(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- if (i > N) ret else Free.suspend(p8(i + 1))
  } yield ()

  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]

  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case Const(x) => Trampoline.done(x)
      case PrintLine(l) => 
        (Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
    }
  }

  def main(args: Array[String]): Unit = {
    p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    p8(0).foldMap(interpreter).run // same...
  }
}

插入 suspend 并没有真正帮助:它仍然很慢,并且在 OutOfMemoryError 时崩溃。

我应该以某种不同的方式使用 suspend 吗?

也许有一些纯粹的语法补救措施可以使用 for-comprehensions 而不会在最后生成 map

如果有人能指出我在这里做错了什么,以及如何修复它,我将不胜感激。

Scala 编译器添加的多余 map 将递归从尾部位置移动到 非尾部 位置。免费的 monad 仍然使这个堆栈安全,但是 space 复杂度变为 O(N) 而不是 O(1)。 (具体还是不O(N2).)

是否可以让 scalac 优化 map 是一个单独的问题(我不知道答案)。

我将尝试说明在解释 p1p3 时发生了什么。 (我将忽略对 Trampoline 的翻译,这是多余的(见下文)。)

p3(即没有额外的map

让我使用以下 shorthand:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => if (i > N) ret else p3(i + 1)

p3(0)解读如下

p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)

等等...您会看到任何时候所需的内存量都不会超过某个常量上限。

p1(即有额外的 map

我将使用以下 shorthands:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }

def cpu: Unit => Prg[Unit] = // constant pure unit
  ignore => Free.pure(())

现在p1(0)解释如下:

p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu

等等...您会看到内存消耗与 N 线性相关。我们只是将评估从堆栈移到了堆。

带走:为了保持Free内存友好,将递归保持在"tail position",也就是在[=的右边29=](或map)。

旁白: 不需要翻译成 Trampoline,因为 Free 已经被 trampolined。您可以直接解释为 Id 并使用 foldMapRec 进行堆栈安全解释:

val idInterpreter = new (Lang ~> Id) {
  def apply[A](cmd: Lang[A]): Id[A] = cmd match {
    case PrintLine(l) => println(l)
  }
}

p0(0).foldMapRec(idInterpreter)

这将重新获得一些内存(但不会使问题消失)。