The rocm-jax repository is deprecated for JAX wheel development, build, and
test workflows. Including or any version before 0.9.2, use the release branches
of this repository to build or test JAX.
| JAX Version | Repo used for building | Build jaxlib wheel? |
|---|---|---|
| 0.8.x | rocm/rocm-jax | Yes |
| 0.9.x | rocm/rocm-jax | Yes |
| 0.10.x | rocm/jax | No |
| latest | rocm/jax or jax-ml/jax | (only when jaxlib changes) |
Starting from JAX 0.10.0 Release, teams that build or test JAX wheels must use the ROCm JAX fork:
git clone https://github.com/ROCm/jax.gitUse the build and test scripts from that repository, including build/build.py
and the ROCm artifact workflow. The legacy stack.py entrypoint in this
repository now exits with a deprecation notice. The retained build/ci_build
script remains available only for Docker image infrastructure actions; wheel
build and test actions exit with guidance to ROCm/jax. The
jax_rocm_plugin/build/build.py compatibility entrypoint delegates to
ROCm/jax build/build.py.
The default branch for this repository is rocm-jax-infra. It keeps the
Dockerfiles and infrastructure files needed to build ROCm JAX images.
This branch is not the source of truth for:
- ROCm JAX plugin or PJRT source development.
- Building JAX,
jaxlib, ROCm plugin, or ROCm PJRT wheels. - Running JAX unit tests for wheel validation.
Those workflows belong in https://github.com/ROCm/jax.
Developers and CI jobs that build or test JAX wheels should clone
https://github.com/ROCm/jax and follow the build and test documentation in
that repository.
QA and infrastructure jobs that build ROCm JAX images should continue using the
Dockerfiles in this rocm-jax-infra branch. These image builds consume wheels
produced by the ROCm JAX fork.
Image users should consume the published ROCm JAX images from the configured container registry for their environment.
ROCm JAX PJRT and plugin wheels should be built from the ROCm JAX fork, not from this infrastructure repository:
git clone https://github.com/ROCm/jax.git
cd jaxFor a local wheel build, use the ROCm artifact build script in ROCm/jax.
Set JAXCI_HERMETIC_PYTHON_VERSION for the Python ABI you want to build, and
set JAXCI_OUTPUT_DIR to the directory where wheels should be written:
export JAXCI_OUTPUT_DIR="$PWD/dist"
export JAXCI_CLONE_MAIN_XLA=1
# Build one Python-specific ROCm plugin wheel.
JAXCI_HERMETIC_PYTHON_VERSION=3.12 \
./ci/build_rocm_artifacts.sh jax-rocm-plugin
# Build the py3-none ROCm PJRT wheel.
JAXCI_HERMETIC_PYTHON_VERSION=3.12 \
./ci/build_rocm_artifacts.sh jax-rocm-pjrtTo build plugin wheels for multiple Python versions, repeat the plugin command
with each required Python version, for example 3.11, 3.12, 3.13, and
3.14. PJRT is a py3-none wheel, so it only needs to be built once for the
wheel set.
CI wheel production should use the ROCm/jax artifact workflow and scripts,
including .github/workflows/build_rocm_artifacts.yml,
ci/build_rocm_artifacts.sh, and build/build.py. The resulting wheels are
published by ROCm/jax and consumed by this repository's Docker image workflow
through the configured CloudFront/S3 artifact location.
The ci/build_rocm_artifacts.sh script is the preferred wrapper because it
sets up the ROCm build environment and invokes build/build.py with the ROCm
Bazel configuration. For debugging or custom automation in ROCm/jax, the
equivalent direct build.py shape is:
export JAXCI_OUTPUT_DIR="$PWD/dist"
export JAXCI_CLONE_MAIN_XLA=1
python build/build.py build \
--wheels=jax-rocm-plugin \
--bazel_startup_options="--bazelrc=build/rocm/rocm.bazelrc" \
--bazel_options=--config=rocm_release_wheel \
--bazel_options=--config=rocm_rbe \
--python_version=3.12 \
--verbose \
--detailed_timestamped_log \
--output_path="$JAXCI_OUTPUT_DIR"
python build/build.py build \
--wheels=jax-rocm-pjrt \
--bazel_startup_options="--bazelrc=build/rocm/rocm.bazelrc" \
--bazel_options=--config=rocm_release_wheel \
--bazel_options=--config=rocm_rbe \
--python_version=3.12 \
--verbose \
--detailed_timestamped_log \
--output_path="$JAXCI_OUTPUT_DIR"Use --wheels=jax-rocm-plugin for Python-specific plugin wheels and
--wheels=jax-rocm-pjrt for the py3-none PJRT wheel. Change
--python_version for each plugin ABI that needs to be built.
The Dockerfiles in this repository are retained for image construction:
docker/Dockerfile.base-ubu24docker/Dockerfile.base-therock-ubu24docker/Dockerfile.jax-ubu24docker/manylinux/Dockerfile.jax-manylinux_2_28-rocmdocker/manylinux/Dockerfile.jax-manylinux_2_28-therock
The Dockerfiles depend on the small set of retained infrastructure inputs,
including tools/get_rocm.py, build/requirements.txt,
docker/manylinux/clang.cfg, docker/patches/rocr-intercept-queue-fix.patch,
and a local wheelhouse/ containing ROCm JAX wheels.
Example image build shape:
# Build or obtain wheels from https://github.com/ROCm/jax first.
mkdir -p wheelhouse
cp /path/to/jax/wheels/*.whl wheelhouse/
# Build a base image.
docker build \
-f docker/Dockerfile.base-ubu24 \
--build-arg ROCM_VERSION=7.2.0 \
--build-arg ROCM_VERSION_TAG=720 \
-t ghcr.io/rocm/jax-base-ubu24.rocm720:local \
.
# Build a JAX image from those wheels.
docker build \
-f docker/Dockerfile.jax-ubu24 \
--build-arg ROCM_VERSION_TAG=720 \
--build-arg BASE_IMAGE_TAG=local \
--build-arg ROCM_VERSION=7.2.0 \
--build-arg JAX_VERSION=<jax-version> \
--build-arg XLA_COMMIT=<xla-commit> \
--build-arg JAX_COMMIT=<jax-commit> \
--build-arg ROCM_JAX_COMMIT=<rocm-jax-commit> \
--build-arg PLUGIN_NAMESPACE=7 \
-t ghcr.io/rocm/jax-ubu24.rocm720:local \
.First install wheels produced in $JAXCI_OUTPUT_DIR.
cd dist/
pip install dist/*.whl
pip install jax==<version>Then confirm if GPU devices are visible
import jax
print(jax.local_devices())Then run tests
# Change to jax repository root
cd jax
pip install pytest pytest-xdist pytest-html pytest-csv flatbuffers hypothesis uv
cd jax && mkdir -p dist && JAXCI_PYTHON="$(command -v python)" ./ci/run_pytest_rocm.shThe .github/workflows directory and scripts used by those workflows are kept
so existing infrastructure can be migrated deliberately. Existing Docker image
workflows continue to use this repository and build/ci_build for Docker
image construction. Wheel-build and JAX-test workflows should move to the
ROCm/jax artifact model, where .github/workflows/build_rocm_artifacts.yml
runs ci/build_rocm_artifacts.sh and build/build.py. TheRock wheel builds
should use the ROCm/jax TheRock manylinux image override, such as
ghcr.io/rocm/jax-manylinux_2_28-therock-latest:latest.
The reporting workflows for pytest results and Llama performance remain in this repository with their supporting upload scripts.