为什么 Mypy 认为添加两个 Jax 数组 returns 一个 numpy 数组?
Why does Mypy think adding two Jax arrays returns a numpy array?
考虑以下文件:
import jax.numpy as jnp
def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
运行mypy mypytest.py
returns出现如下错误:
mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")
出于某种原因,它认为添加两个 jax.numpy.ndarray
s returns 一个 bools 的 NumPy 数组。难道我做错了什么?或者这是 MyPy 或 Jax 类型注释中的错误?
至少在静态上,jnp.ndarray
是 np.ndarray
的子类,修改非常少
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
因此,它继承了 np.ndarray
的方法类型签名。
我想运行时行为是通过 jnp.array
函数实现的。除非我错过了一些存根文件或类型欺骗,否则 jnp.array
的结果匹配 jnp.ndarray
仅仅是因为 jnp.array
是未类型化的。您可以使用
进行测试
def foo(_: str) -> None:
pass
foo(jnp.array(0))
它通过了 mypy。
所以回答你的问题,我不认为你做错了什么。从某种意义上说,这是一个错误,它可能不是他们的意思,但它实际上并不是不正确的,因为当你添加 jnp.ndarray
s 时你确实得到了 np.ndarray
因为 jnp.ndarray
是 np.ndarray
.
至于为什么 bool
s,这可能是因为您的 jnp.array
s 缺少通用参数并且 np.ndarray
上 __add__
的第一个有效重载是
@overload
def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
所以它只是默认为 bool
。
总的来说,JAX 与 mypy 的兼容性很差,因为很难用 JAX 的转换模型来满足 mypy 的约束,JAX 的转换模型经常调用具有特定于转换的跟踪值的函数,这些跟踪值充当数组的替代品(参见 How To Think in JAX: JIT Mechanics 简要讨论此机制)。
使用跟踪器类型作为数组的替代品意味着 mypy 在转换严格类型的 JAX 函数时会引发错误,因此在整个 JAX 代码库中我们倾向于将 Array
别名为 [=11] =],并将其用作 return 数组的 JAX 函数的 return 类型注释。
最好对此进行改进,因为 Any
return 类型对于有效的类型检查不是很有用,但这只是使 mypy 很好地运行的众多挑战中的第一个贾克斯。如果你想阅读过去几年围绕这个问题的一些有价值的讨论,我会从这里开始:https://github.com/google/jax/issues/943
同时,我的建议是使用 Any
作为 JAX 数组的类型注释。
考虑以下文件:
import jax.numpy as jnp
def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
运行mypy mypytest.py
returns出现如下错误:
mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")
出于某种原因,它认为添加两个 jax.numpy.ndarray
s returns 一个 bools 的 NumPy 数组。难道我做错了什么?或者这是 MyPy 或 Jax 类型注释中的错误?
至少在静态上,jnp.ndarray
是 np.ndarray
的子类,修改非常少
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
因此,它继承了 np.ndarray
的方法类型签名。
我想运行时行为是通过 jnp.array
函数实现的。除非我错过了一些存根文件或类型欺骗,否则 jnp.array
的结果匹配 jnp.ndarray
仅仅是因为 jnp.array
是未类型化的。您可以使用
def foo(_: str) -> None:
pass
foo(jnp.array(0))
它通过了 mypy。
所以回答你的问题,我不认为你做错了什么。从某种意义上说,这是一个错误,它可能不是他们的意思,但它实际上并不是不正确的,因为当你添加 jnp.ndarray
s 时你确实得到了 np.ndarray
因为 jnp.ndarray
是 np.ndarray
.
至于为什么 bool
s,这可能是因为您的 jnp.array
s 缺少通用参数并且 np.ndarray
上 __add__
的第一个有效重载是
@overload
def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
所以它只是默认为 bool
。
总的来说,JAX 与 mypy 的兼容性很差,因为很难用 JAX 的转换模型来满足 mypy 的约束,JAX 的转换模型经常调用具有特定于转换的跟踪值的函数,这些跟踪值充当数组的替代品(参见 How To Think in JAX: JIT Mechanics 简要讨论此机制)。
使用跟踪器类型作为数组的替代品意味着 mypy 在转换严格类型的 JAX 函数时会引发错误,因此在整个 JAX 代码库中我们倾向于将 Array
别名为 [=11] =],并将其用作 return 数组的 JAX 函数的 return 类型注释。
最好对此进行改进,因为 Any
return 类型对于有效的类型检查不是很有用,但这只是使 mypy 很好地运行的众多挑战中的第一个贾克斯。如果你想阅读过去几年围绕这个问题的一些有价值的讨论,我会从这里开始:https://github.com/google/jax/issues/943
同时,我的建议是使用 Any
作为 JAX 数组的类型注释。