jax 问题(在 NVIDIA DGX 机器上,同样如此)

jax woes (on an NVDIA DGX box, no less)

我正在尝试 运行 在 nvidia dgx 盒子上进行 jax,但失败得很惨,因此:

>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(10)
2021-10-25 13:00:05.863667: W 
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:80] Couldn't 
get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
2021-10-25 13:00:05.864713: F 
external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:435] 
ptxas returned an error during compilation of ptx to sass: 'INTERNAL: Failed to 
launch ptxas'  If the error message indicates that a file could not be written, 
please verify that sufficient filesystem space is provided.
Aborted (core dumped)

如有任何建议,我们将不胜感激。

这意味着您的 CUDA 安装配置不正确,通常可以通过确保 CUDA 工具包二进制文件(包括 ptxas)存在于您的 $PATH 中来修复。有关对报告类似问题的用户的回复,请参阅 https://github.com/google/jax/discussions/6843 and https://github.com/google/jax/issues/7239