Skip to content

Commit

Permalink
[FEAT] Added initial autosharding API
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Sep 3, 2024
1 parent f170cb2 commit a08abc5
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 31 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/l
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly
RUN pip3 install git+https://github.com/deepmind/jmp
RUN pip3 install git+https://github.com/Findus23/jax-array-info.git

WORKDIR /ReAct_Jax

# Set the entry point to bash
ENTRYPOINT ["/bin/bash"]
ENTRYPOINT ["/bin/bash"]
11 changes: 11 additions & 0 deletions ReAct/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jax_array_info import sharding_info
from jaxtyping import Array, PRNGKeyArray, PyTree

from ReAct.model.baseline import GPT
Expand Down Expand Up @@ -114,6 +115,16 @@ def half_precision(model: eqx.Module) -> eqx.Module:
lambda x: x.astype(jnp.bfloat16) if eqx.is_inexact_array(x) else x, model
)

def viz_obj(model: PyTree):
model = eqx.filter(model, eqx.is_array)

def viz_fn(leaf):
print(f"\n=== leaf: {leaf.shape} ===\n")
return sharding_info(leaf)

jax.tree_util.tree_map(viz_fn, model)


def megatron_init(weight: Array, key: PRNGKeyArray) -> Array:
"""
Init all the weights with the Megatron paper init
Expand Down
196 changes: 196 additions & 0 deletions ReAct/utils/sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import os
from abc import ABC, abstractmethod
from typing import List, Tuple

import equinox as eqx
import jax
import jax.tree_util as jtu
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxtyping import Array, PyTree

from ReAct.model.baseline import GPT
from ReAct.utils.helpers import viz_obj


def get_strategy(strategy: str, *args):
strategy = strategy.strip().lower()
match strategy:
case 'ddp':
strat = DDPSharding(*args)

case 'simple mp':
strat = SimpleMPSharding(*args)

case 'megatron':
strat = MegatronSharding(*args)

case _:
raise NotImplementedError(f'Strategy {strategy} does not exist.')

return strat


class Sharding(ABC):
def __init__(self, model_axis: int = 1) -> None:
self.model_axis = model_axis

@abstractmethod
def get_mesh(self) -> Mesh:
...

@abstractmethod
def shard_data(self, tree: PyTree) -> PyTree:
...

@abstractmethod
def shard_model(self, tree: PyTree) -> PyTree:
...

def get_devices(self):
num_devices = len(jax.devices())
return mesh_utils.create_device_mesh((num_devices // self.model_axis, self.model_axis))

@staticmethod
def shard(a: PyTree, _sharding: NamedSharding) -> PyTree:
return eqx.filter_shard(a,_sharding)

def add_indices_to_tree(self, tree: PyTree, start_index: int = 0, dims_to_count = 3):
'''
dims_to_count: leaves of what `.ndim` would be counted.
'''
def add_index(leaf, index):
return [leaf, index[0]]

def index_incrementer(leaf: PyTree) -> List[int]:
if not eqx.is_array(leaf):
return [-999]

nonlocal start_index
start_index += 1 if leaf.ndim >= 3 else 0
return [start_index - 1]

indexed_tree = jtu.tree_map(add_index, tree, jtu.tree_map(index_incrementer, tree))
return indexed_tree

class DDPSharding(Sharding):
def __init__(self, model_axis: int = 1) -> None:
super().__init__(model_axis)
self.mesh = self.get_mesh()

def get_mesh(self) -> Mesh:
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
mesh = Mesh(devices, axis_names=('data', None))

return mesh

def shard_data(self, tree: PyTree | Array) -> PyTree | Array:
return self.shard(tree, NamedSharding(self.mesh, P('data')))

def shard_model(self, tree: PyTree) -> PyTree:
return jtu.tree_map(self.ddp_sharding, tree)

def ddp_sharding(self, leaf: PyTree) -> PyTree:
return leaf

class SimpleMPSharding(Sharding):
def __init__(self, strategy: str, model_axis: int = 2) -> None:
super().__init__(model_axis)
self.mesh = self.get_mesh()

def get_mesh(self) -> Mesh:
return Mesh(self.get_devices(), axis_names=('data', 'model'))

def shard_data(self, tree: PyTree | Array) -> PyTree | Array:
return self.shard(tree, NamedSharding(self.mesh, P('data')))

def shard_model(self, tree: PyTree) -> PyTree:
return jtu.tree_map(self.simple_sharding, tree)

def simple_sharding(self, leaf: PyTree) -> PyTree:
if not eqx.is_array(leaf):
return leaf

sharding_ = NamedSharding(self.mesh, P())

if leaf.ndim == 1:
sharding_ = NamedSharding(self.mesh, P("model"))

if leaf.ndim >= 2:
sharding_ = NamedSharding(self.mesh, P(None, "model"))

return self.shard(leaf, sharding_)


class MegatronSharding(Sharding):
def __init__(self, strategy: str, model_axis: int = 2) -> None:
super().__init__(model_axis)
self.mesh = self.get_mesh()

def get_mesh(self) -> Mesh:
return Mesh(self.get_devices(), axis_names=('data', 'model'))

def shard_data(self, tree: PyTree | Array) -> PyTree | Array:
return self.shard(tree, NamedSharding(self.mesh, P('data')))

def shard_model(self, tree: PyTree) -> PyTree:
is_leaf = lambda x: isinstance(x, list) # noqa: E731
tree = self.add_indices_to_tree(tree, dims_to_count = 3)
return jtu.tree_map(self.megatron_sharding, tree, is_leaf=is_leaf)

def megatron_sharding(self, leaf_and_index: PyTree[Tuple]) -> PyTree:
leaf, idx = leaf_and_index

if not eqx.is_array(leaf):
return leaf

sharding_ = NamedSharding(self.mesh, P())

# LN params and embedding 1Ds
if leaf.ndim == 1:
if max(leaf.shape) >= 2**14:
sharding_ = NamedSharding(self.mesh, P('model'))
else:
sharding_ = NamedSharding(self.mesh, P(None))

# embedding and unembedding
if leaf.ndim == 2:
if max(leaf.shape) >= 2**14:
# shard the bigger index
p_spec = [
"model" if i == leaf.shape.index(max(leaf.shape)) else None
for i in range(len(leaf.shape))
]
sharding_ = NamedSharding(self.mesh, P(*p_spec))
else:
sharding_ = NamedSharding(self.mesh, P(None, 'model'))

if leaf.ndim == 3:
if idx % 2 == 0:
sharding_ = NamedSharding(self.mesh, P(None, None, "model"))
else:
sharding_ = NamedSharding(self.mesh, P(None, "model", None))

return self.shard(leaf, sharding_)


if __name__ == "__main__":
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
assert len(jax.devices()) == 8, "Hosts not correctly spoofed"

key = jax.random.PRNGKey(0)
BSZ, SEQLEN, WIDTH = 32, 256, 64

model = GPT(4, SEQLEN, 2, WIDTH, 0.01, 50304, key=key)
strategy= get_strategy('megatron', 1)

data = jax.numpy.ones((BSZ, SEQLEN))
data = strategy.shard_data(tree=data)
sharded_model = strategy.shard_model(model)

viz_obj(sharded_model)

print('\n ++++++++ Sharded data: +++++++++++\n')
jax.debug.visualize_array_sharding(data)
49 changes: 19 additions & 30 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import os
import wandb
from typing import Any, Callable, Optional, Tuple, Union

import equinox as eqx
import threading
import queue
import jax
import jax.experimental.mesh_utils as mesh_utils
import jax.numpy as jnp
import jax.sharding as jshard
import optax
import optuna

from typing import Any, Callable, Optional, Tuple, Union
from jaxtyping import Array, Int, PRNGKeyArray, PyTree
from jmp import Policy
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import wandb
from inferencer import Inferencer
from ReAct.model.baseline import GPT
from ReAct.model.react import React
Expand All @@ -33,15 +29,11 @@
_cross_entropy_with_logits_fwd,
cross_entropy_with_logits,
)
from ReAct.utils.sharding import get_strategy

half, full = jnp.bfloat16, jnp.float32
policy = Policy(compute_dtype=half, param_dtype=half, output_dtype=half)

# Setting up distributed stuff
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()
strategy = get_strategy('megatron', 2)

# Stable CE (w/ z-loss) from PaLM
ce_loss = cross_entropy_with_logits
Expand Down Expand Up @@ -85,9 +77,8 @@ def make_step(
num_classes: int,
keys: PRNGKeyArray,
):
replicated = sharding.replicate()
model, opt_state = eqx.filter_shard((model, opt_state), replicated)
x, y, pad_mask = eqx.filter_shard((x, y, pad_mask), sharding)
x, y, pad_mask = strategy.shard_data((x, y, pad_mask))
model, opt_state = strategy.shard_model((model, opt_state))

dynamic_model = eqx.filter(model, eqx.is_inexact_array)

Expand All @@ -114,7 +105,7 @@ def compute_loss(model: Union[React, GPT], x: Array, y: Array, pad_mask: Array,
model = eqx.apply_updates(model, updates)

# shard the outputs as well
model, opt_state = eqx.filter_shard((model, opt_state), replicated)
model, opt_state = strategy.shard_model((model, opt_state))

return loss, model, opt_state, grads, updates

Expand Down Expand Up @@ -143,14 +134,14 @@ def __init__(self,

def evaluate_acc(self, model: Union[React, GPT], is_baseline: bool, loader: DataLoader, eval_iters: int, keys: PRNGKeyArray):

model = eqx.filter_shard(model, replicated)
model = strategy.shard_model((model))

metrics_sum = jnp.zeros(3) # [acc, loss, ppl]
num_batches = len(loader)

for _, batch in tqdm(enumerate(loader), total=len(loader), desc='Validating'):
seq, label, pad_mask = jnp.asarray(batch['text'])
seq, label, pad_mask = eqx.filter_shard((seq, label, pad_mask), sharding)
seq, label, pad_mask = strategy.shard_data((seq, label, pad_mask))
seq, label, pad_mask = policy.cast_to_compute((seq, label, pad_mask))

acc, loss, ppl = self.compute_metrics(is_baseline, model, seq, label, pad_mask, eval_iters, self.args.num_classes, keys)
Expand Down Expand Up @@ -247,10 +238,10 @@ def init_model(self, key: PRNGKeyArray) -> Tuple[PyTree, Union[React, GPT]]:
model = policy.cast_to_param(model)

_, opt_state, model = self.set_optim_and_scheduler(model)
model = eqx.filter_shard(model, replicated)

count_params(model) # prints to stdout
calc_performance_metrics(self.args, self.my_logger) # logs via logger

model, opt_state = strategy.shard_model((model, opt_state))
count_params(model) # prints to stdout
calc_performance_metrics(self.args, self.my_logger) # logs via logger

return opt_state, model

Expand Down Expand Up @@ -299,10 +290,8 @@ def compute_metrics(
Computes the accuracy, perplexity, loss of the model w.r.t batch
'''
# sharding everything
model = eqx.filter_shard(model, replicated)
input_arr, label, pad_mask = eqx.filter_shard(
(input_arr, label, pad_mask), sharding
)
model = strategy.shard_model(model)
input_arr, label, pad_mask = strategy.shard_data((input_arr, label, pad_mask))

keys = keys[:input_arr.shape[0], ...] # take a batch_size sized slice of the keys

Expand Down Expand Up @@ -358,10 +347,10 @@ def train(self, trial: Optional[Any] = None) -> float:
keys = jax.random.split(epoch_key, self.args.batch_size)

for step, batch in tqdm(enumerate(self.trainloader), total=self.dataset_length, desc=f'Epoch {epoch}'):
step += step_done # for multiple epochs
step += step_done # for multiple epochs

seq, label, pad_mask = jnp.asarray(batch['text'])
seq, label, pad_mask = eqx.filter_shard((seq, label,pad_mask), sharding)
seq, label, pad_mask = jnp.asarray(batch["text"])
seq, label, pad_mask = strategy.shard_data((seq, label, pad_mask))
seq, label, pad_mask = policy.cast_to_compute((seq, label, pad_mask))

loss, model, opt_state, grads, updates = make_step(
Expand Down

0 comments on commit a08abc5

Please sign in to comment.