如何在 Python 中使用默认参数重载函数

How to overload on a function with default paramters in Python

所以我有以下(简化的)代码

from typing import Iterable, List, Optional, overload, Literal, Union, Tuple, Any
import sqlite3

@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[False]
) -> List[sqlite3.Row]:
    ...


@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...


def query_db(
    query: str, params: Optional[Iterable] = None, as_tuple: bool = False
) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]:
    """Run a query against the given db.

    If params is not None, securely construct a query from the given
    query string and params.
    """
    with sqlite3.connect("/dummy.sqlite") as con:
        if not as_tuple:
            con.row_factory = sqlite3.Row
        if params is None:
            rows = con.execute(query).fetchall()
        else:
            rows = con.execute(query, params).fetchall()
    return rows



a = query_db("SELECT test_column FROM test_table")
a[0]["test_column"]

我不知道如何进行类型检查。

如果我不添加重载,mypy 会抱怨我可能正在索引一个具有 str 索引的元组。

as_tuple 参数默认为 false,因此 mypy 应该能够推断出我在不向函数提供第二个和第三个参数时使用第一个重载(因为实际实现有默认值)参数)。

然而实际发生的是 mypy 抱怨 none 提供的重载匹配,因为它认为我也需要提供最后两个参数。

当我只是将默认参数复制粘贴到每个重载时,mypy 抱怨我无法将 False 分配给 as_tuple: Literal[True]

是否有一个选项可以让它在运行时对它的工作方式进行类型检查? 我真的不想修改实际的签名,因为该函数在我们的测试中被广泛使用。

好的,所以我找到了 open issue for this on mypy

当前的解决方案显然是注释显式参数的所有可能组合,在我的例子中导致:

@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[False]
) -> List[sqlite3.Row]:
    ...


@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...

@overload
def query_db(
    query: str, params: Optional[Iterable]
) -> List[sqlite3.Row]:
    ...

@overload
def query_db(
    query: str, * , as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...

@overload
def query_db(
    query: str
) -> List[sqlite3.Row]:
    ...

def query_db(
    query: str, params: Optional[Iterable] = None, as_tuple: bool = False
) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]:
    ...

如果让某些重载中的参数采用默认值,则不需要那么多重载。当您将布尔值传递给 as_tuple:

时,您可能还需要额外的重载
from typing import Iterable, List, Optional, overload, Literal, Union, Tuple, Any
import sqlite3


@overload
def query_db(
    query: str, params: Optional[Iterable]=..., as_tuple: Literal[False]=...
) -> List[sqlite3.Row]:
    ...


@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...


@overload
def query_db(
    query: str, * , as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...
    

@overload
def query_db(
    query: str, params: Optional[Iterable]=..., as_tuple: bool=...
) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]:
    ...


def query_db(
    query: str, params: Optional[Iterable] = None, as_tuple: bool = False
) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]:
    """Run a query against the given db.

    If params is not None, securely construct a query from the given
    query string and params.
    """
    with sqlite3.connect("/dummy.sqlite") as con:
        if not as_tuple:
            con.row_factory = sqlite3.Row
        if params is None:
            rows = con.execute(query).fetchall()
        else:
            rows = con.execute(query, params).fetchall()
    return rows

query: str
params: Optional[Iterable]
as_tuple: bool

reveal_type(query_db(query, params, as_tuple=True))
reveal_type(query_db(query, as_tuple=True))
reveal_type(query_db(query, params))
reveal_type(query_db(query))
reveal_type(query_db(query, params, as_tuple=False))
reveal_type(query_db(query, as_tuple=False))
reveal_type(query_db(query, params, as_tuple=as_tuple))
reveal_type(query_db(query, as_tuple=as_tuple))

运行 这给出:

main.py:51: note: Revealed type is 'builtins.list[builtins.tuple[Any]]'
main.py:52: note: Revealed type is 'builtins.list[builtins.tuple[Any]]'
main.py:53: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:54: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:55: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:56: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:57: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row], builtins.list[builtins.tuple[Any]]]'
main.py:58: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row], builtins.list[builtins.tuple[Any]]]'