- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.1k
Set up JAX sampling with GPUs in PyMC
This guide show the steps to set-up and run JAX sampling with GPU supports in PyMC. The step-by-step as follow:
The latest Ubuntu version is 22.04, but I'm a little bit conservative, so decided to install version 20.04. I download the 64-bit PC (AMD64) desktop image from here.
I made a Bootable USB using Rufus with the above ubuntu desktop .iso image. You can check this video How to Make Ubuntu 20.04 Bootable USB Drive. I assume that you have a NVIDIA GPU card on your local machine, and you know how to install ubuntu from a bootable USB. If not, you can just search it on Youtube.
According to Jax's guidelines, to install GPU support for Jax, first we need to install CUDA and CuDNN.
To do that, I follow the Installation of NVIDIA Drivers, CUDA and cuDNN from this guideline (Kudo the author Ashutosh Kumar for this).
One note is that we may not be able to find a specific version of NVIDIA Drivers on this step. Instead, we can go to this url: https://download.nvidia.com/XFree86/Linux-x86_64/ to download our specific driver version. For my case, I download the file NVIDIA-Linux-x86_64-470.82.01.run at this link: https://download.nvidia.com/XFree86/Linux-x86_64/470.82.01/
After successfully following these steps in the guideline, we can run nvidia-smi and nvcc --version commands to verify the installation. In my case, it will be somethings like this:

Following the Jax's guidelines, after installing CUDA and CuDNN, we can using pip to install Jax with GPU support.
pip install --upgrade pip
# Installs the wheel compatible with CUDA 12 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
We can then run Ipython or python and using these following commands to check.
In [1]: import jax
In [2]: jax.default_backend()
Out[2]: 'gpu'
In [3]: jax.devices()
Out[3]: [GpuDevice(id=0, process_index=0)]
That's it. We have successfully installed Jax with GPU support. Now, we can run JAX-based sampling pm.sample(nuts_sampler="numpyro", ...) in PyMC with the GPU capability.