如何 asyncio.gather 块中的任务 + 使用具有 TCP 连接限制的信号量?

How to asyncio.gather tasks in chunks + use semaphore with TCP connections limit?

我有一个大的 (1M) 数据库结果集,我想为每一行调用一个 REST API。

API 可以接受批处理请求,但我不确定如何分割 rows 生成器以便每个任务处理一个行列表,比如 10。我宁愿不预先阅读所有行并坚持使用发电机。

适应 my_function 在一个 http 请求中发送列表很容易,但是 asyncio.gather 呢?也许 itertools 之一可以提供帮助。


async def main(rows):
    async with aiohttp.ClientSession() as session:
        tasks = [my_function(row, session) for row in rows]
        return await asyncio.gather(*tasks)

rows = <generator of database rows>
results = asyncio.run(main(rows))



aiohttpasyncio 很棒但很难理解 - 我同意这个 post:

asyncio keeps changing all the time, so be wary of old Stack Overflow answers. Many of them are not up to date with the current best practices


我刚刚尝试使用 asyncio.BoundedSemaphore(100) 并且内存使用量大致相同 (45GB) - 不确定它是否比连接限制有任何好处



async def worker(queue, session, results):
    while True:
        row = await queue.get()
        results.append(await my_function(row, session))
        # Mark the item as processed, allowing queue.join() to keep
        # track of remaining work and know when everything is done.

async def main(rows):
    N_WORKERS = 50
    queue = asyncio.Queue(N_WORKERS)
    results = []
    async with aiohttp.ClientSession() as session:
        # create 50 workers and feed them tasks
        workers = [asyncio.create_task(worker(queue, session, results))
                   for _ in range(N_WORKERS)]
        # Feed the database rows to the workers. The fixed-capacity of the
        # queue ensures that we never hold all rows in the memory at the
        # same time. (When the queue reaches full capacity, this will block
        # until a worker dequeues an item.)
        async for row in rows:
            await queue.put(row)
        # Wait for all enqueued items to be processed.
        await queue.join()
    # The workers are now idly waiting for the next queue item and we
    # no longer need them.
    for worker in workers:
    return results

请注意,rows 应该是一个异步生成器。如果是普通的生成器,很可能会阻塞事件循环,成为瓶颈。如果您的数据库不支持异步接口,请参阅 以了解通过 运行 在专用线程中将阻塞生成器转换为异步的方法。

要批处理项目,您可以构建一个中间列表并发送它。或者,您可以使用出色的 aiostream library which comes with the chunks 运算符来执行此操作:

async with aiostream.stream.chunks(rows, 10).stream() as chunks:
    async for batch in chunks:
         await queue.put(batch)  # enqueue a batch of 10 rows

很多 感谢@user4815162342 指出了正确的方向。

这里是一个完整的工作示例,实现批处理、连接限制和排队前提是您从 开始使用异步生成器。 更新:如果您不是从异步生成器开始,请参阅之前对同步到异步转换器的回答。


# as per previous comment, match with connections so that each worker feeds one connection  
N_WORKERS = 400  

async def my_function(row, session):
    async with session.post(my_url,
                            json=json.dumps(row, default=str)) as response:
        return await response.json()

async def worker(queue, session, results):
    while True:
        row = await queue.get()
        results.append(await my_function(row, session))

async def main(rows):
    results = []  # better here than global
    queue = asyncio.Queue(N_WORKERS)

    async with aiohttp.ClientSession(
            connector=aiohttp.TCPConnector(limit=TCP_CONNECTIONS)) as session:

        workers = [asyncio.create_task(worker(queue, session, results))
                   for _ in range(N_WORKERS)]

        async with aiostream.stream.chunks(rows, BATCH_SIZE).stream() as chunks:
            async for batch in chunks:
                await queue.put(batch)

        await queue.join()

    for w in workers:

    return results

results = asyncio.run(main(rows))


def chunks(iterator, n):
    return (chain([first], islice(iterator, 0, n - 1))
            for first in iterator)

async def main(rows):
    async with aiohttp.ClientSession(
            connector=aiohttp.TCPConnector(limit=TCP_CONNECTIONS)) as session:

        batches = [my_function(list(batch)) for batch in chunks(rows, BATCH_SIZE)]

    return await asyncio.gather(*batches)

results = asyncio.run(main(rows))