diff --git a/Dockerfile b/Dockerfile index 3a63c54..15ccd6b 100755 --- a/Dockerfile +++ b/Dockerfile @@ -20,8 +20,9 @@ RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/l RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly RUN pip3 install git+https://github.com/deepmind/jmp +RUN pip3 install git+https://github.com/Findus23/jax-array-info.git WORKDIR /ReAct_Jax # Set the entry point to bash -ENTRYPOINT ["/bin/bash"] \ No newline at end of file +ENTRYPOINT ["/bin/bash"] diff --git a/ReAct/utils/helpers.py b/ReAct/utils/helpers.py index 0148583..934a5b0 100644 --- a/ReAct/utils/helpers.py +++ b/ReAct/utils/helpers.py @@ -6,6 +6,7 @@ import equinox as eqx import jax import jax.numpy as jnp +from jax_array_info import sharding_info from jaxtyping import Array, PRNGKeyArray, PyTree from ReAct.model.baseline import GPT @@ -114,6 +115,16 @@ def half_precision(model: eqx.Module) -> eqx.Module: lambda x: x.astype(jnp.bfloat16) if eqx.is_inexact_array(x) else x, model ) +def viz_obj(model: PyTree): + model = eqx.filter(model, eqx.is_array) + + def viz_fn(leaf): + print(f"\n=== leaf: {leaf.shape} ===\n") + return sharding_info(leaf) + + jax.tree_util.tree_map(viz_fn, model) + + def megatron_init(weight: Array, key: PRNGKeyArray) -> Array: """ Init all the weights with the Megatron paper init diff --git a/ReAct/utils/sharding.py b/ReAct/utils/sharding.py new file mode 100644 index 0000000..46626e8 --- /dev/null +++ b/ReAct/utils/sharding.py @@ -0,0 +1,196 @@ +import os +from abc import ABC, abstractmethod +from typing import List, Tuple + +import equinox as eqx +import jax +import jax.tree_util as jtu +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from jaxtyping import Array, PyTree + +from ReAct.model.baseline import GPT +from ReAct.utils.helpers import viz_obj + + +def get_strategy(strategy: str, *args): + strategy = strategy.strip().lower() + match strategy: + case 'ddp': + strat = DDPSharding(*args) + + case 'simple mp': + strat = SimpleMPSharding(*args) + + case 'megatron': + strat = MegatronSharding(*args) + + case _: + raise NotImplementedError(f'Strategy {strategy} does not exist.') + + return strat + + +class Sharding(ABC): + def __init__(self, model_axis: int = 1) -> None: + self.model_axis = model_axis + + @abstractmethod + def get_mesh(self) -> Mesh: + ... + + @abstractmethod + def shard_data(self, tree: PyTree) -> PyTree: + ... + + @abstractmethod + def shard_model(self, tree: PyTree) -> PyTree: + ... + + def get_devices(self): + num_devices = len(jax.devices()) + return mesh_utils.create_device_mesh((num_devices // self.model_axis, self.model_axis)) + + @staticmethod + def shard(a: PyTree, _sharding: NamedSharding) -> PyTree: + return eqx.filter_shard(a,_sharding) + + def add_indices_to_tree(self, tree: PyTree, start_index: int = 0, dims_to_count = 3): + ''' + dims_to_count: leaves of what `.ndim` would be counted. + ''' + def add_index(leaf, index): + return [leaf, index[0]] + + def index_incrementer(leaf: PyTree) -> List[int]: + if not eqx.is_array(leaf): + return [-999] + + nonlocal start_index + start_index += 1 if leaf.ndim >= 3 else 0 + return [start_index - 1] + + indexed_tree = jtu.tree_map(add_index, tree, jtu.tree_map(index_incrementer, tree)) + return indexed_tree + +class DDPSharding(Sharding): + def __init__(self, model_axis: int = 1) -> None: + super().__init__(model_axis) + self.mesh = self.get_mesh() + + def get_mesh(self) -> Mesh: + num_devices = len(jax.devices()) + devices = mesh_utils.create_device_mesh((num_devices, 1)) + mesh = Mesh(devices, axis_names=('data', None)) + + return mesh + + def shard_data(self, tree: PyTree | Array) -> PyTree | Array: + return self.shard(tree, NamedSharding(self.mesh, P('data'))) + + def shard_model(self, tree: PyTree) -> PyTree: + return jtu.tree_map(self.ddp_sharding, tree) + + def ddp_sharding(self, leaf: PyTree) -> PyTree: + return leaf + +class SimpleMPSharding(Sharding): + def __init__(self, strategy: str, model_axis: int = 2) -> None: + super().__init__(model_axis) + self.mesh = self.get_mesh() + + def get_mesh(self) -> Mesh: + return Mesh(self.get_devices(), axis_names=('data', 'model')) + + def shard_data(self, tree: PyTree | Array) -> PyTree | Array: + return self.shard(tree, NamedSharding(self.mesh, P('data'))) + + def shard_model(self, tree: PyTree) -> PyTree: + return jtu.tree_map(self.simple_sharding, tree) + + def simple_sharding(self, leaf: PyTree) -> PyTree: + if not eqx.is_array(leaf): + return leaf + + sharding_ = NamedSharding(self.mesh, P()) + + if leaf.ndim == 1: + sharding_ = NamedSharding(self.mesh, P("model")) + + if leaf.ndim >= 2: + sharding_ = NamedSharding(self.mesh, P(None, "model")) + + return self.shard(leaf, sharding_) + + +class MegatronSharding(Sharding): + def __init__(self, strategy: str, model_axis: int = 2) -> None: + super().__init__(model_axis) + self.mesh = self.get_mesh() + + def get_mesh(self) -> Mesh: + return Mesh(self.get_devices(), axis_names=('data', 'model')) + + def shard_data(self, tree: PyTree | Array) -> PyTree | Array: + return self.shard(tree, NamedSharding(self.mesh, P('data'))) + + def shard_model(self, tree: PyTree) -> PyTree: + is_leaf = lambda x: isinstance(x, list) # noqa: E731 + tree = self.add_indices_to_tree(tree, dims_to_count = 3) + return jtu.tree_map(self.megatron_sharding, tree, is_leaf=is_leaf) + + def megatron_sharding(self, leaf_and_index: PyTree[Tuple]) -> PyTree: + leaf, idx = leaf_and_index + + if not eqx.is_array(leaf): + return leaf + + sharding_ = NamedSharding(self.mesh, P()) + + # LN params and embedding 1Ds + if leaf.ndim == 1: + if max(leaf.shape) >= 2**14: + sharding_ = NamedSharding(self.mesh, P('model')) + else: + sharding_ = NamedSharding(self.mesh, P(None)) + + # embedding and unembedding + if leaf.ndim == 2: + if max(leaf.shape) >= 2**14: + # shard the bigger index + p_spec = [ + "model" if i == leaf.shape.index(max(leaf.shape)) else None + for i in range(len(leaf.shape)) + ] + sharding_ = NamedSharding(self.mesh, P(*p_spec)) + else: + sharding_ = NamedSharding(self.mesh, P(None, 'model')) + + if leaf.ndim == 3: + if idx % 2 == 0: + sharding_ = NamedSharding(self.mesh, P(None, None, "model")) + else: + sharding_ = NamedSharding(self.mesh, P(None, "model", None)) + + return self.shard(leaf, sharding_) + + +if __name__ == "__main__": + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + assert len(jax.devices()) == 8, "Hosts not correctly spoofed" + + key = jax.random.PRNGKey(0) + BSZ, SEQLEN, WIDTH = 32, 256, 64 + + model = GPT(4, SEQLEN, 2, WIDTH, 0.01, 50304, key=key) + strategy= get_strategy('megatron', 1) + + data = jax.numpy.ones((BSZ, SEQLEN)) + data = strategy.shard_data(tree=data) + sharded_model = strategy.shard_model(model) + + viz_obj(sharded_model) + + print('\n ++++++++ Sharded data: +++++++++++\n') + jax.debug.visualize_array_sharding(data) diff --git a/ReAct/utils/trainer.py b/ReAct/utils/trainer.py index bd815f7..3a64a7d 100644 --- a/ReAct/utils/trainer.py +++ b/ReAct/utils/trainer.py @@ -1,21 +1,17 @@ import os -import wandb +from typing import Any, Callable, Optional, Tuple, Union + import equinox as eqx -import threading -import queue import jax -import jax.experimental.mesh_utils as mesh_utils import jax.numpy as jnp -import jax.sharding as jshard import optax import optuna - -from typing import Any, Callable, Optional, Tuple, Union from jaxtyping import Array, Int, PRNGKeyArray, PyTree from jmp import Policy from torch.utils.data import DataLoader from tqdm.auto import tqdm +import wandb from inferencer import Inferencer from ReAct.model.baseline import GPT from ReAct.model.react import React @@ -33,15 +29,11 @@ _cross_entropy_with_logits_fwd, cross_entropy_with_logits, ) +from ReAct.utils.sharding import get_strategy half, full = jnp.bfloat16, jnp.float32 policy = Policy(compute_dtype=half, param_dtype=half, output_dtype=half) - -# Setting up distributed stuff -num_devices = len(jax.devices()) -devices = mesh_utils.create_device_mesh((num_devices, 1)) -sharding = jshard.PositionalSharding(devices) -replicated = sharding.replicate() +strategy = get_strategy('megatron', 2) # Stable CE (w/ z-loss) from PaLM ce_loss = cross_entropy_with_logits @@ -85,9 +77,8 @@ def make_step( num_classes: int, keys: PRNGKeyArray, ): - replicated = sharding.replicate() - model, opt_state = eqx.filter_shard((model, opt_state), replicated) - x, y, pad_mask = eqx.filter_shard((x, y, pad_mask), sharding) + x, y, pad_mask = strategy.shard_data((x, y, pad_mask)) + model, opt_state = strategy.shard_model((model, opt_state)) dynamic_model = eqx.filter(model, eqx.is_inexact_array) @@ -114,7 +105,7 @@ def compute_loss(model: Union[React, GPT], x: Array, y: Array, pad_mask: Array, model = eqx.apply_updates(model, updates) # shard the outputs as well - model, opt_state = eqx.filter_shard((model, opt_state), replicated) + model, opt_state = strategy.shard_model((model, opt_state)) return loss, model, opt_state, grads, updates @@ -143,14 +134,14 @@ def __init__(self, def evaluate_acc(self, model: Union[React, GPT], is_baseline: bool, loader: DataLoader, eval_iters: int, keys: PRNGKeyArray): - model = eqx.filter_shard(model, replicated) + model = strategy.shard_model((model)) metrics_sum = jnp.zeros(3) # [acc, loss, ppl] num_batches = len(loader) for _, batch in tqdm(enumerate(loader), total=len(loader), desc='Validating'): seq, label, pad_mask = jnp.asarray(batch['text']) - seq, label, pad_mask = eqx.filter_shard((seq, label, pad_mask), sharding) + seq, label, pad_mask = strategy.shard_data((seq, label, pad_mask)) seq, label, pad_mask = policy.cast_to_compute((seq, label, pad_mask)) acc, loss, ppl = self.compute_metrics(is_baseline, model, seq, label, pad_mask, eval_iters, self.args.num_classes, keys) @@ -247,10 +238,10 @@ def init_model(self, key: PRNGKeyArray) -> Tuple[PyTree, Union[React, GPT]]: model = policy.cast_to_param(model) _, opt_state, model = self.set_optim_and_scheduler(model) - model = eqx.filter_shard(model, replicated) - - count_params(model) # prints to stdout - calc_performance_metrics(self.args, self.my_logger) # logs via logger + + model, opt_state = strategy.shard_model((model, opt_state)) + count_params(model) # prints to stdout + calc_performance_metrics(self.args, self.my_logger) # logs via logger return opt_state, model @@ -299,10 +290,8 @@ def compute_metrics( Computes the accuracy, perplexity, loss of the model w.r.t batch ''' # sharding everything - model = eqx.filter_shard(model, replicated) - input_arr, label, pad_mask = eqx.filter_shard( - (input_arr, label, pad_mask), sharding - ) + model = strategy.shard_model(model) + input_arr, label, pad_mask = strategy.shard_data((input_arr, label, pad_mask)) keys = keys[:input_arr.shape[0], ...] # take a batch_size sized slice of the keys @@ -358,10 +347,10 @@ def train(self, trial: Optional[Any] = None) -> float: keys = jax.random.split(epoch_key, self.args.batch_size) for step, batch in tqdm(enumerate(self.trainloader), total=self.dataset_length, desc=f'Epoch {epoch}'): - step += step_done # for multiple epochs + step += step_done # for multiple epochs - seq, label, pad_mask = jnp.asarray(batch['text']) - seq, label, pad_mask = eqx.filter_shard((seq, label,pad_mask), sharding) + seq, label, pad_mask = jnp.asarray(batch["text"]) + seq, label, pad_mask = strategy.shard_data((seq, label, pad_mask)) seq, label, pad_mask = policy.cast_to_compute((seq, label, pad_mask)) loss, model, opt_state, grads, updates = make_step(