使用 mypy 处理条件逻辑 + 标记值

Handling conditional logic + sentinel value with mypy

我有一个大致如下所示的函数:

import datetime
from typing import Union

class Sentinel(object): pass
sentinel = Sentinel()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Sentinel] = sentinel,
) -> str:

    if as_tz is not sentinel:
        # Never reached if as_tz has wrong type (Sentinel)
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

这里使用sentinel值是因为None已经是.astimezone()的有效参数,所以目的是正确识别用户不想调用的情况.astimezone() 完全没有。

但是,mypy 抱怨这种模式:

error: Argument 1 to "astimezone" of "datetime" has incompatible type "Union[tzinfo, None, Sentinel]"; expected "Optional[tzinfo]"

这似乎是因为 datetime stub(理所当然)使用:

def astimezone(self, tz: Optional[_tzinfo] = ...) -> datetime: ...

但是,有没有办法让 mypy 知道 sentinel 值永远不会因为 if 检查而传递给 .astimezone()?或者这是否只需要一个 # type: ignore 而没有更清洁的方法?


另一个例子:

from typing import Optional
import requests


def func(session: Optional[requests.Session] = None):
    new_session_made = session is None
    if new_session_made:
        session = requests.Session()
    try:
        session.request("GET", "https://a.b.c.d.com/foo")
        # ...
    finally:
        if new_session_made:
            session.close()

第二个,与第一个一样,是 "runtime safe"(因为没有更好的术语):调用 None.request()None.close()AttributeError 将无法到达或评价。然而,mypy 仍然抱怨:

mypytest.py:9: error: Item "None" of "Optional[Session]" has no attribute "request"
mypytest.py:13: error: Item "None" of "Optional[Session]" has no attribute "close"

我应该在这里做些不同的事情吗?

绕过此问题的一种方法是执行以下操作:

from typing import Optional
import requests


def func(session: Optional[requests.Session] = None) -> None:
    new_session = session is None
    if not session:
        session = requests.Session()
    try:
        session.request("GET", "https://a.b.c.d.com/foo")
        # other stuff
    finally:
        if not new_session:
            session.close()

此外,我们可以检查 mypy 是否可以处理我们使用不同参数类型的情况:

func('a')  # mypy_typing.py:14: error: Argument 1 to "func" has incompatible type "str"; expected "Optional[Session]"
func(1)  # mypy_typing.py:14: error: Argument 1 to "func" has incompatible type "int"; expected "Optional[Session]"
...
# PS:  The test will break for any kind of types except for None and requests.Session
...

但是,如果我们使用 Nonerequest.Session() 对象作为参数,测试将顺利通过:

func(None)  # No errors
func(requests.Session())  # No errors

有关更多信息,您可以从 mypy 的官方文档中阅读此 example

您可以使用显式 cast:

    from typing import cast
    ... 
    if as_tz is not sentinel:
        # Never reached if as_tz has wrong type (Sentinel)
        as_tz = cast(datetime.tzinfo, as_tz)
        dt = dt.astimezone(as_tz)

    new_session_made = session is None
    session = cast(requests.Session, session)

您可以交替使用 assert(尽管这是实际的运行时检查,而 cast 更明确地说是无操作):

        assert isinstance(as_tz, datetime.tzinfo)
        dt = dt.astimezone(as_tz)

    new_session_made = session is None
    assert session is not None

Mypy 对 isinstance 有特殊处理。除了检查身份之外,还可以这样做:

if not isinstance(as_tz, Sentinel):
    dt = dt.astimezone(as_tz)

...

通过此更改,您的示例似乎可以进行类型检查。

根据我的经验,最好的解决方案是使用 enum.Enum.

要求

一个好的哨兵模式有 3 个属性:

  1. 具有明确的 type/value ,不会被误认为是其他值。例如object()
  2. 可以使用描述性常量来引用
  3. 可以简洁地测试,使用isis not

解决方案

enum.Enum 被 mypy 特别对待,因此它是我发现的唯一可以满足所有这三个要求并在 mypy 中正确验证的解决方案。

import datetime
import enum
from typing import Union

class Sentinel(enum.Enum):
    SKIP_TZ = object()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Sentinel] = Sentinel.SKIP_TZ,
) -> str:

    if as_tz is not Sentinel.SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

变化

此解决方案还有其他一些有趣的特性。

可重复使用的 Sentinel 对象

sentinel.py

import enum
class Sentinel(enum.Enum):
    sentinel = object()

main.py

import datetime
from sentinel import Sentinel
from typing import Union

SKIP_TZ = Sentinel.sentinel

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Sentinel] = SKIP_TZ,
) -> str:

    if as_tz is not SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

请注意,由于 Sentinel.sentinel 始终提供相同的 object 实例,因此两个可重用哨兵永远不应在相同的上下文中使用。

使用 Literal

限制 Sentinel 值

Sentinel 替换为 Literal[Sentinel.SKIP_TZ]] 使您的函数签名更加清晰,尽管它确实是多余的,因为只有一个枚举值。

import datetime
import enum
from typing import Union
from typing_extensions import Literal

class Sentinel(enum.Enum):
    SKIP_TZ = object()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Literal[Sentinel.SKIP_TZ]] = Sentinel.SKIP_TZ,
) -> str:

    if as_tz is not Sentinel.SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

func(datetime.datetime.now(), as_tz=Sentinel.SKIP_TZ)

不符合我要求的方案

自定义哨兵class

import datetime
from typing import Union

class SentinelType:
    pass

SKIP_TZ = SentinelType()


def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, SentinelType] = SKIP_TZ,
) -> str:

    if not isinstance(dt, SentinelType):
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

虽然这有效,但使用 isinstance(dt, SentinelType) 不符合要求 3(“使用 is”),因此也不符合要求 2(“使用命名常量”)。为清楚起见,我希望能够使用 if dt is not SKIP_TZ.

对象Literal

Literal 不适用于任意值(尽管它确实适用于枚举。见上文。)

import datetime
from typing import Union
from typing_extensions import Literal

SKIP_TZ = object()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Literal[SKIP_TZ]] = SKIP_TZ,
) -> str:

    if dt is SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

产生以下 mypy 错误:

error: Parameter 1 of Literal[...] is invalid
error: Variable "sentinel.SKIP_TZ" is not valid as a type

字符串Literal

在这次尝试中,我使用了字符串文字而不是对象:

import datetime
from typing import Union
from typing_extensions import Literal


def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Literal['SKIP_TZ']] = 'SKIP_TZ',
) -> str:

    if as_tz is not 'SKIP_TZ':
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

func(datetime.datetime.now(), as_tz='SKIP_TZ')

即使这行得通,它在要求 1 上也很薄弱。

但是在mypy中没有通过。它产生错误:

error: Argument 1 to "astimezone" of "datetime" has incompatible type "Union[tzinfo, None, Literal['SKIP_TZ']]"; expected "Optional[tzinfo]"