trio.Nursery 任务 added/removed 时回调

Callback when tasks are added/removed in trio.Nursery

问题

是否有任何官方/更好的方法来为 nursery 的任务添加回调 add/remove,而不是包装 trio._core._run.GLOBAL_RUN_CONTEXT.runner.tasks 集合?


详情

为了研究 trio.Nursery 的内部结构(并且只是为了好玩),我试图计算一个特定托儿所的每个任务 add/remove 的任务计数。

我发现,trio 在 trio._core._run.GLOBAL_RUN_CONTEXT.runner 中使用单个 Set 作为属性 tasks

所以我尝试包装 Set 因为替换它会导致退出时处理程序错误。

演示代码:

from collections.abc import MutableSet
from typing import Iterator
import random
import time

import trio


class SetWithCallback(MutableSet):
    """
    Class to wrap around existing set for adding callback support
    """
    def __init__(self, original_set: set, add_callback=None, remove_callback=None):
        self.inner_set = original_set

        self.add_callback = add_callback if add_callback else lambda _: None
        self.remove_callback = remove_callback if remove_callback else lambda _: None

    def add(self, value) -> None:
        self.inner_set.add(value)
        self.add_callback(value)

    def remove(self, value) -> None:
        self.inner_set.remove(value)
        self.remove_callback(value)

    def discard(self, value) -> None:
        self.inner_set.discard(value)

    def __contains__(self, x: object) -> bool:
        return x in self.inner_set

    def __len__(self) -> int:
        return len(self.inner_set)

    def __iter__(self) -> Iterator:
        return iter(self.inner_set)


async def dummy_task(task_no, lifetime):
    """
    A very meaningful and serious workload

    Args:
        task_no: Task's ID
        lifetime: Lifetime of task
    """
    print(f"  Task {task_no} started, expecting lifetime of {lifetime}s!")
    start = time.time()
    await trio.sleep(lifetime)
    print(f"  Task {task_no} finished, actual lifetime was {time.time() - start:.6}s!")


async def main():
    # Wrap original tasks set with our new class
    # noinspection PyProtectedMember
    runner = trio._core._run.GLOBAL_RUN_CONTEXT.runner
    # noinspection PyTypeChecker
    runner.tasks = SetWithCallback(runner.tasks)

    async with trio.open_nursery() as nursery:

        # callback to be called on every task add/remove.
        # checks if task belongs to nursery.
        def add_callback(task):
            # do something
            # child tasks count + 1 because given task is yet to be added.
            print(f"Task {id(task)} added, {len(nursery.child_tasks) + 1} in nursery {id(nursery)}")

        def remove_callback(task):
            # do something
            # child tasks count - 1 because given task is yet to be removed from it.
            print(f"Task {id(task)} done, {len(nursery.child_tasks) - 1} in nursery {id(nursery)}")

        # replace default callback to count the task count in nursery.
        runner.tasks.add_callback = add_callback
        runner.tasks.remove_callback = remove_callback

        # spawn tasks
        for n in range(5):
            nursery.start_soon(dummy_task, n, random.randint(2, 5))
            await trio.sleep(1)


if __name__ == '__main__':
    trio.run(main)
Task 2436520969376 added, 1 in nursery 2436520962768
  Task 0 started, expecting lifetime of 3s!
Task 2436520969536 added, 2 in nursery 2436520962768
  Task 1 started, expecting lifetime of 2s!
Task 2436520969856 added, 3 in nursery 2436520962768
  Task 2 started, expecting lifetime of 2s!
  Task 0 finished, actual lifetime was 3.00393s!
Task 2436520969376 done, 2 in nursery 2436520962768
  Task 1 finished, actual lifetime was 2.01102s!
Task 2436520969536 done, 1 in nursery 2436520962768
Task 2436520969536 added, 2 in nursery 2436520962768
  Task 3 started, expecting lifetime of 5s!
  Task 2 finished, actual lifetime was 2.01447s!
Task 2436520969856 done, 1 in nursery 2436520962768
Task 2436520969856 added, 2 in nursery 2436520962768
  Task 4 started, expecting lifetime of 3s!
  Task 4 finished, actual lifetime was 3.01358s!
Task 2436520969856 done, 1 in nursery 2436520962768
  Task 3 finished, actual lifetime was 5.01383s!
Task 2436520969536 done, 0 in nursery 2436520962768
Task 2436520969056 done, -1 in nursery 2436520962768
Task 2436520969216 done, -1 in nursery 2436520962768
Task 2436520968896 done, -1 in nursery 2436520962768

Process finished with exit code 0

结果看起来很有希望,一切都很好。直到我们意识到它确实是 Global Context.

周围的任何其他和核心托儿所 运行 都会经常调用此回调,并给出惊人的任务计数,例如 -1。

我能想到的最佳解决方案是检查任务是否属于托儿所:

        def add_callback(task):
            # do something
            if task in nursery.child_tasks:  # <---- what??
                # child tasks count + 1 because given task is yet to be added.
                print(f"Task {id(task)} added, {len(nursery.child_tasks) + 1} in nursery {id(nursery)}")

        def remove_callback(task):
            # do something
            if task in nursery.child_tasks:
                # child tasks count - 1 because given task is yet to be removed from it.
                print(f"Task {id(task)} done, {len(nursery.child_tasks) - 1} in nursery {id(nursery)}")

但这显然不能检查一个几乎不存在的任务,它只是试图在 nursery 中添加。

您是否考虑过 trio instrumentation 作为解决方案?

from collections import defaultdict

from trio import open_nursery
from trio.abc import Instrument
from trio.lowlevel import add_instrument

task_counts = defaultdict(lambda: 0)

class TaskCountInstrument(Instrument):
    def task_spawned(self, task):
        task_counts[task.parent_nursery] += 1

    def task_exited(self, task):
        task_counts[task.parent_nursery] -= 1


async def main():
    add_instrument(TaskCountInstrument())
    async with open_nursery() as nursery:
        ...