通过超时取消异步迭代器

Cancel asynchronous iterator by timeout

我有一个进程 运行 asyncio 应该 运行 永远。

我可以使用 ProcessIterator 与该进程交互,ProcessIterator 可以(此处省略)将数据发送到标准输入并从标准输出中获取数据。

我可以使用 async for fd, data in ProcessIterator(...): 访问数据。

现在的问题是这个异步迭代器的执行必须有时间限制。如果时间 运行s 结束,则调用 timeout() 函数, 但异常并非源自通知超时的 __anext__ 函数。

如何在异步迭代器中引发此异常? 我找不到调用 awaitable.throw(something) 或类似方法的方法。

class ProcessIterator:
    def __init__(self, process, loop, run_timeout):
        self.process = process
        self.loop = loop

        self.run_timeout = run_timeout

        # set the global timer
        self.overall_timer = self.loop.call_later(
            self.run_timeout, self.timeout)

    def timeout(self):
        # XXX: how do i pass this exception into the iterator?
        raise ProcTimeoutError(
            self.process.args,
            self.run_timeout,
            was_global,
        )

    async def __aiter__(self):
        return self

    async def __anext__(self):    
        if self.process.exited:
            raise StopAsyncIteration()

        else:
            # fetch output from the process asyncio.Queue()
            entry = await self.process.output_queue.get()
            if entry == StopIteration:
                raise StopAsyncIteration()

            return entry

现在异步迭代器的用法大致是:

async def test_coro(loop):
    code = 'print("rofl"); time.sleep(5); print("lol")'

    proc = Process([sys.executable, '-u', '-c', code])

    await proc.create()

    try:
        async for fd, line in ProcessIterator(proc, loop, run_timeout=1):
            print("%d: %s" % (fd, line))

    except ProcessTimeoutError as exc:
        # XXX This is the exception I'd like to get here! How can i throw it?
        print("timeout: %s" % exc)

    await proc.wait()

tl;dr:我怎样才能抛出一个定时异常,使其源自异步迭代器?

编辑:添加了解决方案 2

解决方案一:

timeout()回调能否将ProcTimeoutError异常存储在实例变量中?然后 __anext__() 可以检查实例变量并在设置时引发异常。

class ProcessIterator:
    def __init__(self, process, loop, run_timeout):
        self.process = process
        self.loop = loop
        self.error = None

        self.run_timeout = run_timeout

        # set the global timer
        self.overall_timer = self.loop.call_later(
            self.run_timeout, self.timeout)

    def timeout(self):
        # XXX: set instance variable
        self.error = ProcTimeoutError(
                         self.process.args,
                         self.run_timeout,
                         was_global
                     )

    async def __aiter__(self):
        return self

    async def __anext__(self): 
        # XXX: if error is set, then raise the exception
        if self.error:
            raise self.error

        elif self.process.exited:
            raise StopAsyncIteration()

        else:
            # fetch output from the process asyncio.Queue()
            entry = await self.process.output_queue.get()
            if entry == StopIteration:
                raise StopAsyncIteration()

            return entry

方案二:

将例外放在 process.output_queue。

....
def timeout(self):
    # XXX: set instance variable
    self.process.ouput_queue.put(ProcTimeoutError(
                                     self.process.args,
                                     self.run_timeout,
                                     was_global
                                 ))

....

# fetch output from the process asyncio.Queue()
entry = await self.process.output_queue.get()
if entry == StopIteration:
    raise StopAsyncIteration()

elif entry = ProcTimeoutError:
    raise entry
....

如果队列中可能有条目,请使用优先级队列。为 ProcTimeoutError 分配比其他条目更高的优先级,例如 (0, ProcTimeoutError) vs (1, other_entry).

您可以使用 get_nowait,它将 return 进入或立即抛出 QueueEmpty。将其包装在 self.error 上的 while 循环中并进行一些异步睡眠应该可以解决问题。类似于:

