如何在协议上定义具有协变 return 类型的可调用属性?

How to define callable attribute with covariant return type on protocol?

通常理解为 return 类型的可调用对象是 co 变体。当用可调用的 属性 定义类型时,我确实可以使 return 类型通用且协变:

from typing import TypeVar, Callable, Generic, Sequence
from dataclasses import dataclass

R = TypeVar("R", covariant=True)

@dataclass
class Works(Generic[R]):
    call: Callable[[], R]  # returns an R *or subtype*

w: Works[Sequence] = Works(lambda: [])  # okay: list is subtype of Sequence

然而,这对 Protocol 不起作用。当我以相同的方式为类型定义 Protocol 时,MyPy 拒绝了这一点——它坚持认为 return 类型必须是 invariant.

from typing import TypeVar, Callable, Protocol

R = TypeVar("R", covariant=True)

class Fails(Protocol[R]):
    attribute: Callable[[], R]
$ python -m mypy so_testbed.py --pretty
so_testbed.py:5: error: Covariant type variable "R" used in protocol where invariant one is expected
    class Fails(Protocol[R]):
    ^
Found 1 error in 1 file (checked 1 source file)

如何为尊重 R 协方差的具体类型正确定义 Protocol

Protocol 显然无法实现您的尝试 - 请参阅 PEP 544 中的以下内容:


Covariant subtyping of mutable attributes

Rejected because covariant subtyping of mutable attributes is not safe. Consider this example:

class P(Protocol):
    x: float

def f(arg: P) -> None:
    arg.x = 0.42

class C:
    x: int

c = C()
f(c)  # Would typecheck if covariant subtyping
      # of mutable attributes were allowed.
c.x >> 1  # But this fails at runtime

It was initially proposed to allow this for practical reasons, but it was subsequently rejected, since this may mask some hard to spot bugs.


由于您的 attribute 是可变成员 - 您不能让它与 R 相关。

一种可能的替代方法是将 attribute 替换为以下方法:

class Passes(Protocol[R]):
    @property
    def attribute(self) -> Callable[[], R]:
        pass

它通过了类型检查 - 但它是一个不灵活的解决方案。

如果您需要可变的协变成员,Protocol 不是正确的选择。

正如@Daniel Kleinstein 所指出的,您不能通过协变变量来参数化协议类型,因为它用于可变属性。

另一种方法是将变量分成两个(协变和不变)并在两个协议中使用它们(replace CallableProtocol)。

from typing import TypeVar, Callable, Protocol

R_cov = TypeVar("R_cov", covariant=True)
R_inv = TypeVar("R_inv")

class CallProto(Protocol[R_cov]):
    def __call__(self) -> R_cov: ...
    
class Fails(Protocol[R_inv]):
    attribute: CallProto[R_inv]