如何定义 ContextManager 协议

How to define a ContextManager Protocol

我正在尝试使用类型提示来指定 API 在实现连接器 class(在本例中为代理)时要遵循。

我想指定这样的 class(es) 应该是上下文管理器

我该怎么做?

让我改写得更清楚:我如何定义 Broker class 以便它表明它的具体实现,例如Rabbit class,必须是上下文管理器吗?

有没有实用的方法?我是否必须指定 __enter____exit__ 并仅继承 Protocol

ContextManager继承就够了吗?

顺便问一下,我应该使用@runtime还是@runtime_checkable? (我的 VScode linter 似乎无法在 typing 中找到它们。我正在使用 python 3 7.5)

我知道如何使用 ABC 来做到这一点,但我想学习如何使用协议定义(我已经很好地使用它,但它们不是上下文管理器)。

我不知道如何使用 ContextManager 类型。到目前为止,我还没能从官方文档中找到很好的例子。

目前我想出了

from typing import Protocol, ContextManager, runtime, Dict, List


@runtime
class Broker(ContextManager):
    """
    Basic interface to a broker.
    It must be a context manager
    """

    def publish(self, data: str) -> None:
        """
        Publish data to the topic/queue
        """
        ...

    def subscribe(self) -> None:
        """
        Subscribe to the topic/queue passed to constructor
        """
        ...

    def read(self) -> str:
        """
        Read data from the topic/queue
        """
        ...

实现是

@implements(Broker)
class Rabbit:
    def __init__(self,
            url: str,
            queue: str = 'default'):
        """
        url: where to connect, i.e. where the broker is
        queue: the topic queue, one only
        """
        # self.url = url
        self.queue = queue
        self.params = pika.URLParameters(url)
        self.params.socket_timeout = 5

    def __enter__(self):
        self.connection = pika.BlockingConnection(self.params) # Connect to CloudAMQP
        self.channel = self.connection.channel() # start a channel
        self.channel.queue_declare(queue=self.queue) # Declare a queue
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.connection.close()

    def publish(self, data: str):
        pass  # TBD

    def subscribe(self):
        pass  # TBD

    def read(self):
        pass  # TBD

注意:implements 装饰器工作正常(它来自以前的项目),它检查 class 是给定协议 class 的子

简短回答 -- 您的 Rabbit 实现实际上没有问题。只需添加一些类型提示以表明 __enter__ returns 是其自身的一个实例以及 __exit__ returns None__exit__ 参数的类型实际上并不重要。


更长的答案:

每当我不确定某种类型 is/what 某种协议到底是什么时,检查 TypeShed 通常很有帮助,它是标准库(和一些第 3 方库)的类型提示集合。

例如,here is the definition of typing.ContextManager。我已将其复制到下面:

from types import TracebackType

# ...snip...

_T_co = TypeVar('_T_co', covariant=True)  # Any type covariant containers.

# ...snip...

@runtime_checkable
class ContextManager(Protocol[_T_co]):
    def __enter__(self) -> _T_co: ...
    def __exit__(self, __exc_type: Optional[Type[BaseException]],
                 __exc_value: Optional[BaseException],
                 __traceback: Optional[TracebackType]) -> Optional[bool]: ...

通过阅读本文,我们知道了一些事情:

  1. 这个类型是一个协议,这意味着任何在上面给定的签名之后恰好实现 __enter____exit__ 的类型都是 [=20= 的有效子类型] 而无需显式继承它。

  2. 这种类型是运行时可检查的,这意味着 isinstance(my_manager, ContextManager) 也可以,如果你出于任何原因想要这样做的话。

  3. __exit__的参数名都是以两个下划线为前缀。这是类型检查器用来指示那些参数仅是位置参数的约定:在 __exit__ 上使用关键字参数不会进行类型检查。实际上,这意味着您可以随意命名自己的 __exit__ 参数,同时仍然遵守协议。

因此,将它们放在一起,这是仍然进行类型检查的 ContextManager 的最小可能实现:

from typing import ContextManager, Type, Generic, TypeVar

class MyManager:
    def __enter__(self) -> str:
        return "hello"

    def __exit__(self, *args: object) -> None:
        return None

def foo(manager: ContextManager[str]) -> None:
    with manager as x:
        print(x)        # Prints "hello"
        reveal_type(x)  # Revealed type is 'str'

# Type checks!
foo(MyManager())



def bar(manager: ContextManager[int]) -> None: ...

# Does not type check, since MyManager's `__enter__` doesn't return an int
bar(MyManager())

一个不错的小技巧是,如果我们实际上不打算使用参数,我们实际上可以使用非常懒惰的 __exit__ 签名。毕竟,如果 __exit__ 基本上接受任何东西,就不存在类型安全问题。

(更正式地说,符合 PEP 484 的类型检查器将尊重函数在其参数类型方面是逆变的)。

当然,如果需要,您可以指定完整类型。例如,采用您的 Rabbit 实现:

# So I don't have to use string forward references
from __future__ import annotations
from typing import Optional, Type
from types import TracebackType

# ...snip...

@implements(Broker)
class Rabbit:
    def __init__(self,
            url: str,
            queue: str = 'default'):
        """
        url: where to connect, i.e. where the broker is
        queue: the topic queue, one only
        """
        # self.url = url
        self.queue = queue
        self.params = pika.URLParameters(url)
        self.params.socket_timeout = 5

    def __enter__(self) -> Rabbit:
        self.connection = pika.BlockingConnection(params) # Connect to CloudAMQP
        self.channel = self.connection.channel() # start a channel
        self.channel.queue_declare(queue=self.queue) # Declare a queue
        return self

    def __exit__(self,
                 exc_type: Optional[Type[BaseException]],
                 exc_value: Optional[BaseException],
                 traceback: Optional[TracebackType],
                 ) -> Optional[bool]:
        self.connection.close()

    def publish(self, data: str):
        pass  # TBD

    def subscribe(self):
        pass  # TBD

    def read(self):
        pass  # TBD

回答新编辑的问题:

How can I define the Broker class so that it indicates that its concrete implementations, e.g. the Rabbit class, must be context managers?

Is there a practical way? Do I have to specify enter and exit and just inherit from Protocol?

Is it enough to inherit from ContextManager?

有两种方法:

  1. 重新定义 __enter____exit__ 函数,从 ContextManager 复制原始定义。
  2. 创建 Broker 子类 ContextManager 和 Protocol。

如果您仅对 ContextManager 进行子类化,那么您所做的只是让 Broker 或多或少地继承 ContextManager 中碰巧具有默认实现的任何方法。

PEP 544: Protocols and structural typing goes into more details about this. The mypy docs on Protocols have a more user-friendly version of this. For example, see the section on defining subprotocols and subclassing protocols.

By the way, should I use @runtime or @runtime_checkable? (My VScode linter seems to have problems finding those in typing. I am using python 3 7.5)

应该是runtime_checkable.

也就是说,Protocol 和 runtime_checkable 实际上都是在 3.8 版中添加到 Python 中的,这可能就是您的 linter 不满意的原因。

如果您想在旧版本的 Python 中同时使用这两者,您需要 pip install typing-extensions,输入类型的官方反向端口。

安装完成后,您可以from typing_extensions import Protocol, runtime_checkable