Kotlin 协程进度计数器

Kotlin coroutines progress counter

我正在使用 async/await 发出数千个 HTTP 请求,并希望有一个进度指示器。我以一种天真的方式添加了一个,但注意到当所有请求都完成后,计数器值永远不会达到总数。所以我创建了一个简单的测试,果然,它没有按预期工作:

fun main(args: Array<String>) {
    var i = 0
    val range = (1..100000)
    range.map {
        launch {
            ++i
        }
    }
    println("$i ${range.count()}")
}

输出是这样的,第一个数字总是变化的:

98800 100000

我可能在 JVM/Kotlin 中遗漏了一些关于 concurrency/synchronization 的重要细节,但不知道从哪里开始。有什么建议吗?

更新:我最终按照 Marko 的建议使用了频道:

/**
 * Asynchronously fetches stats for all symbols and sends a total number of requests
 * to the `counter` channel each time a request completes. For example:
 *
 *     val counterActor = actor<Int>(UI) {
 *         var counter = 0
 *         for (total in channel) {
 *             progressLabel.text = "${++counter} / $total"
 *         }
 *     }
 */
suspend fun getAssetStatsWithProgress(counter: SendChannel<Int>): Map<String, AssetStats> {
    val symbolMap = getSymbols()?.let { it.map { it.symbol to it }.toMap() } ?: emptyMap()
    val total = symbolMap.size
    return symbolMap.map { async { getAssetStats(it.key) } }
        .mapNotNull { it.await().also { counter.send(total) } }
        .map { it.symbol to it }
        .toMap()
}

您正在丢失写入,因为 i++ 不是原子操作 - 值必须被读取、递增,然后写回 - 并且您有多个线程读取和写入 i同一时间。 (如果您不为 launch 提供上下文,它默认使用线程池。)

每当两个线程读取相同的值时,您的计数就会减 1,因为它们都会写入该值加一。

以某种方式同步,例如使用 AtomicInteger 解决了这个问题:

fun main(args: Array<String>) {
    val i = AtomicInteger(0)
    val range = (1..100000)
    range.map {
        launch {
            i.incrementAndGet()
        }
    }
    println("$i ${range.count()}") // 100000 100000
}

也不能保证这些后台线程会在您打印结果和程序结束时完成它们的工作 - 您可以通过在 launch 中添加一个非常小的延迟来轻松测试它,几毫秒。有了这个,最好将这一切包装在一个 runBlocking 调用中,这将使主线程保持活动状态,然后等待协程全部完成:

fun main(args: Array<String>) = runBlocking {
    val i = AtomicInteger(0)
    val range = (1..100000)
    val jobs: List<Job> = range.map {
        launch {
            i.incrementAndGet()
        }
    }
    jobs.forEach { it.join() }
    println("$i ${range.count()}") // 100000 100000
}

你读过Coroutines basics了吗?存在与您完全相同的问题:

val c = AtomicInteger()

for (i in 1..1_000_000)
    launch {
        c.addAndGet(i)
    }

println(c.get())

This example completes in less than a second for me, but it prints some arbitrary number, because some coroutines don't finish before main() prints the result.

因为 launch 没有阻塞,所以不能保证所有协程都会在 println 之前完成。您需要使用 async,存储 Deferred 个对象,并 await 让它们完成。

究竟是什么导致你的错误方法失败的解释是次要的:主要的是修复方法。

而不是 async-awaitlaunch,对于此通信模式,您应该有一个 actor,所有 HTTP 作业都将其状态发送到该地址。这将自动处理您所有的并发问题。

这是一些示例代码,取自您在评论中提供的 link,并根据您的用例进行了调整。 Actor 在 UI 上下文中运行并更新 GUI 本身,而不是某些第三方要求它提供计数器值并用它更新 GUI:

import kotlinx.coroutines.experimental.*
import kotlinx.coroutines.experimental.channels.*
import kotlin.system.*
import kotlin.coroutines.experimental.*

object IncCounter

fun counterActor() = actor<IncCounter>(UI) {
    var counter = 0
    for (msg in channel) {
        updateView(++counter)
    }
}

fun main(args: Array<String>) = runBlocking {
    val counter = counterActor()
    massiveRun(CommonPool) {
        counter.send(IncCounter)
    }
    counter.close()
    println("View state: $viewState")
}


// Everything below is mock code that supports the example
// code above:

val UI = newSingleThreadContext("UI")

fun updateView(newVal: Int) {
    viewState = newVal
}

var viewState = 0

suspend fun massiveRun(context: CoroutineContext, action: suspend () -> Unit) {
    val numCoroutines = 1000
    val repeatActionCount = 1000
    val time = measureTimeMillis {
        val jobs = List(numCoroutines) {
            launch(context) {
                repeat(repeatActionCount) { action() }
            }
        }
        jobs.forEach { it.join() }
    }
    println("Completed ${numCoroutines * repeatActionCount} actions in $time ms")
}

运行 它打印

Completed 1000000 actions in 2189 ms
View state: 1000000