alpha state...
A community supported Windows build for jax.
Currently, only CPU and CUDA 11.1 are supported. For CUDA 11.x, please install the cuda
/cuda11_cudnn82
package.
Eachjax
build pinnes a concrete jaxlib
package version in its setup.py
. To install an unstable
build, you must first ensure the required jaxlib
package exists in the pacakge
index. Check it out at https://whls.blob.core.windows.net/unstable/index.html
You can either install jax
via pip (CPU only or CUDA), install jax
from source or download the desired wheel manually.
**Starting from 0.4.13, CPU build was removed. Please use official CPU build from PyPI directly.
pip install "jax[cpu]===0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install jax[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install -e .[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
Select a version of jaxlib
that you want to install. Then install jax
manually.
# download jaxlib from https://whls.blob.core.windows.net/unstable/index.html
pip install <jaxlib_whl>
pip install jax
For --use-deprecated legacy-resolver
, refers to
pip #9186 and
pip #9203.