配置:
Python: 3.10.16
OS: Debian GNU/Linux 12 (bookworm) x86_64
GPU: 3090
Driver Version: 535.216.01 CUDA Version: 12.2 (nvidia-smi)
CUDA Toolkit (nvcc):11.8 (nvcc --version)

ERROR:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 20
     18 outputs = {}
     19 for k in tqdm(B_dict):
---> 20   outputs[k] = train_model(network_size, learning_rate, iters, B_dict[k], train_data, test_data)

Cell In[4], line 41
     39 xs = []
     40 for i in tqdm(range(iters), desc='train iter', leave=False):
---> 41     opt_state = opt_update(i, model_grad_loss(get_params(opt_state), *train_data), opt_state)
     43     if i % 25 == 0:
     44         train_psnrs.append(model_psnr(get_params(opt_state), *train_data))

ValueError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: Failed to update gpu graph: Graph update result=kNodeTypeChanged: Failed to update CUDA graph: CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: the graph update was not performed because it included changes which violated constraints specific to instantiated graph update; current profiling annotation: XlaModule:#hlo_module=jit__lambda_,program_id=42#.

卸载原版jax和jaxlib,最新版的jax无法自动安装带cuda版本的jaxlib
推荐使用旧版:pip install jax[cuda11_pip]==0.4.23 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