Kotlin 挂起函数递归调用

Kotlin suspend function recursive call

突然发现suspend函数的递归调用比不带suspend修饰符调用同一个函数要花更多的时间,所以请考虑下面的代码片段(基本斐波那契数列计算):

suspend fun asyncFibonacci(n: Int): Long = when {
    n <= -2 -> asyncFibonacci(n + 2) - asyncFibonacci(n + 1)
    n == -1 -> 1
    n == 0 -> 0
    n == 1 -> 1
    n >= 2 -> asyncFibonacci(n - 1) + asyncFibonacci(n - 2)
    else -> throw IllegalArgumentException()
}

如果我调用此函数并使用以下代码测量其执行时间:

fun main(args: Array<String>) {
    val totalElapsedTime = measureTimeMillis {
        val nFibonacci = 40

        val deferredFirstResult: Deferred<Long> = async {
            asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
        }
        val deferredSecondResult: Deferred<Long> = async {
            asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
        }

        val firstResult: Long = runBlocking { deferredFirstResult.await() }
        val secondResult: Long = runBlocking { deferredSecondResult.await() }
        val superSum = secondResult + firstResult
        println("${thread()} - Sum of two $nFibonacci'th fibonacci numbers: $superSum")
    }
    println("${thread()} - Total elapsed time: $totalElapsedTime millis")
}

我观察到进一步的结果:

commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 7704 millis
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 7741 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 7816 millis

但是如果我从 asyncFibonacci 函数中删除 suspend 修饰符,我将得到这样的结果:

commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 1179 millis
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 1201 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 1250 millis

我知道最好用 tailrec 重写这样的函数,这样会增加它的执行时间 apx。几乎是 100 次,但无论如何,这个 suspend 关键字将执行速度从 1 秒降低到 8 秒?

suspend标记递归函数是不是完全愚蠢的想法?

问题出在suspend函数生成的Java字节码。虽然非 suspend 函数只是像我们期望的那样生成字节码:

public static final long asyncFibonacci(int n) {
  long var10000;
  if (n <= -2) {
     var10000 = asyncFibonacci(n + 2) - asyncFibonacci(n + 1);
  } else if (n == -1) {
     var10000 = 1L;
  } else if (n == 0) {
     var10000 = 0L;
  } else if (n == 1) {
     var10000 = 1L;
  } else {
     if (n < 2) {
        throw (Throwable)(new IllegalArgumentException());
     }

     var10000 = asyncFibonacci(n - 1) + asyncFibonacci(n - 2);
  }

  return var10000;
}

当您添加 suspend 关键字时,反编译的 Java 源代码为 165 行 - 大得多。您可以通过 Tools -> Kotlin -> 查看 IntelliJ 中的字节码和反编译的 Java 代码显示 Kotlin 字节码(然后点击页面顶部的反编译)。虽然很难说出 Kotlin 编译器在函数中到底做了什么,但看起来它正在做大量的协程状态检查——考虑到协程可以随时挂起,这是有道理的。

所以作为结论,我会说每个 suspend 方法调用都比非 suspend 调用重得多。这不仅适用于递归函数,而且可能对它们产生最坏的结果。

Is it totally stupid idea to mark recursive functions with suspend?

除非您有充分的理由这样做 - 是

作为介绍性评论,您的测试代码设置过于复杂。这个简单得多的代码在强调 suspend fun 递归方面实现了相同的效果:

fun main(args: Array<String>) {
    launch(Unconfined) {
        val nFibonacci = 37
        var sum = 0L
        (1..1_000).forEach {
            val took = measureTimeMillis {
                sum += suspendFibonacci(nFibonacci)
            }
            println("Sum is $sum, took $took ms")
        }
    }
}

suspend fun suspendFibonacci(n: Int): Long {
    return when {
        n >= 2 -> suspendFibonacci(n - 1) + suspendFibonacci(n - 2)
        n == 0 -> 0
        n == 1 -> 1
        else -> throw IllegalArgumentException()
    }
}

我试图通过编写一个简单的函数来重现它的性能,该函数近似于 suspend 函数为实现可挂起性而必须做的事情:

val COROUTINE_SUSPENDED = Any()

fun fakeSuspendFibonacci(n: Int, inCont: Continuation<Unit>): Any? {
    val cont = if (inCont is MyCont && inCont.label and Integer.MIN_VALUE != 0) {
        inCont.label -= Integer.MIN_VALUE
        inCont
    } else MyCont(inCont)
    val suspended = COROUTINE_SUSPENDED
    loop@ while (true) {
        when (cont.label) {
            0 -> {
                when {
                    n >= 2 -> {
                        cont.n = n
                        cont.label = 1
                        val f1 = fakeSuspendFibonacci(n - 1, cont)!!
                        if (f1 === suspended) {
                            return f1
                        }
                        cont.data = f1
                        continue@loop
                    }
                    n == 1 || n == 0 -> return n.toLong()
                    else -> throw IllegalArgumentException("Negative input not allowed")
                }
            }
            1 -> {
                cont.label = 2
                cont.f1 = cont.data as Long
                val f2 = fakeSuspendFibonacci(cont.n - 2, cont)!!
                if (f2 === suspended) {
                    return f2
                }
                cont.data = f2
                continue@loop
            }
            2 -> {
                val f2 = cont.data as Long
                return cont.f1 + f2
            }
            else -> throw AssertionError("Invalid continuation label ${cont.label}")
        }
    }
}

class MyCont(val completion: Continuation<Unit>) : Continuation<Unit> {
    var label = 0
    var data: Any? = null
    var n: Int = 0
    var f1: Long = 0

    override val context: CoroutineContext get() = TODO("not implemented")
    override fun resumeWithException(exception: Throwable) = TODO("not implemented")
    override fun resume(value: Unit) = TODO("not implemented")
}

你必须用

调用这个
sum += fakeSuspendFibonacci(nFibonacci, InitialCont()) as Long

其中 InitialCont

class InitialCont : Continuation<Unit> {
    override val context: CoroutineContext get() = TODO("not implemented")
    override fun resumeWithException(exception: Throwable) = TODO("not implemented")
    override fun resume(value: Unit) = TODO("not implemented")
}

基本上,要编译 suspend fun,编译器必须将其主体变成状态机。每次调用还必须创建一个对象来保存机器的状态。当您恢复时,状态对象会告诉您转到哪个状态处理程序。以上还不是全部,真正的代码更复杂。

在解释模式下(java -Xint),我得到的性能几乎与实际suspend fun相同,而且比启用JIT 的真实模式快不到两倍。相比之下,"direct" 函数的实现速度大约快 10 倍。这意味着显示的代码解释了可挂起性开销的很大一部分。