使用 Kotlin 协程构建挂起资源池

Building a suspending resource pool using Kotlin Coroutines

我正在研究如何构建一个在资源可用之前暂停的资源池。 这个池可以是任何东西,从允许 API 调用的插槽到图像处理器的套接字连接,或者在这个示例代码中,可以是可用服务计数器的整数:

val pool = Pool(1, 2, 3)


repeat(100) {
    launch {
        pool.borrow {
            println(it)
            delay(1000)
        }
    }
}

借用功能从资源池中获取一项,并在完成后returns它。

到目前为止我所做的是:

class Pool(vararg resource: Int) {
    private val mutex = Mutex()
    private val list = mutableListOf(*resource.toTypedArray())
    private var available: CompletableDeferred<Boolean>? = null

    suspend fun add(value: Int) {
        mutex.withLock {
            list.add(value)
            available?.complete(true)
        }
    }

    suspend fun rem(): Int {
        mutex.withLock {
            if (list.size == 1) {
                available = CompletableDeferred()
            } else {
                available?.await()
            }
            return list.removeLast()
        }
    }

    suspend fun borrow(handler: suspend (Int) -> Unit) {
        val borrowed = rem()
        try {
            handler(borrowed)
        } finally {
            add(borrowed)
        }
    }
}

addremove 都在 mutex.withLock { ... } 中运行,以确保当两个线程试图修改同一个列表时我们不会得到并发修改异常。

最初,available 为空,因此任何 available?.await() 都将由于空检查而被跳过。 从列表 (list.size == 1) 中删除最后一项后,可用设置为 CompletableDeferred,这意味着 available?.await() 现在将暂停,直到调用 available?.complete(true)

添加更多项目后,将再次调用 available?.complete(true),这将阻止 available?.await() 挂起。如果您尝试删除一个项目,此代码将死锁,available?.await() 将挂起,这意味着互斥体永远不会退出,从而阻止 add 再次被调用以允许 rem 再次取消挂起。

如果我将 available?.await() 移到互斥量之前,两个线程将尝试从可能只有一个项目的列表中删除一个项目,第二个线程将遇到 List is empty 错误.

这样的挂起资源池的正确实现方式是什么?

我认为使用频道构建起来会容易得多。您甚至不需要 add 挂起,因为当 Channel 具有无限容量时,您可以从任何线程安全地使用 trySend

我还建议使 borrow 内联并从其函数参数中删除 suspend。这避免了在使用时分配函数包装器。由于它是内联的,即使它不是暂停 lambda,您仍然可以在传递给它的 lambda 中调用暂停函数。

class Pool<T>(vararg initialResources: T) {
    private val channel = Channel<T>(Channel.UNLIMITED)
    
    init {
        for (res in initialResources) {
            channel.trySend(res)
        }
    }

    fun add(value: T) {
       channel.trySend(value)
    }

    suspend fun rem(): T = channel.receive()

    suspend inline fun borrow(handler: (T) -> Unit) {
        val borrowed = rem()
        try {
            handler(borrowed)
        } finally {
            add(borrowed)
        }
    }
}

还有另一种方法,使用信号量

class ResourcePool<T>(vararg initialResources: T) {

    private val mutex = Mutex()
    private val semaphore = Semaphore(permits = initialResources.size)
    private val resources = initialResources.toMutableList()

    suspend operator fun invoke(handler: suspend (T) -> Unit) {
        semaphore.withPermit {
            val borrowed = mutex.withLock { resources.removeLast() }
            try {
                handler(borrowed)
            } finally {
                mutex.withLock { resources.add(borrowed) }
            }
        }
    }
}

运行 针对它的一些基准:

 @ExperimentalTime
    fun main() = runBlocking {
        val r = ResourcePool(*(1L..5000L).toList().toTypedArray())
        val totals = mutableListOf<Duration>()
        repeat(500) {
            val l = mutableListOf<Job>()
            totals.add(measureTime {
                repeat(100_000) { l.add(launch { r {  } }) }
                l.forEach {
                    it.join()
                }
            })
            println(totals.last())
        }
        println("${totals.subList(totals.size / 2, totals.lastIndex).average()} (average)")
    }

42.284184ms (average)

然后我们采用等效的 Channel 实现,

class ResourcePoolWithChannel<T>(vararg initialResources: T) {

    private val channel = Channel<T>(Channel.UNLIMITED)
    init {
        initialResources.forEach {
            channel.trySend(it)
        }
    }

    suspend operator fun invoke(handler: suspend (T) -> Unit) {
        val borrowed = channel.receive()
        try {
            handler(borrowed)
        } finally {
            channel.send(borrowed)
        }
    }
}

和运行针对它的相同基准:

@ExperimentalTime
fun main() = runBlocking {
    val r = ResourcePoolWithChannel(*(1L..5000L).toList().toTypedArray())
    val totals = mutableListOf<Duration>()
    repeat(500) {
        val l = mutableListOf<Job>()
        totals.add(measureTime {
            repeat(100_000) { l.add(launch { r {  } }) }
            l.forEach {
                it.join()
            }
        })
        println(totals.last())
    }
    println("${totals.subList(totals.size / 2, totals.lastIndex).average()} (average)")
}

56.524879ms (average)

Semaphore + Mutex 实现比 Channel 实现稍快。