为什么 Mypy 认为添加两个 Jax 数组 returns 一个 numpy 数组?

Why does Mypy think adding two Jax arrays returns a numpy array?

考虑以下文件:

import jax.numpy as jnp

def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    return a + b

运行mypy mypytest.pyreturns出现如下错误:

mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")

出于某种原因,它认为添加两个 jax.numpy.ndarrays returns 一个 bools 的 NumPy 数组。难道我做错了什么?或者这是 MyPy 或 Jax 类型注释中的错误?

至少在静态上,jnp.ndarraynp.ndarray 的子类,修改非常少

class ndarray(np.ndarray, metaclass=_ArrayMeta):
  dtype: np.dtype
  shape: Tuple[int, ...]
  size: int

  def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
               order=None):
    raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                    " Use jax.numpy.array, or jax.numpy.zeros instead.")

因此,它继承了 np.ndarray 的方法类型签名。

我想运行时行为是通过 jnp.array 函数实现的。除非我错过了一些存根文件或类型欺骗,否则 jnp.array 的结果匹配 jnp.ndarray 仅仅是因为 jnp.array 是未类型化的。您可以使用

进行测试
def foo(_: str) -> None:
   pass

foo(jnp.array(0))

它通过了 mypy。

所以回答你的问题,我不认为你做错了什么。从某种意义上说,这是一个错误,它可能不是他们的意思,但它实际上并不是不正确的,因为当你添加 jnp.ndarrays 时你确实得到了 np.ndarray 因为 jnp.ndarraynp.ndarray.

至于为什么 bools,这可能是因为您的 jnp.arrays 缺少通用参数并且 np.ndarray__add__ 的第一个有效重载是

    @overload
    def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ...  # type: ignore[misc]

所以它只是默认为 bool

总的来说,JAX 与 mypy 的兼容性很差,因为很难用 JAX 的转换模型来满足 mypy 的约束,JAX 的转换模型经常调用具有特定于转换的跟踪值的函数,这些跟踪值充当数组的替代品(参见 How To Think in JAX: JIT Mechanics 简要讨论此机制)。

使用跟踪器类型作为数组的替代品意味着 mypy 在转换严格类型的 JAX 函数时会引发错误,因此在整个 JAX 代码库中我们倾向于将 Array 别名为 [=11] =],并将其用作 return 数组的 JAX 函数的 return 类型注释。

最好对此进行改进,因为 Any return 类型对于有效的类型检查不是很有用,但这只是使 mypy 很好地运行的众多挑战中的第一个贾克斯。如果你想阅读过去几年围绕这个问题的一些有价值的讨论,我会从这里开始:https://github.com/google/jax/issues/943

同时,我的建议是使用 Any 作为 JAX 数组的类型注释。