Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion .github/workflows/build-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ on:
required: false
default: 'plugin_wheels'
type: string
secrets:
rbe_ci_key:
required: true
rbe_ci_cert:
required: true

jobs:
build-plugin-wheels:
Expand Down Expand Up @@ -51,6 +56,13 @@ jobs:
rocm-smi -a || true
rocminfo | grep gfx || true
- uses: actions/checkout@v4
- name: Get RBE cluster keys
env:
RBE_CI_CERT: ${{ secrets.rbe_ci_cert }}
RBE_CI_KEY: ${{ secrets.rbe_ci_key }}
run: |
echo "$RBE_CI_CERT" >> ./jax_rocm_plugin/ci-cert.crt
echo "$RBE_CI_KEY" >> ./jax_rocm_plugin/ci-cert.key
- name: Build plugin wheels
run: |
python3 build/ci_build \
Expand All @@ -59,7 +71,8 @@ jobs:
--rocm-version="${{ inputs.rocm-version }}" \
--rocm-build-job="${{ inputs.rocm-build-job }}" \
--rocm-build-num="${{ inputs.rocm-build-num }}" \
dist_wheels
dist_wheels \
--rbe
- name: Archive plugin wheels
uses: actions/upload-artifact@v4
with:
Expand Down
23 changes: 22 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ jobs:
# fails with a complaint abou the pipes module.
python-versions: "3.11,3.12"
rocm-version: ${{ matrix.rocm-version }}
secrets:
rbe_ci_cert: ${{ secrets.RBE_CI_CERT }}
rbe_ci_key: ${{ secrets.RBE_CI_KEY }}
call-build-docker:
needs: call-build-wheels
strategy:
Expand Down Expand Up @@ -59,6 +62,18 @@ jobs:
repository: rocm/jax
ref: rocm-jaxlib-v0.7.1
path: jax
- name: Apply patches to rocm/jax test repo
run: |
pushd jax
git apply ../ci/patches/*
popd
Comment on lines +65 to +69
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldnt this be done automatically by bazel ? Why do we need to do this manually ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these get applied to the rocm/jax that we use for test cases, not the jax-ml/jax we use for wheel builds. Maybe it'd be better to carry this as a change in the rocm/jax repo itself rather than a patch? I'd really rather not carry anything in that repo though, and it's unfortunate that we have to keep jaxlib changes in there right now.

- name: Get RBE cluster keys
env:
RBE_CI_CERT: ${{ secrets.RBE_CI_CERT }}
RBE_CI_KEY: ${{ secrets.RBE_CI_KEY }}
run: |
echo "$RBE_CI_CERT" >> ./jax/ci-cert.crt
echo "$RBE_CI_KEY" >> ./jax/ci-cert.key
- name: Authenticate to GitHub Container Registry
run: |
echo "${{ secrets.GITHUB_TOKEN }}" \
Expand All @@ -70,6 +85,11 @@ jobs:
docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \
"ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
rocm-smi -a || true
- name: Download wheel artifacts
uses: actions/download-artifact@v4
with:
name: plugin_wheels_r${{ matrix.rocm-version }}
path: ./wheelhouse
- name: Run tests
env:
GPU_COUNT: "8"
Expand All @@ -78,7 +98,8 @@ jobs:
ROCM_VERSION: ${{ matrix.rocm-version }}
# TODO: Add the tests/linalg_test.py test back once we fix the XLAClient thing.
run: |

python3 build/ci_build test \
"ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
--test-cmd "pytest jax/tests/core_test.py"
--test-cmd "bash ci/jax_rbe/pr_setup.sh && ci/jax_rbe/pr_test.sh 0.7.1 3.12"

6 changes: 6 additions & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ jobs:
repository: rocm/jax
ref: rocm-jaxlib-v0.7.1
path: jax
- name: Apply patches to rocm/jax test repo
run: |
pushd jax
git apply ../ci/patches/*
popd
- name: Authenticate to GitHub Container Registry
run: |
echo "${{ secrets.GITHUB_TOKEN }}" \
Expand All @@ -95,6 +100,7 @@ jobs:
ROCM_VERSION: ${{ matrix.rocm-version }}
UBUNTU_VERSION: ${{ matrix.ubuntu-version }}
run: |
# (charleshofer) TODO: Switch to RBE once we're able to process the test log information
python3 build/ci_build test \
"ghcr.io/rocm/jax-ubu${UBUNTU_VERSION}.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
--test-cmd "python jax_rocm_plugin/build/rocm/run_single_gpu.py -c && \
Expand Down
13 changes: 12 additions & 1 deletion build/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def dist_wheels(
rocm_build_job="",
rocm_build_num="",
therock_path=None,
rbe=False,
compiler="gcc",
):
jax_plugin_dir = "jax_rocm_plugin"
Expand All @@ -89,6 +90,9 @@ def dist_wheels(
xla_path = os.path.realpath(os.path.expanduser(xla_source_dir))
cmd.append("--xla-source-dir=%s" % xla_path)

if rbe:
cmd.append("--rbe")

cmd.append("dist_wheels")
subprocess.check_call(cmd, cwd=jax_plugin_dir)

Expand Down Expand Up @@ -201,7 +205,8 @@ def test(image_name, test_cmd=None):
print(f"rocm/xla commit hash: {labels['com.amdgpu.xla_commit']}", flush=True)

if not test_cmd:
test_cmd = "python jax_rocm_plugin/build/rocm/run_single_gpu.py -c && python jax_rocm_plugin/build/rocm/run_multi_gpu.py"
test_cmd = "python jax_rocm_plugin/build/rocm/run_single_gpu.py -c && " \
"python jax_rocm_plugin/build/rocm/run_multi_gpu.py"

gpu_args = [
"--device=/dev/kfd",
Expand Down Expand Up @@ -302,6 +307,11 @@ def parse_args():
subp = p.add_subparsers(dest="action", required=True)

dwp = subp.add_parser("dist_wheels")
dwp.add_argument(
"--rbe",
action="store_true",
help="Use Bazel RBE for building"
)

dtestp = subp.add_parser("test_docker")
dtestp.add_argument("--docker-build-only", action="store_true")
Expand Down Expand Up @@ -355,6 +365,7 @@ def main():
rocm_build_job=args.rocm_build_job,
rocm_build_num=args.rocm_build_num,
therock_path=args.therock_path,
rbe=args.rbe,
compiler=args.compiler,
)

Expand Down
14 changes: 14 additions & 0 deletions ci/jax_rbe/pr_setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

set -ex

# Fail gracefully if there's no ./jax directory
if [ ! -d ./jax ]; then
echo "ERROR: ./jax directory does not exist"
fi

# (charleshofer) This might create a dependency we don't want, but we can
# remove this once platform information is stored in the XLA project.
# Give platform information to JAX
cp -r jax_rocm_plugin/platform jax/

47 changes: 47 additions & 0 deletions ci/jax_rbe/pr_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash

set -ex

args=("$@")
JAX_VERSION="${args[0]}"
PYTHON="${args[1]}"

ROCM_JAX_DIR=$(realpath ".")
JAX_DIR=$(realpath "./jax")
WHEELHOUSE=$(realpath "./wheelhouse")


pushd "${ROCM_JAX_DIR}" || exit

if [ ! -d "${JAX_DIR}" ]; then
git clone -b "release/${JAX_VERSION}" --depth 1 "${JAX_DIR}"
fi

pushd "${JAX_DIR}" || exit

# If we haven't stuck the plugin requirements in the requirements.in file yet, put them in
if ! grep -q jax_rocm7 build/requirements.in; then
{
echo "jaxlib==${JAX_VERSION}"
echo "${WHEELHOUSE}/jax_rocm7_pjrt-${JAX_VERSION}-py3-none-manylinux_2_28_x86_64.whl"
echo "${WHEELHOUSE}/jax_rocm7_plugin-${JAX_VERSION}-cp${PYTHON//.}-cp${PYTHON//.}-manylinux_2_28_x86_64.whl"
} >> build/requirements.in
fi

python3 build/build.py requirements_update --python="${PYTHON}"
python3 build/build.py build --wheels=jax-rocm-plugin --configure_only --python_version="${PYTHON}"

./bazel-7.4.1-linux-x86_64 \
--bazelrc=.bazelrc \
--bazelrc=../jax_rocm_plugin/rbe.bazelrc \
test \
--config=rocm \
--config=rocm_rbe \
--noremote_accept_cached \
--//jax:build_jaxlib=false \
--action_env=TF_ROCM_AMDGPU_TARGETS="gfx906,gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" \
--test_verbose_timeout_warnings \
--test_output=errors \
//tests:core_test_gpu \
//tests:linalg_test_gpu

83 changes: 83 additions & 0 deletions ci/patches/0001-Enable-testing-with-ROCm-plugin-wheels.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
From 5e6d27b98b054bd110efb24e4fad34454f513a94 Mon Sep 17 00:00:00 2001
From: Charles Hofer <[email protected]>
Date: Wed, 5 Nov 2025 01:25:50 +0000
Subject: [PATCH] Enable testing with ROCm plugin wheels

---
WORKSPACE | 2 ++
jaxlib/jax.bzl | 8 +++++---
jaxlib/tools/BUILD.bazel | 19 +++++++++++++++++++
3 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/WORKSPACE b/WORKSPACE
index 9061b65ab..c42aa005e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -31,6 +31,8 @@ python_init_repositories(
"jaxlib*",
"jax_cuda*",
"jax-cuda*",
+ "jax_rocm*",
+ "jax-rocm*",
],
local_wheel_workspaces = ["//jaxlib:jax.bzl"],
requirements = {
diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl
index dfa7657e6..1d5b5f4cf 100644
--- a/jaxlib/jax.bzl
+++ b/jaxlib/jax.bzl
@@ -192,8 +192,10 @@ def _gpu_test_deps():
"//jax_plugins:gpu_plugin_only_test_deps",
],
"//jax:config_build_jaxlib_false": [
- "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps",
- "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps",
+ "//jaxlib/tools:rocm_plugin_kernels_wheel",
+ "//jaxlib/tools:rocm_plugin_pjrt_wheel",
+ #"//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps",
+ #"//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps",
],
"//jax:config_build_jaxlib_wheel": [
"//jaxlib/tools:jax_cuda_plugin_py_import",
@@ -300,7 +302,7 @@ def jax_multiplatform_test(
shard_count = test_shards,
tags = test_tags,
main = main,
- exec_properties = tf_exec_properties({"tags": test_tags}),
+ exec_properties = {} #tf_exec_properties({"tags": test_tags}),
)

def jax_generate_backend_suites(backends = []):
diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel
index 58fbfd734..57cbd4252 100644
--- a/jaxlib/tools/BUILD.bazel
+++ b/jaxlib/tools/BUILD.bazel
@@ -564,6 +564,25 @@ py_import(
wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

+filegroup(
+ name = "rocm_wheel_deps",
+ srcs = [
+ ],
+)
+
+# Targets for importing ROCm plugin wheels
+py_import(
+ name = "rocm_plugin_kernels_wheel",
+ wheel = "@pypi_jax_rocm7_plugin//:whl",
+ wheel_deps = [":rocm_wheel_deps"],
+)
+
+py_import(
+ name = "rocm_plugin_pjrt_wheel",
+ wheel = "@pypi_jax_rocm7_pjrt//:whl",
+ wheel_deps = [":rocm_wheel_deps"],
+)
+
# Mosaic GPU

py_binary(
--
2.34.1

6 changes: 5 additions & 1 deletion jax_rocm_plugin/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,13 @@ build:rocm --copt=-Qunused-arguments
build:rocm --action_env=TF_HIPCC_CLANG="1"


# Flag to enable remote config
#############################################################################
# Configuration for running RBE builds and tests
#############################################################################
common --experimental_repo_remote_exec

try-import %workspace%/rbe.bazelrc

#############################################################################
# Some configs to make getting some forms of debug builds. In general, the
# codebase is only regularly built with optimizations. Use 'debug_symbols' to
Expand Down
10 changes: 10 additions & 0 deletions jax_rocm_plugin/build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
""",
)

parser.add_argument(
"--rbe",
action="store_true",
help="If set, uses Bazel Remote Build Execution",
)

# CUDA Options
cuda_group = parser.add_argument_group("CUDA Options")
cuda_group.add_argument(
Expand Down Expand Up @@ -478,6 +484,10 @@ async def main():
}
target_cpu = wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch

if args.rbe:
logging.debug("Using RBE")
wheel_build_command_base.append("--config=rocm_rbe")

if args.local_xla_path:
logging.debug("Local XLA path: %s", args.local_xla_path)
wheel_build_command_base.append(
Expand Down
11 changes: 11 additions & 0 deletions jax_rocm_plugin/build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def dist_wheels(
rocm_build_job,
rocm_build_num,
therock_path,
rbe,
compiler,
):
if xla_path:
Expand Down Expand Up @@ -101,6 +102,9 @@ def dist_wheels(
if xla_path:
bw_cmd.extend(["--xla-path", container_xla_path])

if rbe:
bw_cmd.append("--rbe")

bw_cmd.append("/jax")

cmd = ["docker", "run"]
Expand Down Expand Up @@ -225,6 +229,12 @@ def parse_args():
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
)

p.add_argument(
"--rbe",
action="store_true",
help="Enable Bazel RBE builds for JAX",
)

p.add_argument(
"--compiler",
choices=["gcc", "clang"],
Expand Down Expand Up @@ -252,6 +262,7 @@ def main():
args.rocm_build_job,
args.rocm_build_num,
args.therock_path,
args.rbe,
args.compiler,
)

Expand Down
Loading
Loading