输入一个重载的装饰器,包裹在 partial 中

Typing an overloaded decorator wrapped in partial

我正在尝试正确输入被 partial:

包装的重载装饰器
from functools import partial
from typing import Any, Callable, Optional, Union, overload


AnyCallable = Callable[..., Any]


class Wrapped:
    def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
        pass


@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
    ...


@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
    ...


def create_wrapped(
    foo: str,
    func: Optional[AnyCallable] = None,
    *,
    bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
    def wrapper(func_: AnyCallable) -> Wrapped:
        return Wrapped(func_, foo, bar)

    if func is None:
        return wrapper
    return wrapper(func)


baz = partial(create_wrapped, "baz")


@baz
def func_1() -> None:
    pass


@baz(bar=False)
def func_2() -> None:
    pass

代码是正确的,但是mypy给出了

47: error: "Wrapped" not callable

这表明在应用 partial 时实际参数类型丢失,因为 @baz(bar=False) 应该匹配第二个重载,因为它与 @create_wrapped("baz", bar=False) 相同,它确实可以正常工作.

我不确定我还能怎么注释这个,事实上我想不出任何方法让 mypy 不抱怨这个,即使我对没有合适的装饰器类型很好,因为在那种情况下,我会收到 Untyped decorator makes function untyped 错误。

mypy 当前无法正确推断部分应用函数的类型:https://github.com/python/mypy/issues/1484.

您可以通过将 partial 调用的 return 转换为适当的 Protocol.

来解决这个问题
from functools import partial
from typing import Any, Callable, Optional, Protocol, Union, overload, cast


AnyCallable = Callable[..., Any]


class Wrapped:
    def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
        pass


@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
    ...


@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
    ...


def create_wrapped(
    foo: str,
    func: Optional[AnyCallable] = None,
    *,
    bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
    def wrapper(func_: AnyCallable) -> Wrapped:
        return Wrapped(func_, foo, bar)

    if func is None:
        return wrapper

    return wrapper(func)


class partial_create_wrapped(Protocol):
    @overload
    def __call__(self, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
        ...

    @overload
    def __call__(self, func: AnyCallable) -> Wrapped:
        ...


baz = cast(partial_create_wrapped, partial(create_wrapped, "baz"))


@baz
def func_1() -> None:
    pass


@baz(bar=False)
def func_2() -> None:
    pass