From 1b6d9ac7b07401ff86e5024115529306e1a9c887 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Tue, 1 Apr 2025 15:04:29 -0700 Subject: [PATCH] Make GPU CUDA plugin require JAX Some XLA GPU features require JAX. Rather than only installing the latest version of JAX in CI, we'll just make the CUDA plugin depend on a version of JAX that's the same as what's used by PyTorch/XLA on TPU. (Except the JAX CUDA wheels). --- .../workflows/_test_requiring_torch_cuda.yml | 8 -- build_util.py | 125 +++++++++++++++++- plugins/cuda/setup.py | 7 +- setup.py | 59 ++------- 4 files changed, 137 insertions(+), 62 deletions(-) diff --git a/.github/workflows/_test_requiring_torch_cuda.yml b/.github/workflows/_test_requiring_torch_cuda.yml index 1cb844e464af..ce15473f8829 100644 --- a/.github/workflows/_test_requiring_torch_cuda.yml +++ b/.github/workflows/_test_requiring_torch_cuda.yml @@ -87,14 +87,6 @@ jobs: uses: actions/checkout@v4 with: path: pytorch/xla - - name: Extra CI deps - shell: bash - run: | - set -x - pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html - pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - if: ${{ matrix.run_triton_tests }} - name: Install Triton shell: bash run: | diff --git a/build_util.py b/build_util.py index 487f5116323e..fc8af01e4745 100644 --- a/build_util.py +++ b/build_util.py @@ -3,6 +3,129 @@ import subprocess import sys import shutil +from dataclasses import dataclass +import functools + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +@functools.lru_cache +def get_pinned_packages(): + """Gets the versions of important pinned dependencies of torch_xla.""" + return PinnedPackages( + use_nightly=True, + date='20250320', + raw_libtpu_version='0.0.12', + raw_jax_version='0.5.4', + raw_jaxlib_version='0.5.4', + ) + + +@functools.lru_cache +def get_build_version(): + xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR) + version = os.getenv('TORCH_XLA_VERSION', '2.8.0') + if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): + try: + version += '+git' + xla_git_sha[:7] + except Exception: + pass + return version + + +@functools.lru_cache +def get_git_head_sha(base_dir): + xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=base_dir).decode('ascii').strip() + if os.path.isdir(os.path.join(base_dir, '..', '.git')): + torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=os.path.join( + base_dir, + '..')).decode('ascii').strip() + else: + torch_git_sha = '' + return xla_git_sha, torch_git_sha + + +def get_jax_cuda_requirements(): + """Get a list of JAX CUDA requirements for use in setup.py without extra package registries.""" + pinned_packages = get_pinned_packages() + if not pinned_packages.use_nightly: + # Stable versions of JAX can be directly installed from PyPI. + return [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}', + f'jax[cuda12]=={pinned_packages.jax_version}', + ] + + # Install nightly JAX libraries from the JAX package registries. + jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl' + jaxlib = [] + for python_minor_version in [9, 10, 11]: + jaxlib.append( + f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + + # Install nightly JAX CUDA libraries. + jax_cuda = [ + f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl' + ] + for python_minor_version in [9, 10, 11]: + jax_cuda.append( + f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + + return [jax] + jaxlib + jax_cuda + + +@dataclass(eq=True, frozen=True) +class PinnedPackages: + use_nightly: bool + """Whether to use nightly or stable libtpu and JAX""" + + date: str + raw_libtpu_version: str + raw_jax_version: str + raw_jaxlib_version: str + + @property + def libtpu_version(self) -> str: + if self.use_nightly: + return f'{self.raw_libtpu_version}.dev{self.date}' + else: + return self.raw_libtpu_version + + @property + def jax_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jax_version}.dev{self.date}' + else: + return self.raw_jax_version + + @property + def jaxlib_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jaxlib_version}.dev{self.date}' + else: + return self.raw_jaxlib_version + + @property + def libtpu_storage_directory(self) -> str: + if self.use_nightly: + return 'libtpu-nightly-releases' + else: + return 'libtpu-lts-releases' + + @property + def libtpu_wheel_name(self) -> str: + if self.use_nightly: + return f'libtpu-{self.libtpu_version}+nightly' + else: + return f'libtpu-{self.libtpu_version}' + + @property + def libtpu_storage_path(self) -> str: + return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-py3-none-linux_x86_64.whl' def check_env_flag(name: str, default: str = '') -> bool: @@ -60,7 +183,7 @@ def bazel_build(bazel_target: str, ] # Remove duplicated flags because they confuse bazel - flags = set(bazel_options_from_env() + options) + flags = set(list(bazel_options_from_env()) + list(options)) bazel_argv.extend(flags) print(' '.join(bazel_argv), flush=True) diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py index 2652880c6fd7..6e075e3bbf57 100644 --- a/plugins/cuda/setup.py +++ b/plugins/cuda/setup.py @@ -1,4 +1,3 @@ -import datetime import os import sys @@ -12,6 +11,6 @@ 'torch_xla_cuda_plugin/lib', ['--config=cuda']) setuptools.setup( - # TODO: Use a common version file - version=os.getenv('TORCH_XLA_VERSION', - f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}')) + version=build_util.get_build_version(), + install_requires=build_util.get_jax_cuda_requirements(), +) diff --git a/setup.py b/setup.py index 558ad8b97c58..e0fdbd505112 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,6 @@ import os import requests import shutil -import subprocess import sys import tempfile import zipfile @@ -63,25 +62,7 @@ import build_util base_dir = os.path.dirname(os.path.abspath(__file__)) - -USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax - -_date = '20250320' -_libtpu_version = '0.0.12' -_jax_version = '0.5.4' -_jaxlib_version = '0.5.4' - -_libtpu_wheel_name = f'libtpu-{_libtpu_version}' -_libtpu_storage_directory = 'libtpu-lts-releases' - -if USE_NIGHTLY: - _libtpu_version += f".dev{_date}" - _jax_version += f".dev{_date}" - _jaxlib_version += f".dev{_date}" - _libtpu_wheel_name += f".dev{_date}+nightly" - _libtpu_storage_directory = 'libtpu-nightly-releases' - -_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}-py3-none-linux_x86_64.whl' +pinned_packages = build_util.get_pinned_packages() def _get_build_mode(): @@ -90,29 +71,6 @@ def _get_build_mode(): return sys.argv[i] -def get_git_head_sha(base_dir): - xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=base_dir).decode('ascii').strip() - if os.path.isdir(os.path.join(base_dir, '..', '.git')): - torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=os.path.join( - base_dir, - '..')).decode('ascii').strip() - else: - torch_git_sha = '' - return xla_git_sha, torch_git_sha - - -def get_build_version(xla_git_sha): - version = os.getenv('TORCH_XLA_VERSION', '2.8.0') - if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): - try: - version += '+git' + xla_git_sha[:7] - except Exception: - pass - return version - - def create_version_files(base_dir, version, xla_git_sha, torch_git_sha): print('Building torch_xla version: {}'.format(version)) print('XLA Commit ID: {}'.format(xla_git_sha)) @@ -151,7 +109,7 @@ def maybe_bundle_libtpu(base_dir): print('No installed libtpu found. Downloading...') with tempfile.NamedTemporaryFile('wb') as whl: - resp = requests.get(_libtpu_storage_path) + resp = requests.get(pinned_packages.libtpu_storage_path) resp.raise_for_status() whl.write(resp.content) @@ -194,8 +152,8 @@ def run(self): distutils.command.clean.clean.run(self) -xla_git_sha, torch_git_sha = get_git_head_sha(base_dir) -version = get_build_version(xla_git_sha) +xla_git_sha, torch_git_sha = build_util.get_git_head_sha(base_dir) +version = build_util.get_build_version() build_mode = _get_build_mode() if build_mode not in ['clean']: @@ -226,7 +184,7 @@ class BuildBazelExtension(build_ext.build_ext): def run(self): for ext in self.extensions: self.bazel_build(ext) - command.build_ext.build_ext.run(self) + command.build_ext.build_ext.run(self) # type: ignore def bazel_build(self, ext): if not os.path.exists(self.build_temp): @@ -328,11 +286,14 @@ def run(self): # On Cloud TPU VM install with: # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html 'tpu': [ - f'libtpu=={_libtpu_version}', + f'libtpu=={pinned_packages.libtpu_version}', 'tpu-info', ], # pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - 'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'], + 'pallas': [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}' + ], }, cmdclass={ 'build_ext': BuildBazelExtension,