有什么方法可以定义带有前导可选参数的 Python 函数吗?

Is there any way to define a Python function with leading optional arguments?

正如我们所知,可选参数必须位于参数列表的末尾,如下所示:

def func(arg1, arg2, ..., argN=default)

我在 PyTorch 包中看到了一些异常。例如,我们可以在 torch.randint 中找到这个问题。如图所示,它的位置参数中有一个前导可选参数!怎么可能?

Docstring:
randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor

我们怎样才能像上面那样定义一个函数?

你的发现让我着迷,因为在 Python(以及我知道的所有其他语言)中使用前导可选参数确实是非法的,这在我们的例子中肯定会引发:

SyntaxError: non-default argument follows default argument

我有点怀疑,但我已经搜索了源代码:

我发现,在 TensorFactories.cpp 的第 566-596 行,实际上 randint:

有几个(!)实现
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tensor randint(int64_t high, IntArrayRef size, const TensorOptions& options) {
  return native::randint(high, size, c10::nullopt, options);
}

Tensor randint(
    int64_t high,
    IntArrayRef size,
    c10::optional<Generator> generator,
    const TensorOptions& options) {
  return native::randint(0, high, size, generator, options);
}

Tensor randint(
    int64_t low,
    int64_t high,
    IntArrayRef size,
    const TensorOptions& options) {
  return native::randint(low, high, size, c10::nullopt, options);
}

Tensor randint(
    int64_t low,
    int64_t high,
    IntArrayRef size,
    c10::optional<Generator> generator,
    const TensorOptions& options) {
  auto result = at::empty(size, options);
  return result.random_(low, high, generator);
}

此模式再次出现在 gen_pyi.py 的第 466-471 行,它为 top-level 函数生成类型签名:

        'randint': ['def randint(low: _int, high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS),
                    'def randint(high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS)],

所以,基本上发生的是没有“真正的”可选参数,而是有几个函数,其中一个存在,另一个不存在。

这意味着,当在没有 low 参数的情况下调用 randint 时,它被设置为 0:

Tensor randint(
    int64_t high,
    IntArrayRef size,
    c10::optional<Generator> generator,
    const TensorOptions& options) {
  return native::randint(0, high, size, generator, options);
}

进一步研究,关于 OP 请求如何可能存在具有相同名称和不同参数的多个函数:

再次返回gen_pyi.py我们看到这些函数被收集到第436行定义的unsorted_function_hints,然后在第509-513行用于创建function_hints,最后function_hints 在第 670 行设置为 env

env 字典用于编写 pyi 存根文件。

这些存根文件使用 Function/method 重载,如 PEP-484 中所述。

Function/method重载,利用@overload装饰器:

The @overload decorator allows describing functions and methods that support multiple different combinations of argument types. This pattern is used frequently in builtin modules and types.

这是一个例子:

from typing import overload

class bytes:
    ...
    @overload
    def __getitem__(self, i: int) -> int: ...
    @overload
    def __getitem__(self, s: slice) -> bytes: ...

所以我们基本上定义了具有不同参数的相同函数__getitem__

还有一个例子:

from typing import Callable, Iterable, Iterator, Tuple, TypeVar, overload

T1 = TypeVar('T1')
T2 = TypeVar('T2')
S = TypeVar('S')

@overload
def map(func: Callable[[T1], S], iter1: Iterable[T1]) -> Iterator[S]: ...
@overload
def map(func: Callable[[T1, T2], S],
        iter1: Iterable[T1], iter2: Iterable[T2]) -> Iterator[S]: ...
# ... and we could add more items to support more than two iterables

这里我们定义了相同函数 map,但参数数量不同。

单个函数不允许只有前导可选参数:

8.6. Function definitions

[...] If a parameter has a default value, all following parameters up until the “*” must also have a default value — this is a syntactic restriction that is not expressed by the grammar.

请注意,这不包括 keyword-only 个参数,这些参数从不按位置接收参数。


如果需要,可以通过手动实现参数匹配来模拟这种行为。例如,可以基于元数进行分派,或者显式匹配可变参数。

def leading_default(*args):
    # match arguments to "parameters"
    *_, low, high, size = 0, *args
    print(low, high, size)

leading_default(1, 2)     # 0, 1, 2
leading_default(1, 2, 3)  # 1, 2, 3

一种简单的分派形式通过迭代签名并调用第一个匹配的签名来实现函数重载。

import inspect


class MatchOverload:
    """Overload a function via explicitly matching arguments to parameters on call"""
    def __init__(self, base_case=None):
        self.cases = [base_case] if base_case is not None else []

    def overload(self, call):
        self.cases.append(call)
        return self

    def __call__(self, *args, **kwargs):
        failures = []
        for call in self.cases:
            try:
                inspect.signature(call).bind(*args, **kwargs)
            except TypeError as err:
                failures.append(str(err))
            else:
                return call(*args, **kwargs)
        raise TypeError(', '.join(failures))


@MatchOverload
def func(high, size):
    print('two', 0, high, size)


@func.overload
def func(low, high, size):
    print('three', low, high, size)


func(1, 2, size=3)    # three 1 2 3
func(1, 2)            # two 0 1 2
func(1, 2, 3, low=4)  # TypeError: too many positional arguments, multiple values for argument 'low'

我的另一个答案是关于 reverse-engineering 火炬库,但是我想把这个答案献给如何以 non-hacky 直接的方式实现类似的机制。

我们有 multipledispatch 图书馆:

A relatively sane approach to multiple dispatch in Python. This implementation of multiple dispatch is efficient, mostly complete, performs static analysis to avoid conflicts, and provides optional namespace support. It looks good too.

所以让我们利用它:

from multipledispatch import dispatch

@dispatch(int, int)
def randint(low, high):
    my_randint(low, high)

@dispatch(int)
def randint(high):
    my_randint(0, high)

def my_randint(low, high):
    print(low, high)

# 0 5
randint(5)

# 2 3
randint(2, 3)