使用 jax 数组索引到 numpy 数组:错误消息
indexing into numpy array with jax array: faulty error messages
下面的 numpy 代码完全没问题:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
迁移到 jax 后它也有效:
import jax.numpy as jnp
arr = jnp.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
现在让我们尝试混合使用 numpy 和 jax:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
这会产生以下错误:
IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed
如果不支持使用 jax 数组索引到 numpy 数组,我没问题。但是错误信息似乎不对。事情变得更加混乱。如果稍微更改一下形状,代码就可以正常工作。在下面的示例中,我只编辑了从 (30,) 到 (40,) 的索引形状。没有更多的错误消息:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)
arr[indices]
我是 运行 jax 版本“0.2.12”,在 cpu 上。
这里发生了什么?
这是一个长期存在的已知问题(参见 https://github.com/google/jax/issues/620);这不是 JAX 本身可以轻松修复的错误,但需要更改 NumPy 处理非 ndarray
索引的方式。好消息是修复即将到来:上面有问题的代码伴随着以下警告,该警告来自 NumPy:
FutureWarning: Using a non-tuple sequence for multidimensional indexing is
deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this
will be interpreted as an array index, `arr[np.array(seq)]`, which will result
either in an error or a different result.
此弃用周期完成后,JAX 数组将在 NumPy 索引中正常工作。
在那之前,您可以通过在使用 JAX 数组索引到 NumPy 数组时显式调用 np.asarray
来解决这个问题。
下面的 numpy 代码完全没问题:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
迁移到 jax 后它也有效:
import jax.numpy as jnp
arr = jnp.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
现在让我们尝试混合使用 numpy 和 jax:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
这会产生以下错误:
IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed
如果不支持使用 jax 数组索引到 numpy 数组,我没问题。但是错误信息似乎不对。事情变得更加混乱。如果稍微更改一下形状,代码就可以正常工作。在下面的示例中,我只编辑了从 (30,) 到 (40,) 的索引形状。没有更多的错误消息:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)
arr[indices]
我是 运行 jax 版本“0.2.12”,在 cpu 上。 这里发生了什么?
这是一个长期存在的已知问题(参见 https://github.com/google/jax/issues/620);这不是 JAX 本身可以轻松修复的错误,但需要更改 NumPy 处理非 ndarray
索引的方式。好消息是修复即将到来:上面有问题的代码伴随着以下警告,该警告来自 NumPy:
FutureWarning: Using a non-tuple sequence for multidimensional indexing is
deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this
will be interpreted as an array index, `arr[np.array(seq)]`, which will result
either in an error or a different result.
此弃用周期完成后,JAX 数组将在 NumPy 索引中正常工作。
在那之前,您可以通过在使用 JAX 数组索引到 NumPy 数组时显式调用 np.asarray
来解决这个问题。