模块 'jaxlib.xla_extension.jax_jit' 没有属性 'set_enable_x64_cpp_flag'

Module 'jaxlib.xla_extension.jax_jit' has no attribute 'set_enable_x64_cpp_flag'

我尝试在 win10 上安装 jax.lib。似乎安装了 jax.lib 但是当我 运行 spyder 并写 'import jax' 时,它说

module 'jaxlib.xla_extension.jax_jit' has no attribute 'set_enable_x64_cpp_flag'

我有 python 3.10 和 cuda 版本 11.6。

你能帮我解决一下吗?

我刚刚通过编写升级了 JAX

pip install --upgrade jax jaxlib

在 anaconda 命令提示符下,问题已解决。