From 47d6ea1ecd8fbbab9fff274877d8ffe110f1f706 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Tue, 1 Apr 2025 15:04:29 -0700 Subject: [PATCH 1/4] Make GPU CUDA plugin require JAX Some XLA GPU features require JAX. Rather than only installing the latest version of JAX in CI, we'll just make the CUDA plugin depend on a version of JAX that's the same as what's used by PyTorch/XLA on TPU. (Except the JAX CUDA wheels). --- .../workflows/_test_requiring_torch_cuda.yml | 6 - build_util.py | 150 +++++++++++++++++- plugins/cuda/setup.py | 7 +- setup.py | 85 ++-------- 4 files changed, 161 insertions(+), 87 deletions(-) diff --git a/.github/workflows/_test_requiring_torch_cuda.yml b/.github/workflows/_test_requiring_torch_cuda.yml index 9861e4fba161..94a8797dd1ba 100644 --- a/.github/workflows/_test_requiring_torch_cuda.yml +++ b/.github/workflows/_test_requiring_torch_cuda.yml @@ -97,12 +97,6 @@ jobs: uses: actions/checkout@v4 with: path: pytorch/xla - - name: Extra CI deps - if: inputs.has_code_changes == 'true' && matrix.run_triton_tests - shell: bash - run: | - set -x - pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - name: Install Triton if: inputs.has_code_changes == 'true' shell: bash diff --git a/build_util.py b/build_util.py index 487f5116323e..1d8f4a52df70 100644 --- a/build_util.py +++ b/build_util.py @@ -1,8 +1,154 @@ import os -from typing import Iterable +from collections.abc import Iterable import subprocess import sys import shutil +from dataclasses import dataclass +import functools + +import platform + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +@functools.lru_cache +def get_pinned_packages(): + """Gets the versions of important pinned dependencies of torch_xla.""" + return PinnedPackages( + use_nightly=True, + date='20250424', + raw_libtpu_version='0.0.14', + raw_jax_version='0.6.1', + raw_jaxlib_version='0.6.1', + ) + + +@functools.lru_cache +def get_build_version(): + xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR) + version = os.getenv('TORCH_XLA_VERSION', '2.8.0') + if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): + try: + version += '+git' + xla_git_sha[:7] + except Exception: + pass + return version + + +@functools.lru_cache +def get_git_head_sha(base_dir): + xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=base_dir).decode('ascii').strip() + if os.path.isdir(os.path.join(base_dir, '..', '.git')): + torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=os.path.join( + base_dir, + '..')).decode('ascii').strip() + else: + torch_git_sha = '' + return xla_git_sha, torch_git_sha + + +@functools.lru_cache +def get_jax_install_requirements(): + """Get a list of JAX requirements for use in setup.py without extra package registries.""" + pinned_packages = get_pinned_packages() + if not pinned_packages.use_nightly: + # Stable versions of JAX can be directly installed from PyPI. + return [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}', + ] + + # Install nightly JAX libraries from the JAX package registries. + # TODO(https://github.com/pytorch/xla/issues/9064): This URL needs to be + # updated to use the new JAX package registry for any JAX builds after Apr 28, 2025. + jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl' + jaxlib = [] + for python_minor_version in [9, 10, 11]: + jaxlib.append( + f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + return [jax] + jaxlib + + +@functools.lru_cache +def get_jax_cuda_requirements(): + """Get a list of JAX CUDA requirements for use in setup.py without extra package registries.""" + pinned_packages = get_pinned_packages() + jax_requirements = get_jax_install_requirements() + + # Install nightly JAX CUDA libraries. + jax_cuda = [ + f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl' + ] + for python_minor_version in [9, 10, 11]: + jax_cuda.append( + f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + + return jax_requirements + jax_cuda + + +@dataclass(eq=True, frozen=True) +class PinnedPackages: + use_nightly: bool + """Whether to use nightly or stable libtpu and JAX""" + + date: str + """The date of the libtpu and jax build""" + + raw_libtpu_version: str + """libtpu version string in [major].[minor].[patch] format.""" + + raw_jax_version: str + """jax version string in [major].[minor].[patch] format.""" + + raw_jaxlib_version: str + """jaxlib version string in [major].[minor].[patch] format.""" + + @property + def libtpu_version(self) -> str: + if self.use_nightly: + return f'{self.raw_libtpu_version}.dev{self.date}' + else: + return self.raw_libtpu_version + + @property + def jax_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jax_version}.dev{self.date}' + else: + return self.raw_jax_version + + @property + def jaxlib_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jaxlib_version}.dev{self.date}' + else: + return self.raw_jaxlib_version + + @property + def libtpu_storage_directory(self) -> str: + if self.use_nightly: + return 'libtpu-nightly-releases' + else: + return 'libtpu-lts-releases' + + @property + def libtpu_wheel_name(self) -> str: + if self.use_nightly: + return f'libtpu-{self.libtpu_version}+nightly' + else: + return f'libtpu-{self.libtpu_version}' + + @property + def libtpu_storage_path(self) -> str: + platform_machine = platform.machine() + # The suffix can be changed when the version is updated. Check + # https://storage.googleapis.com/libtpu-wheels/index.html for correct name. + suffix = f"py3-none-manylinux_2_31_{platform_machine}" + return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-{suffix}.whl' def check_env_flag(name: str, default: str = '') -> bool: @@ -60,7 +206,7 @@ def bazel_build(bazel_target: str, ] # Remove duplicated flags because they confuse bazel - flags = set(bazel_options_from_env() + options) + flags = set(list(bazel_options_from_env()) + list(options)) bazel_argv.extend(flags) print(' '.join(bazel_argv), flush=True) diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py index 2652880c6fd7..6e075e3bbf57 100644 --- a/plugins/cuda/setup.py +++ b/plugins/cuda/setup.py @@ -1,4 +1,3 @@ -import datetime import os import sys @@ -12,6 +11,6 @@ 'torch_xla_cuda_plugin/lib', ['--config=cuda']) setuptools.setup( - # TODO: Use a common version file - version=os.getenv('TORCH_XLA_VERSION', - f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}')) + version=build_util.get_build_version(), + install_requires=build_util.get_jax_cuda_requirements(), +) diff --git a/setup.py b/setup.py index 3eae00e2796a..ec9953fa68af 100644 --- a/setup.py +++ b/setup.py @@ -56,41 +56,14 @@ import re import requests import shutil -import subprocess import sys import tempfile import zipfile import build_util -import platform - -platform_machine = platform.machine() - base_dir = os.path.dirname(os.path.abspath(__file__)) - -USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax - -_date = '20250424' - -_libtpu_version = '0.0.14' -_jax_version = '0.6.1' -_jaxlib_version = '0.6.1' - -if USE_NIGHTLY: - _libtpu_version += f".dev{_date}" - _jax_version += f'.dev{_date}' - _jaxlib_version += f'.dev{_date}' - _libtpu_wheel_name = f'libtpu-{_libtpu_version}.dev{_date}+nightly-py3-none-manylinux_2_31_{platform_machine}' - _libtpu_storage_directory = 'libtpu-nightly-releases' -else: - # The postfix can be changed when the version is updated. Check - # https://storage.googleapis.com/libtpu-wheels/index.html for correct - # versioning. - _libtpu_wheel_name = f'libtpu-{_libtpu_version}-py3-none-manylinux_2_31_{platform_machine}' - _libtpu_storage_directory = 'libtpu-lts-releases' - -_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}.whl' +pinned_packages = build_util.get_pinned_packages() def _get_build_mode(): @@ -99,29 +72,6 @@ def _get_build_mode(): return sys.argv[i] -def get_git_head_sha(base_dir): - xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=base_dir).decode('ascii').strip() - if os.path.isdir(os.path.join(base_dir, '..', '.git')): - torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=os.path.join( - base_dir, - '..')).decode('ascii').strip() - else: - torch_git_sha = '' - return xla_git_sha, torch_git_sha - - -def get_build_version(xla_git_sha): - version = os.getenv('TORCH_XLA_VERSION', '2.8.0') - if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): - try: - version += '+git' + xla_git_sha[:7] - except Exception: - pass - return version - - def create_version_files(base_dir, version, xla_git_sha, torch_git_sha): print('Building torch_xla version: {}'.format(version)) print('XLA Commit ID: {}'.format(xla_git_sha)) @@ -160,7 +110,7 @@ def maybe_bundle_libtpu(base_dir): print('No installed libtpu found. Downloading...') with tempfile.NamedTemporaryFile('wb') as whl: - resp = requests.get(_libtpu_storage_path) + resp = requests.get(pinned_packages.libtpu_storage_path) resp.raise_for_status() whl.write(resp.content) @@ -203,8 +153,8 @@ def run(self): distutils.command.clean.clean.run(self) -xla_git_sha, torch_git_sha = get_git_head_sha(base_dir) -version = get_build_version(xla_git_sha) +xla_git_sha, torch_git_sha = build_util.get_git_head_sha(base_dir) +version = build_util.get_build_version() build_mode = _get_build_mode() if build_mode not in ['clean']: @@ -353,24 +303,6 @@ def link_packages(self): f.write(path + "\n") -def _get_jax_install_requirements(): - if not USE_NIGHTLY: - # Stable versions of JAX can be directly installed from PyPI. - return [ - f'jaxlib=={_jaxlib_version}', - f'jax=={_jax_version}', - ] - - # Install nightly JAX libraries from the JAX package registries. - jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{_jax_version}-py3-none-any.whl' - jaxlib = [] - for python_minor_version in [9, 10, 11]: - jaxlib.append( - f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' - ) - return [jax] + jaxlib - - setup( name=os.environ.get('TORCH_XLA_PACKAGE_NAME', 'torch_xla'), version=version, @@ -411,7 +343,7 @@ def _get_jax_install_requirements(): # to Python 3.10 'importlib_metadata>=4.6;python_version<"3.10"', # Some torch operations are lowered to HLO via JAX. - *_get_jax_install_requirements(), + *build_util.get_jax_install_requirements(), ], package_data={ 'torch_xla': ['lib/*.so*',], @@ -430,13 +362,16 @@ def _get_jax_install_requirements(): # On Cloud TPU VM install with: # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html 'tpu': [ - f'libtpu=={_libtpu_version}', + f'libtpu=={pinned_packages.libtpu_version}', 'tpu-info', ], # As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla. # However, this no-op extras_require entrypoint is left here for backwards compatibility. # pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - 'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'], + 'pallas': [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}' + ], }, cmdclass={ 'build_ext': BuildBazelExtension, From 2e67486665b10cf2e8be6a7fea9e53785984518d Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Fri, 23 May 2025 18:29:17 +0000 Subject: [PATCH 2/4] actually add jax dependencies to the CUDA plugin --- .github/workflows/_test_requiring_torch_cuda.yml | 3 +++ build_util.py | 4 ++-- plugins/cuda/pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/_test_requiring_torch_cuda.yml b/.github/workflows/_test_requiring_torch_cuda.yml index 94a8797dd1ba..79f5e05fa6d6 100644 --- a/.github/workflows/_test_requiring_torch_cuda.yml +++ b/.github/workflows/_test_requiring_torch_cuda.yml @@ -85,6 +85,9 @@ jobs: echo "Check if CUDA is available for PyTorch..." python -c "import torch; assert torch.cuda.is_available()" echo "CUDA is available for PyTorch." + echo "Check if CUDA is available for PyTorch/XLA..." + PJRT_DEVICE=CUDA python -c "import torch; import torch_xla; print(torch.tensor([1,2,3], device='xla')); assert torch_xla.runtime.device_type() == 'CUDA'" + echo "CUDA is available for PyTorch/XLA." - name: Checkout PyTorch Repo if: inputs.has_code_changes == 'true' uses: actions/checkout@v4 diff --git a/build_util.py b/build_util.py index 1d8f4a52df70..edb22f22ac23 100644 --- a/build_util.py +++ b/build_util.py @@ -80,11 +80,11 @@ def get_jax_cuda_requirements(): # Install nightly JAX CUDA libraries. jax_cuda = [ - f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl' + f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl' ] for python_minor_version in [9, 10, 11]: jax_cuda.append( - f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' ) return jax_requirements + jax_cuda diff --git a/plugins/cuda/pyproject.toml b/plugins/cuda/pyproject.toml index d44a2ea3bd53..a5e65c6bc69d 100644 --- a/plugins/cuda/pyproject.toml +++ b/plugins/cuda/pyproject.toml @@ -9,7 +9,7 @@ authors = [ ] description = "PyTorch/XLA CUDA Plugin" requires-python = ">=3.8" -dynamic = ["version"] +dynamic = ["version", "dependencies"] [tool.setuptools.package-data] torch_xla_cuda_plugin = ["lib/*.so"] From c2024cb8d7bace685c275702ac43aed7eebcc8e7 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Fri, 23 May 2025 19:52:47 +0000 Subject: [PATCH 3/4] Don't bundle jax wheels into the cuda plugin artifact --- .../roles/build_plugin/tasks/main.yaml | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/infra/ansible/roles/build_plugin/tasks/main.yaml b/infra/ansible/roles/build_plugin/tasks/main.yaml index 142d29c3718f..65282bbe1520 100644 --- a/infra/ansible/roles/build_plugin/tasks/main.yaml +++ b/infra/ansible/roles/build_plugin/tasks/main.yaml @@ -6,12 +6,27 @@ - name: Build PyTorch/XLA CUDA Plugin ansible.builtin.command: - cmd: pip wheel -w /dist plugins/cuda -v + cmd: pip wheel plugins/cuda -v chdir: "{{ (src_root, 'pytorch/xla') | path_join }}" environment: "{{ env_vars }}" when: accelerator == "cuda" -- name: Find CUDA plugin wheel pytorch/xla/dist +- name: Find the built CUDA plugin wheel + ansible.builtin.find: + paths: "{{ (src_root, 'pytorch/xla') | path_join }}" # Look in the dir where pip saved the wheel + patterns: "torch_xla_cuda_plugin-*.whl" + recurse: no + when: accelerator == "cuda" + register: built_plugin_wheel_info + +- name: Copy the CUDA plugin wheel to /dist + ansible.builtin.copy: + src: "{{ item.path }}" + dest: "/dist/{{ item.path | basename }}" # Ensure only the filename is used for dest + loop: "{{ built_plugin_wheel_info.files }}" + when: accelerator == "cuda" and built_plugin_wheel_info.files | length > 0 + +- name: Find CUDA plugin wheel in /dist ansible.builtin.find: path: "/dist" pattern: "torch_xla_cuda_plugin*.whl" From 9a11f5d8eb00952bf752f1edc6a01ea1034d9d8b Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Fri, 23 May 2025 22:27:05 +0000 Subject: [PATCH 4/4] Update cudnn to 9.8.0 --- infra/ansible/config/cuda_deps.yaml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/infra/ansible/config/cuda_deps.yaml b/infra/ansible/config/cuda_deps.yaml index 3732bb0f93ec..9609d96e6209 100644 --- a/infra/ansible/config/cuda_deps.yaml +++ b/infra/ansible/config/cuda_deps.yaml @@ -1,22 +1,22 @@ # Versions of cuda dependencies for given cuda versions. # Note: wrap version in quotes to ensure they're treated as strings. cuda_deps: - # List all libcudnn8 versions with `apt list -a libcudnn8` + # Find package versions from https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ libcudnn: - "12.8": libcudnn9-cuda-12=9.1.1.17-1 - "12.6": libcudnn9-cuda-12=9.1.1.17-1 - "12.4": libcudnn9-cuda-12=9.1.1.17-1 - "12.3": libcudnn9-cuda-12=9.1.1.17-1 + "12.8": libcudnn9-cuda-12=9.8.0.87-1 + "12.6": libcudnn9-cuda-12=9.8.0.87-1 + "12.4": libcudnn9-cuda-12=9.8.0.87-1 + "12.3": libcudnn9-cuda-12=9.8.0.87-1 "12.1": libcudnn8=8.9.2.26-1+cuda12.1 "12.0": libcudnn8=8.8.0.121-1+cuda12.0 "11.8": libcudnn8=8.7.0.84-1+cuda11.8 "11.7": libcudnn8=8.5.0.96-1+cuda11.7 "11.2": libcudnn8=8.1.1.33-1+cuda11.2 libcudnn-dev: - "12.8": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.6": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.4": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.3": libcudnn9-dev-cuda-12=9.1.1.17-1 + "12.8": libcudnn9-dev-cuda-12=9.8.0.87-1 + "12.6": libcudnn9-dev-cuda-12=9.8.0.87-1 + "12.4": libcudnn9-dev-cuda-12=9.8.0.87-1 + "12.3": libcudnn9-dev-cuda-12=9.8.0.87-1 "12.1": libcudnn8-dev=8.9.2.26-1+cuda12.1 "12.0": libcudnn8-dev=8.8.0.121-1+cuda12.0 "11.8": libcudnn8-dev=8.7.0.84-1+cuda11.8