Skip to content

Commit

Permalink
Removing scalax and updating docker image
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Aug 10, 2024
1 parent 3f438e6 commit b7d0131
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 51 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Byte-compiled / optimized / DLL files
**.log
*.log
__pycache__/
*.py[cod]
*$py.class
Expand Down
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"args": [
"--debug",
"--width", "64",
"--batch_size", "32",
"--batch_size", "8",
"--seqlen", "192",
"--num_blocks", "4",
"--epochs", "1",
"--max_iters", "5",
Expand Down
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ RUN apt-get update && \
RUN pip3 install Ipython matplotlib
RUN pip3 install numpy pandas scipy

RUN pip3 install -U numpy==1.26.4
RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly
Expand Down
2 changes: 1 addition & 1 deletion ReAct/model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(self, input_dim: int, key: PRNGKeyArray):
self.input_dim = input_dim
self.weight = LinearProj(input_dim, input_dim, use_bias=True, key=key)

@jax.jit
@eqx.filter_jit
def __call__(self, x: Float[Array, 'seqlen in_dim'], mask: Array):
x = policy.cast_to_compute(x)
attn_weights = jax.nn.softmax(self.weight(x.T, mask), axis=1) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion ReAct/model/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self.post_ln = eqx.nn.LayerNorm(width)
self.out_head = LinearProj(width, vocab_size, key=key3)

@partial(jax.jit, static_argnums=(4, 5, 6))
@eqx.filter_jit
def iterate_for_steps(
self,
interim_thought: Array,
Expand Down
108 changes: 63 additions & 45 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import jax.numpy as jnp
import optax
import optuna

from jaxtyping import Array, PRNGKeyArray, PyTree
import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard
from jmp import Policy
from scalax.sharding import MeshShardingHelper
from scalax.sharding import PartitionSpec as P

from torch.utils.data import DataLoader
from tqdm.auto import tqdm

Expand All @@ -32,10 +34,13 @@


half, full = jnp.bfloat16, jnp.float32
policy = Policy(compute_dtype=half, param_dtype=half, output_dtype=half)

# Setting up distributed stuff
mesh = MeshShardingHelper(axis_dims=[-1], axis_names=['data']) # handle DDP + TP over multi-node
policy = Policy(compute_dtype=half, param_dtype=half, output_dtype=half)
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()

# Stable CE (w/ z-loss) from PaLM
ce_loss = cross_entropy_with_logits
Expand Down Expand Up @@ -94,27 +99,26 @@ def _compute_softmax_cross_entropy_loss(

return loss.sum((-1, -2)).mean() # mean across batch

@partial(
mesh.sjit,
in_shardings=(None, None, P('data'), P('data'), P('data'), None),
args_sharding_constraint=(None, None, P('data'), P('data'), P('data'), None),
out_shardings=None,
static_argnums=(2, 6, 7, 8)
)
def make_step(model: eqx.Module,
opt_state: Tuple[PyTree],
filter_spec: PyTree, # static
x: Array,
y: Array,
pad_mask: Array,
iters_to_do: int, # static
optim: Callable, # static
num_classes: int, # static
keys: List[PRNGKeyArray]):
@eqx.filter_jit(donate="all-except-first")
def make_step(
model: eqx.Module,
opt_state: PyTree,
filter_spec: PyTree,
x: Array,
y: Array,
pad_mask: Array,
iters_to_do: int,
optim: Callable,
num_classes: int,
keys: PRNGKeyArray,
):
replicated =sharding.replicate()
model, opt_state = eqx.filter_shard((model, opt_state), replicated)
x, y, pad_mask = eqx.filter_shard((x, y, pad_mask), sharding)

@eqx.filter_value_and_grad
def compute_loss(model: eqx.Module, static_model: PyTree, x: Array, y: Array, pad_mask: Array,
iters_to_do: int, num_classes: int, keys: PRNGKeyArray) -> int:
iters_to_do: int, num_classes: int, keys: PRNGKeyArray) -> Array:
'''
Computes the loss of the model w.r.t the input. Is a closure for accessing static_model
'''
Expand All @@ -140,6 +144,9 @@ def compute_loss(model: eqx.Module, static_model: PyTree, x: Array, y: Array, pa
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)

# shard the outputs as well
model, opt_state = eqx.filter_shard((model, opt_state), replicated)

return loss, model, opt_state

class Trainer:
Expand Down Expand Up @@ -168,13 +175,20 @@ def __init__(self,
# Assign each arg as a class attribute
self.__dict__.update(vars(self.args))

def evaluate_acc(self, model: eqx.Module, loader: DataLoader, eval_iters: int, keys: List[PRNGKeyArray]):
@eqx.filter_jit(donate='all')
def evaluate_acc(self, model: eqx.Module, is_baseline: bool, loader: DataLoader, eval_iters: int, keys: PRNGKeyArray):

metric = []

model = eqx.filter_shard(model, replicated)

for step, batch in tqdm(enumerate(loader), total=len(loader), desc='Validating'):
seq, label, pad_mask = jnp.asarray(batch['text'])
acc, loss, ppl = self.compute_metrics(model, seq, label, pad_mask, eval_iters, self.num_classes, keys)
seq, label, pad_mask = eqx.filter_shard((seq, label,pad_mask), sharding)
seq, label, pad_mask = policy.cast_to_compute((seq, label, pad_mask))

acc, loss, ppl = self.compute_metrics(is_baseline, model, seq, label, pad_mask, eval_iters, self.num_classes, keys)

metric.extend([acc, loss, ppl])

# Compute cumulatives
Expand Down Expand Up @@ -244,12 +258,13 @@ def init_model(self, key: PRNGKeyArray):
model = policy.cast_to_param(model)

_, opt_state, model = self.set_optim_and_scheduler(model)
model = eqx.filter_shard(model, replicated)

count_params(model) # prints to stdout
calc_performance_metrics(self.args, self.my_logger) # logs via logger

return opt_state, model

def resume_training(self, model: eqx.Module, opt_state: eqx.Module):
# extracting out the paths
run_path, step = self.resume.split('+')
Expand All @@ -268,27 +283,30 @@ def resume_training(self, model: eqx.Module, opt_state: eqx.Module):

return model, opt_state, step

@partial(
mesh.sjit,
in_shardings=(None, P('data'), P('data'), P('data'), None),
out_shardings=None,
args_sharding_constraint=(None, P('data'), P('data'), P('data'), None),
static_argnums=(0, 5, 6)
)
def compute_metrics(self,
model: eqx.Module,
input_arr: Array,
label: Array,
pad_mask: Array,
eval_iters: int, # static
num_classes: int, # static
keys: List[PRNGKeyArray]):
@eqx.filter_jit
def compute_metrics(
self,
is_baseline: bool,
model: eqx.Module,
input_arr: Array,
label: Array,
pad_mask: Array,
eval_iters: int, # static
num_classes: int, # static
keys: List[PRNGKeyArray],
):
'''
Computes the accuracy, perplexity, loss of the model w.r.t batch
'''
# sharding everything
model = eqx.filter_shard(model, replicated)
input_arr, label, pad_mask = eqx.filter_shard(
(input_arr, label, pad_mask), sharding
)

keys = keys[:input_arr.shape[0], ...] # take a batch_size sized slice of the keys

if self.baseline:
if is_baseline:
pred_y = jax.vmap(model, in_axes=(0, 0, None, 0))(input_arr, pad_mask, False, keys)
else:
pred_y = jax.vmap(model, in_axes=(0, None, 0, None, None, 0))(input_arr, eval_iters, pad_mask, False, False, keys)[0]
Expand Down Expand Up @@ -316,7 +334,7 @@ def optuna_log(self, trial: Optional[Any], metrics: Tuple[float, int]):
if trial is not None:
trial.report(loss, progress)

def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
def train(self, trial: Optional[Any] = None) -> float:
step_done = 0

opt_state, model = self.init_model(self.key)
Expand All @@ -341,14 +359,14 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
step += step_done # for multiple epochs

seq, label, pad_mask = jnp.asarray(batch['text'])
seq, label, pad_mask = eqx.filter_shard((seq, label,pad_mask), sharding)
seq, label, pad_mask = policy.cast_to_compute((seq, label, pad_mask))

loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask,
self.max_iters, optim, self.num_classes, keys)

if step % 100 == 0:
accuracy, loss, perplexity = self.compute_metrics(model, seq, label, pad_mask,
self.max_iters, self.num_classes, keys)
accuracy, loss, perplexity = self.compute_metrics(self.baseline, model, seq, label, pad_mask, self.max_iters, self.num_classes, keys)

train_acc.append(accuracy)
train_loss.append(loss)
Expand Down Expand Up @@ -379,7 +397,7 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
train_acc, train_loss, train_ppl = [], [], []

## Validation
(val_acc, val_loss, val_ppl), val_sample = self.evaluate_acc(model, self.valloader, self.max_iters, keys)
(val_acc, val_loss, val_ppl), val_sample = self.evaluate_acc(model, self.baseline, self.valloader, self.max_iters, keys)

self.wandb_logger.log(
{
Expand Down
22 changes: 19 additions & 3 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

jax.config.update("jax_compilation_cache_dir", "./ReAct/compilation_cache")

if platform.processor() != 'arm':
jax.distributed.initialize() # don't run on apple sillicon
if platform.processor() != "arm":
jax.distributed.initialize() # don't run on apple sillicon

import optuna
import os

from wandb import Artifact

from jax import config
from jaxtyping import PRNGKeyArray
from optuna.integration.wandb import WeightsAndBiasesCallback
Expand Down Expand Up @@ -62,9 +66,14 @@ def main(key: PRNGKeyArray):
valloader = val_dataset.create_dataloader('-1%:')

# Create optuna hypertununing study
storage = optuna.storages.JournalStorage(
optuna.storages.JournalFileStorage("./journal.log"),
)

study = optuna.create_study(
direction="minimize",
load_if_exists=True,
storage=storage,
sampler=optuna.samplers.TPESampler(
seed=69,
consider_magic_clip=True,
Expand Down Expand Up @@ -159,6 +168,13 @@ def kickoff_optuna(trial, **trainer_kwargs):
# ========= Logging ========
logger = UnifiedLogger(args, level='DEBUG')
my_logger, wandb_logger = logger.my_logger(), logger.wandb_logger(args)

# Store the optuna checkpoint progress
if os.path.isfile('./journal.log'):
artifact = Artifact(name="Optuna_Checkpoint", type="checkpoint")
artifact.add_file(local_path = "./journal.log", name = "optuna_chkp")
artifact.save()

trainer_kwargs['logger'] = (my_logger, wandb_logger)

trainer = Trainer(**trainer_kwargs)
Expand All @@ -175,4 +191,4 @@ def kickoff_optuna(trial, **trainer_kwargs):
if __name__ == '__main__':
key = jax.random.PRNGKey(69)
main(key)
exit(0)
exit(0)

0 comments on commit b7d0131

Please sign in to comment.