diff --git a/.github/workflows/_test_requiring_torch_cuda.yml b/.github/workflows/_test_requiring_torch_cuda.yml index 1cb844e464a..ce15473f882 100644 --- a/.github/workflows/_test_requiring_torch_cuda.yml +++ b/.github/workflows/_test_requiring_torch_cuda.yml @@ -87,14 +87,6 @@ jobs: uses: actions/checkout@v4 with: path: pytorch/xla - - name: Extra CI deps - shell: bash - run: | - set -x - pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html - pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - if: ${{ matrix.run_triton_tests }} - name: Install Triton shell: bash run: | diff --git a/build_util.py b/build_util.py index 487f5116323..fc8af01e474 100644 --- a/build_util.py +++ b/build_util.py @@ -3,6 +3,129 @@ import subprocess import sys import shutil +from dataclasses import dataclass +import functools + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +@functools.lru_cache +def get_pinned_packages(): + """Gets the versions of important pinned dependencies of torch_xla.""" + return PinnedPackages( + use_nightly=True, + date='20250320', + raw_libtpu_version='0.0.12', + raw_jax_version='0.5.4', + raw_jaxlib_version='0.5.4', + ) + + +@functools.lru_cache +def get_build_version(): + xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR) + version = os.getenv('TORCH_XLA_VERSION', '2.8.0') + if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): + try: + version += '+git' + xla_git_sha[:7] + except Exception: + pass + return version + + +@functools.lru_cache +def get_git_head_sha(base_dir): + xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=base_dir).decode('ascii').strip() + if os.path.isdir(os.path.join(base_dir, '..', '.git')): + torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=os.path.join( + base_dir, + '..')).decode('ascii').strip() + else: + torch_git_sha = '' + return xla_git_sha, torch_git_sha + + +def get_jax_cuda_requirements(): + """Get a list of JAX CUDA requirements for use in setup.py without extra package registries.""" + pinned_packages = get_pinned_packages() + if not pinned_packages.use_nightly: + # Stable versions of JAX can be directly installed from PyPI. + return [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}', + f'jax[cuda12]=={pinned_packages.jax_version}', + ] + + # Install nightly JAX libraries from the JAX package registries. + jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl' + jaxlib = [] + for python_minor_version in [9, 10, 11]: + jaxlib.append( + f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + + # Install nightly JAX CUDA libraries. + jax_cuda = [ + f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl' + ] + for python_minor_version in [9, 10, 11]: + jax_cuda.append( + f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + + return [jax] + jaxlib + jax_cuda + + +@dataclass(eq=True, frozen=True) +class PinnedPackages: + use_nightly: bool + """Whether to use nightly or stable libtpu and JAX""" + + date: str + raw_libtpu_version: str + raw_jax_version: str + raw_jaxlib_version: str + + @property + def libtpu_version(self) -> str: + if self.use_nightly: + return f'{self.raw_libtpu_version}.dev{self.date}' + else: + return self.raw_libtpu_version + + @property + def jax_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jax_version}.dev{self.date}' + else: + return self.raw_jax_version + + @property + def jaxlib_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jaxlib_version}.dev{self.date}' + else: + return self.raw_jaxlib_version + + @property + def libtpu_storage_directory(self) -> str: + if self.use_nightly: + return 'libtpu-nightly-releases' + else: + return 'libtpu-lts-releases' + + @property + def libtpu_wheel_name(self) -> str: + if self.use_nightly: + return f'libtpu-{self.libtpu_version}+nightly' + else: + return f'libtpu-{self.libtpu_version}' + + @property + def libtpu_storage_path(self) -> str: + return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-py3-none-linux_x86_64.whl' def check_env_flag(name: str, default: str = '') -> bool: @@ -60,7 +183,7 @@ def bazel_build(bazel_target: str, ] # Remove duplicated flags because they confuse bazel - flags = set(bazel_options_from_env() + options) + flags = set(list(bazel_options_from_env()) + list(options)) bazel_argv.extend(flags) print(' '.join(bazel_argv), flush=True) diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py index 2652880c6fd..6e075e3bbf5 100644 --- a/plugins/cuda/setup.py +++ b/plugins/cuda/setup.py @@ -1,4 +1,3 @@ -import datetime import os import sys @@ -12,6 +11,6 @@ 'torch_xla_cuda_plugin/lib', ['--config=cuda']) setuptools.setup( - # TODO: Use a common version file - version=os.getenv('TORCH_XLA_VERSION', - f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}')) + version=build_util.get_build_version(), + install_requires=build_util.get_jax_cuda_requirements(), +) diff --git a/setup.py b/setup.py index 558ad8b97c5..e0fdbd50511 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,6 @@ import os import requests import shutil -import subprocess import sys import tempfile import zipfile @@ -63,25 +62,7 @@ import build_util base_dir = os.path.dirname(os.path.abspath(__file__)) - -USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax - -_date = '20250320' -_libtpu_version = '0.0.12' -_jax_version = '0.5.4' -_jaxlib_version = '0.5.4' - -_libtpu_wheel_name = f'libtpu-{_libtpu_version}' -_libtpu_storage_directory = 'libtpu-lts-releases' - -if USE_NIGHTLY: - _libtpu_version += f".dev{_date}" - _jax_version += f".dev{_date}" - _jaxlib_version += f".dev{_date}" - _libtpu_wheel_name += f".dev{_date}+nightly" - _libtpu_storage_directory = 'libtpu-nightly-releases' - -_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}-py3-none-linux_x86_64.whl' +pinned_packages = build_util.get_pinned_packages() def _get_build_mode(): @@ -90,29 +71,6 @@ def _get_build_mode(): return sys.argv[i] -def get_git_head_sha(base_dir): - xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=base_dir).decode('ascii').strip() - if os.path.isdir(os.path.join(base_dir, '..', '.git')): - torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=os.path.join( - base_dir, - '..')).decode('ascii').strip() - else: - torch_git_sha = '' - return xla_git_sha, torch_git_sha - - -def get_build_version(xla_git_sha): - version = os.getenv('TORCH_XLA_VERSION', '2.8.0') - if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): - try: - version += '+git' + xla_git_sha[:7] - except Exception: - pass - return version - - def create_version_files(base_dir, version, xla_git_sha, torch_git_sha): print('Building torch_xla version: {}'.format(version)) print('XLA Commit ID: {}'.format(xla_git_sha)) @@ -151,7 +109,7 @@ def maybe_bundle_libtpu(base_dir): print('No installed libtpu found. Downloading...') with tempfile.NamedTemporaryFile('wb') as whl: - resp = requests.get(_libtpu_storage_path) + resp = requests.get(pinned_packages.libtpu_storage_path) resp.raise_for_status() whl.write(resp.content) @@ -194,8 +152,8 @@ def run(self): distutils.command.clean.clean.run(self) -xla_git_sha, torch_git_sha = get_git_head_sha(base_dir) -version = get_build_version(xla_git_sha) +xla_git_sha, torch_git_sha = build_util.get_git_head_sha(base_dir) +version = build_util.get_build_version() build_mode = _get_build_mode() if build_mode not in ['clean']: @@ -226,7 +184,7 @@ class BuildBazelExtension(build_ext.build_ext): def run(self): for ext in self.extensions: self.bazel_build(ext) - command.build_ext.build_ext.run(self) + command.build_ext.build_ext.run(self) # type: ignore def bazel_build(self, ext): if not os.path.exists(self.build_temp): @@ -328,11 +286,14 @@ def run(self): # On Cloud TPU VM install with: # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html 'tpu': [ - f'libtpu=={_libtpu_version}', + f'libtpu=={pinned_packages.libtpu_version}', 'tpu-info', ], # pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - 'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'], + 'pallas': [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}' + ], }, cmdclass={ 'build_ext': BuildBazelExtension,