async def __anext__(self):    
    if self.process.exited:
        raise StopAsyncIteration()

    else:
        while self.error is None:
            try:
                entry = self.process.output_queue.get_nowait()
                if entry == StopIteration:
                    raise StopAsyncIteration()
                return entry
            except asyncio.QueueEmpty:
                # some sleep to give back control to ioloop
                # since we using nowait
                await asyncio.sleep(0.1)
        else:
            raise self.error

并且作为在 Tornado's Queue.get 超时实现中使用的提示方法:

def get(self, timeout=None):
    """Remove and return an item from the queue.
    Returns a Future which resolves once an item is available, or raises
    `tornado.gen.TimeoutError` after a timeout.
    """
    future = Future()
    try:
        future.set_result(self.get_nowait())
    except QueueEmpty:
        self._getters.append(future)
        _set_timeout(future, timeout)
    return future

请查看 timeout 来自 asyncio 的上下文管理器:

with asyncio.timeout(10):
    async for i in get_iter():
        process(i)

尚未发布,但您可以从 asyncio master branch

复制粘贴实现

这是我现在想到的解决方案。

上游版本见https://github.com/SFTtech/kevinkevin/process.py

它还有行计数和输出超时功能,我从这个例子中去掉了。

class Process:
    def __init__(self, command, loop=None):

        self.loop = loop or asyncio.get_event_loop()

        self.created = False
        self.killed = asyncio.Future()

        self.proc = self.loop.subprocess_exec(
            lambda: WorkerInteraction(self),  # see upstream repo
            *command)

        self.transport = None
        self.protocol = None

    async def create(self):
        self.transport, self.protocol = await self.proc

    def communicate(self, timeout):
        if self.killed.done():
            raise Exception("process was already killed "
                            "and no output is waiting")

        return ProcessIterator(self, self.loop, timeout)

class ProcessIterator:
    """
    Asynchronous iterator for the process output.   
    Use like `async for (fd, data) in ProcessIterator(...):`
    """

    def __init__(self, process, loop, run_timeout):
        self.process = process
        self.loop = loop
        self.run_timeout = run_timeout

        self.overall_timer = None

        if self.run_timeout < INF:
            # set the global timer
            self.overall_timer = self.loop.call_later(
                self.run_timeout,
                functools.partial(self.timeout, was_global=True))

    def timeout(self):
        if not self.process.killed.done():
            self.process.killed.set_exception(ProcTimeoutError(
                self.process.args,
                self.run_timeout,
            ))

    async def __aiter__(self):
        return self

    async def __anext__(self):
        # either the process exits,
        # there's an exception (process killed, timeout, ...)
        # or the queue gives us the next data item.
        # wait for the first of those events.
        done, pending = await asyncio.wait(
            [self.process.protocol.queue.get(), self.process.killed],
            return_when=asyncio.FIRST_COMPLETED)

        # at least one of them is done now:
        for future in done:
            # if something failed, cancel the pending futures
            # and raise the exception
            # this happens e.g. for a timeout.
            if future.exception():
                for future_pending in pending:
                    future_pending.cancel()

                # kill the process before throwing the error!
                await self.process.pwn()
                raise future.exception()

            # fetch output from the process
            entry = future.result()

            # it can be stopiteration to indicate the last data chunk
            # as the process exited on its own.
            if entry == StopIteration:
                if not self.process.killed.done():
                    self.process.killed.set_result(entry)

                    # raise the stop iteration
                    await self.stop_iter(enough=False)

            return entry

        raise Exception("internal fail: no future was done!")

    async def stop_iter(self):
        # stop the timer
        if self.overall_timer:
            self.overall_timer.cancel()

        retcode = self.process.returncode()

        raise StopAsyncIteration()

神奇的功能是:

done, pending = await asyncio.wait(
    [self.process.protocol.queue.get(), self.process.killed],
    return_when=asyncio.FIRST_COMPLETED)

发生超时时,可靠地中止队列提取。