From 68f1200547254d630c2e3c239ff463cd175317cf Mon Sep 17 00:00:00 2001 From: Mark Lee Date: Sat, 9 Dec 2023 21:57:07 -0500 Subject: [PATCH] Bump to jax 0.4.21. (#237) --- axlearn/common/utils_spmd.py | 33 +-------------------------------- pyproject.toml | 10 +++++----- 2 files changed, 6 insertions(+), 37 deletions(-) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 658aced5e..7a37f48c8 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -3,13 +3,10 @@ """SPMD related utils.""" import logging -import socket from typing import Optional import jax -import jax.numpy as jnp import portpicker -from jax.experimental import multihost_utils _jax_distributed_initialized = False @@ -56,11 +53,7 @@ def setup( "distributed_coordinator, num_processes, and process_id " "should all be None for tpu backend." ) - jax.distributed.initialize( - coordinator_address=_infer_tpu_coordinator_address(), - num_processes=jax.process_count(), - process_id=jax.process_index(), - ) + jax.distributed.initialize() else: if distributed_coordinator is None and num_processes is None and process_id is None: logging.info( @@ -89,27 +82,3 @@ def setup( process_id=process_id, ) _jax_distributed_initialized = True - - -def _infer_tpu_coordinator_address() -> str: - """Infers a viable JAX coordination address on TPU (including over multiple TPU slices). - - TODO(markblee,tom_gunter): Delete this when multi-slice init is fully supported by JAX. - - Returns: - A coordinator address string as "ip:port". - """ - slice_local_coordinator_ip = socket.gethostbyname(socket.gethostname()) - # E.g. "172.31.4.83". - slice_local_coordinator_ip_as_nums = [int(num) for num in slice_local_coordinator_ip.split(".")] - # E.g. [172, 31, 4, 83]. - global_coordinator_ip_as_nums = multihost_utils.broadcast_one_to_all( - jnp.asarray(slice_local_coordinator_ip_as_nums) - ) - global_coordinator_ip = ".".join([str(num) for num in global_coordinator_ip_as_nums]) - # E.g. "172.31.4.83" on all hosts on all slices. - global_coordinator_port = multihost_utils.broadcast_one_to_all( - jnp.asarray(portpicker.pick_unused_port()) - ) - global_coordinator_address = f"{global_coordinator_ip}:{global_coordinator_port}" - return global_coordinator_address diff --git a/pyproject.toml b/pyproject.toml index 834f2d250..7ab35b4b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ dependencies = [ "chex<0.1.81", # chex 0.1.81 depends on numpy>=1.25.0. "flax==0.7.4", # only for checkpoints. "importlab==0.7", # breaks pytype on 0.8 - "jax>=0.4.18,<=0.4.20", # jax 0.4.20 runs into issues on GPU. - "jaxlib>=0.4.18,<=0.4.20", + "jax==0.4.21", + "jaxlib==0.4.21", "nltk==3.7", # for text preprocessing "numpy<1.24", # needed to pin to < 1.24; tf ragged_tensor depends on deprecated np.object. "optax==0.1.7", # optimizers (0.1.0 has known bugs). @@ -43,8 +43,8 @@ apple-silicon = [ "absl-py", "chex>=0.1.7", "flax==0.7.4", # only for checkpoints. - "jax>=0.4.18,<=0.4.20", - "jaxlib>=0.4.18,<=0.4.20", + "jax==0.4.21", + "jaxlib==0.4.21", "nltk==3.7", # for text preprocessing "optax>=0.1.7", # optimizers (0.1.0 has known bugs). "portpicker", @@ -94,7 +94,7 @@ gcp = [ # Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install. tpu = [ "axlearn[gcp]", - "jax[tpu]==0.4.20", # must be >=0.4.19 for compat with v5p. + "jax[tpu]==0.4.21", # must be >=0.4.19 for compat with v5p. ] # Vertex AI tensorboard. vertexai_tensorboard = [