免费 ~> Trampoline:递归程序因 OutOfMemoryError 而崩溃
Free ~> Trampoline : recursive program crashes with OutOfMemoryError
假设我正在尝试实现一种非常简单的领域特定语言,只有一个操作:
printLine(line)
然后我想写一个程序,它接受一个整数 n
作为输入,如果 n
可以被 10k 整除则打印一些东西,然后用 n + 1
调用它自己,直到 n
达到某个最大值 N
.
省略所有由 for-comprehensions 引起的语法噪音,我想要的是:
@annotation.tailrec def p(n: Int): Unit = {
if (n % 10000 == 0) printLine("line")
if (n > N) () else p(n + 1)
}
本质上,这将是一种 "fizzbuzz"。
以下是使用 Scalaz 7.3.0-M7 中的 Free monad 实现此功能的一些尝试:
import scalaz._
object Demo1 {
// define operations of a little domain specific language
sealed trait Lang[X]
case class PrintLine(line: String) extends Lang[Unit]
// define the domain specific language as the free monad of operations
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
// lift operations into the free monad
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
// write a program that is just a loop that prints current index
// after every few iteration steps
val mod = 100000
val N = 1000000
// straightforward syntax: deadly slow, exits with OutOfMemoryError
def p0(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- (if (i > N) ret else p0(i + 1))
} yield ()
// Same as above, but written out without `for`
def p1(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
}
// Same as above, with `map` attached to recursive call
def p2(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p2(i + 1).map{ ignore2 => () })
}
// Same as above, but without the `map`; performs ok.
def p3(i: Int): Prog[Unit] = {
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
if (i > N) ret else p3(i + 1)
}
}
// Variation of the above; Ok.
def p4(i: Int): Prog[Unit] = (for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
} yield ()).flatMap{ ignored2 =>
if (i > N) ret else p4(i + 1)
}
// try to use the variable returned by the last generator after yield,
// hope that the final `map` is optimized away (it's not optimized away...)
def p5(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
stopHere <- (if (i > N) ret else p5(i + 1))
} yield stopHere
// define an interpreter that translates the programs into Trampoline
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case PrintLine(l) => Trampoline.delay(println(l))
}
}
// try it out
def main(args: Array[String]): Unit = {
println("\n p0")
p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p1")
p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p2")
p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p3")
p3(0).foldMap(interpreter).run // ok
println("\n p4")
p4(0).foldMap(interpreter).run // ok
println("\n p5")
p5(0).foldMap(interpreter).run // OutOfMemory
}
}
不幸的是,直截了当的翻译 (p0
) 似乎 运行 具有某种 O(N^2) 的开销,并因 OutOfMemoryError 而崩溃。问题似乎是 for
-comprehension 在对 p0
的递归调用之后附加了一个 map{x => ()}
,这会强制 Free
monad 填充整个内存并提醒 "finish 'p0' and then do nothing"。
如果我手动 "unroll" 理解 for
,并显式写出最后的 flatMap
(如 p3
和 p4
),那么问题就消失了,并且一切 运行 都很顺利。然而,这是一个非常脆弱的解决方法:如果我们简单地向它附加一个 map(id)
,程序的行为会发生巨大变化,并且这个 map(id)
在代码中甚至不可见,因为它是生成的自动被for
-理解。
在这个较旧的post这里:https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/
人们反复建议将递归调用包装到 suspend
中。这是 Applicative
实例和 suspend
:
的尝试
import scalaz._
// Essentially same as in `Demo1`, but this time with
// an `Applicative` and an explicit `Suspend` in the
// `for`-comprehension
object Demo2 {
sealed trait Lang[H]
case class Const[H](h: H) extends Lang[H]
case class PrintLine[H](line: String) extends Lang[H]
implicit object Lang extends Applicative[Lang] {
def point[A](a: => A): Lang[A] = Const(a)
def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
case Const(x) => {
f match {
case Const(ab) => Const(ab(x))
case _ => throw new Error
}
}
case PrintLine(l) => PrintLine(l)
}
}
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
val mod = 100000
val N = 2000000
// try to suspend the entire second generator
def p7(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- Free.suspend(if (i > N) ret else p7(i + 1))
} yield ()
// try to suspend the recursive call
def p8(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- if (i > N) ret else Free.suspend(p8(i + 1))
} yield ()
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case Const(x) => Trampoline.done(x)
case PrintLine(l) =>
(Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
}
}
def main(args: Array[String]): Unit = {
p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
p8(0).foldMap(interpreter).run // same...
}
}
插入 suspend
并没有真正帮助:它仍然很慢,并且在 OutOfMemoryError
时崩溃。
我应该以某种不同的方式使用 suspend
吗?
也许有一些纯粹的语法补救措施可以使用 for-comprehensions 而不会在最后生成 map
?
如果有人能指出我在这里做错了什么,以及如何修复它,我将不胜感激。
Scala 编译器添加的多余 map
将递归从尾部位置移动到 非尾部 位置。免费的 monad 仍然使这个堆栈安全,但是 space 复杂度变为 O(N) 而不是 O(1)。 (具体还是不O(N2).)
是否可以让 scalac
优化 map
是一个单独的问题(我不知道答案)。
我将尝试说明在解释 p1
与 p3
时发生了什么。 (我将忽略对 Trampoline
的翻译,这是多余的(见下文)。)
p3
(即没有额外的map
)
让我使用以下 shorthand:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => if (i > N) ret else p3(i + 1)
现p3(0)
解读如下
p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)
等等...您会看到任何时候所需的内存量都不会超过某个常量上限。
p1
(即有额外的 map
)
我将使用以下 shorthands:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
def cpu: Unit => Prg[Unit] = // constant pure unit
ignore => Free.pure(())
现在p1(0)
解释如下:
p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu
等等...您会看到内存消耗与 N
线性相关。我们只是将评估从堆栈移到了堆。
带走:为了保持Free
内存友好,将递归保持在"tail position",也就是在[=的右边29=](或map
)。
旁白: 不需要翻译成 Trampoline
,因为 Free
已经被 trampolined。您可以直接解释为 Id
并使用 foldMapRec
进行堆栈安全解释:
val idInterpreter = new (Lang ~> Id) {
def apply[A](cmd: Lang[A]): Id[A] = cmd match {
case PrintLine(l) => println(l)
}
}
p0(0).foldMapRec(idInterpreter)
这将重新获得一些内存(但不会使问题消失)。
假设我正在尝试实现一种非常简单的领域特定语言,只有一个操作:
printLine(line)
然后我想写一个程序,它接受一个整数 n
作为输入,如果 n
可以被 10k 整除则打印一些东西,然后用 n + 1
调用它自己,直到 n
达到某个最大值 N
.
省略所有由 for-comprehensions 引起的语法噪音,我想要的是:
@annotation.tailrec def p(n: Int): Unit = {
if (n % 10000 == 0) printLine("line")
if (n > N) () else p(n + 1)
}
本质上,这将是一种 "fizzbuzz"。
以下是使用 Scalaz 7.3.0-M7 中的 Free monad 实现此功能的一些尝试:
import scalaz._
object Demo1 {
// define operations of a little domain specific language
sealed trait Lang[X]
case class PrintLine(line: String) extends Lang[Unit]
// define the domain specific language as the free monad of operations
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
// lift operations into the free monad
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
// write a program that is just a loop that prints current index
// after every few iteration steps
val mod = 100000
val N = 1000000
// straightforward syntax: deadly slow, exits with OutOfMemoryError
def p0(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- (if (i > N) ret else p0(i + 1))
} yield ()
// Same as above, but written out without `for`
def p1(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
}
// Same as above, with `map` attached to recursive call
def p2(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p2(i + 1).map{ ignore2 => () })
}
// Same as above, but without the `map`; performs ok.
def p3(i: Int): Prog[Unit] = {
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
if (i > N) ret else p3(i + 1)
}
}
// Variation of the above; Ok.
def p4(i: Int): Prog[Unit] = (for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
} yield ()).flatMap{ ignored2 =>
if (i > N) ret else p4(i + 1)
}
// try to use the variable returned by the last generator after yield,
// hope that the final `map` is optimized away (it's not optimized away...)
def p5(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
stopHere <- (if (i > N) ret else p5(i + 1))
} yield stopHere
// define an interpreter that translates the programs into Trampoline
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case PrintLine(l) => Trampoline.delay(println(l))
}
}
// try it out
def main(args: Array[String]): Unit = {
println("\n p0")
p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p1")
p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p2")
p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p3")
p3(0).foldMap(interpreter).run // ok
println("\n p4")
p4(0).foldMap(interpreter).run // ok
println("\n p5")
p5(0).foldMap(interpreter).run // OutOfMemory
}
}
不幸的是,直截了当的翻译 (p0
) 似乎 运行 具有某种 O(N^2) 的开销,并因 OutOfMemoryError 而崩溃。问题似乎是 for
-comprehension 在对 p0
的递归调用之后附加了一个 map{x => ()}
,这会强制 Free
monad 填充整个内存并提醒 "finish 'p0' and then do nothing"。
如果我手动 "unroll" 理解 for
,并显式写出最后的 flatMap
(如 p3
和 p4
),那么问题就消失了,并且一切 运行 都很顺利。然而,这是一个非常脆弱的解决方法:如果我们简单地向它附加一个 map(id)
,程序的行为会发生巨大变化,并且这个 map(id)
在代码中甚至不可见,因为它是生成的自动被for
-理解。
在这个较旧的post这里:https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/
人们反复建议将递归调用包装到 suspend
中。这是 Applicative
实例和 suspend
:
import scalaz._
// Essentially same as in `Demo1`, but this time with
// an `Applicative` and an explicit `Suspend` in the
// `for`-comprehension
object Demo2 {
sealed trait Lang[H]
case class Const[H](h: H) extends Lang[H]
case class PrintLine[H](line: String) extends Lang[H]
implicit object Lang extends Applicative[Lang] {
def point[A](a: => A): Lang[A] = Const(a)
def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
case Const(x) => {
f match {
case Const(ab) => Const(ab(x))
case _ => throw new Error
}
}
case PrintLine(l) => PrintLine(l)
}
}
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
val mod = 100000
val N = 2000000
// try to suspend the entire second generator
def p7(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- Free.suspend(if (i > N) ret else p7(i + 1))
} yield ()
// try to suspend the recursive call
def p8(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- if (i > N) ret else Free.suspend(p8(i + 1))
} yield ()
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case Const(x) => Trampoline.done(x)
case PrintLine(l) =>
(Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
}
}
def main(args: Array[String]): Unit = {
p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
p8(0).foldMap(interpreter).run // same...
}
}
插入 suspend
并没有真正帮助:它仍然很慢,并且在 OutOfMemoryError
时崩溃。
我应该以某种不同的方式使用 suspend
吗?
也许有一些纯粹的语法补救措施可以使用 for-comprehensions 而不会在最后生成 map
?
如果有人能指出我在这里做错了什么,以及如何修复它,我将不胜感激。
Scala 编译器添加的多余 map
将递归从尾部位置移动到 非尾部 位置。免费的 monad 仍然使这个堆栈安全,但是 space 复杂度变为 O(N) 而不是 O(1)。 (具体还是不O(N2).)
是否可以让 scalac
优化 map
是一个单独的问题(我不知道答案)。
我将尝试说明在解释 p1
与 p3
时发生了什么。 (我将忽略对 Trampoline
的翻译,这是多余的(见下文)。)
p3
(即没有额外的map
)
让我使用以下 shorthand:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => if (i > N) ret else p3(i + 1)
现p3(0)
解读如下
p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)
等等...您会看到任何时候所需的内存量都不会超过某个常量上限。
p1
(即有额外的 map
)
我将使用以下 shorthands:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
def cpu: Unit => Prg[Unit] = // constant pure unit
ignore => Free.pure(())
现在p1(0)
解释如下:
p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu
等等...您会看到内存消耗与 N
线性相关。我们只是将评估从堆栈移到了堆。
带走:为了保持Free
内存友好,将递归保持在"tail position",也就是在[=的右边29=](或map
)。
旁白: 不需要翻译成 Trampoline
,因为 Free
已经被 trampolined。您可以直接解释为 Id
并使用 foldMapRec
进行堆栈安全解释:
val idInterpreter = new (Lang ~> Id) {
def apply[A](cmd: Lang[A]): Id[A] = cmd match {
case PrintLine(l) => println(l)
}
}
p0(0).foldMapRec(idInterpreter)
这将重新获得一些内存(但不会使问题消失)。