我们应该如何键入具有附加属性的可调用对象?

How should we type a callable with additional properties?

作为玩具示例,让我们使用斐波那契数列:

def fib(n: int) -> int:
  if n < 2:
    return 1
  return fib(n - 2) + fib(n - 1)

当然,如果我们尝试这样做,这将挂起计算机:

print(fib(100))

所以我们决定添加记忆功能。为了保持 fib 的逻辑清晰,我们决定不更改 fib 而是通过装饰器添加记忆:

from typing import Callable
from functools import wraps


def remember(f: Callable[[int], int]) -> Callable[[int], int]:
    @wraps(f)
    def wrapper(n: int) -> int:
        if n not in wrapper.memory:
            wrapper.memory[n] = f(n)
        return wrapper.memory[n]

    wrapper.memory = dict[int, int]()
    return wrapper


@remember
def fib(n: int) -> int:
    if n < 2:
        return 1
    return fib(n - 2) + fib(n - 1)

现在没有问题了,如果我们:

print(fib(100))
573147844013817084101

然而,mypy 抱怨 "Callable[[int], int]" has no attribute "memory",这是有道理的,如果我试图访问不属于声明类型的 属性,通常我会想要这个抱怨...

那么,我们应该如何用typing来表示wrapper,而一个Callable,还有属性memory呢?

要将某物描述为 “具有内存属性的可调用对象”,您可以定义 protocol (Python 3.8+, or earlier versions with typing_extensions):

from typing import Protocol


class Wrapper(Protocol):
    memory: dict[int, int]
    def __call__(self, n: int) -> int: ...

在使用中,类型检查器知道 WrapperCallable[[int], int] 一样有效,并允许 return wrapper 以及对 wrapper.memory 的赋值:

from functools import wraps
from typing import Callable, cast


def remember(f: Callable[[int], int]) -> Callable[[int], int]:
    @wraps(f)
    def _wrapper(n: int) -> int:
        if n not in wrapper.memory:
            wrapper.memory[n] = f(n)
        return wrapper.memory[n]
    wrapper = cast(Wrapper, _wrapper)
    wrapper.memory = dict()
    return wrapper

Playground

不幸的是,这需要 wrapper = cast(Wrapper, _wrapper),它 不是 类型安全的 - wrapper = cast(Wrapper, "foo") 也可以检查。

基于 jonrsharpe 的回答(有效,建议如下,我已经接受),我们可以避免对非类型安全转换的需要,如下所示:

from typing import Callable
from functools import wraps


class Remember:
    def __init__(self) -> None:
        self.memory = dict[int, int]()

    def __call__(self, f: Callable[[int], int]) -> Callable[[int], int]:
        @wraps(f)
        def wrapper(n: int) -> int:
            if n not in self.memory:
                self.memory[n] = f(n)
            return self.memory[n]

        return wrapper


@Remember()
def fib(n: int) -> int:
    if n < 2:
        return 1
    return fib(n - 2) + fib(n - 1)


print(fib(100))

不要使用函数属性来存储缓存,就不会出现这个问题。您已经在定义闭包(wrapper 保留对原始可调用对象的引用),因此也将缓存存储在闭包中。

from typing import Callable
from functools import wraps


def remember(f: Callable[[int], int]) -> Callable[[int], int]:
    cache: dict[int, int] = {}

    @wraps(f)
    def wrapper(n: int) -> int:
        if n not in cache:
            cache[n] = f(n)
        return cache[n]

    return wrapper


@remember
def fib(n: int) -> int:
    if n < 2:
        return 1
    return fib(n - 2) + fib(n - 1)