[记录] jax计算时无法正常使用CUDA
·
配置:
Python: 3.10.16
OS: Debian GNU/Linux 12 (bookworm) x86_64
GPU: 3090
Driver Version: 535.216.01 CUDA Version: 12.2 (nvidia-smi)
CUDA Toolkit (nvcc):11.8 (nvcc --version)
ERROR:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 20
18 outputs = {}
19 for k in tqdm(B_dict):
---> 20 outputs[k] = train_model(network_size, learning_rate, iters, B_dict[k], train_data, test_data)
Cell In[4], line 41
39 xs = []
40 for i in tqdm(range(iters), desc='train iter', leave=False):
---> 41 opt_state = opt_update(i, model_grad_loss(get_params(opt_state), *train_data), opt_state)
43 if i % 25 == 0:
44 train_psnrs.append(model_psnr(get_params(opt_state), *train_data))
ValueError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: Failed to update gpu graph: Graph update result=kNodeTypeChanged: Failed to update CUDA graph: CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: the graph update was not performed because it included changes which violated constraints specific to instantiated graph update; current profiling annotation: XlaModule:#hlo_module=jit__lambda_,program_id=42#.
卸载原版jax和jaxlib,最新版的jax无法自动安装带cuda版本的jaxlib
推荐使用旧版:pip install jax[cuda11_pip]==0.4.23 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐



所有评论(0)