diff --git a/.github/workflows/docker-image.yaml b/.github/workflows/docker-image.yaml index 1eeb549..fedf77e 100644 --- a/.github/workflows/docker-image.yaml +++ b/.github/workflows/docker-image.yaml @@ -3,8 +3,12 @@ name: Docker Image CI on: push: branches: [ "dev" ] + paths: + - '**Dockerfile**' pull_request: branches: [ "main" ] + paths: + - '**Dockerfile**' jobs: build: diff --git a/.gitignore b/.gitignore index c5ca05b..82af572 100644 --- a/.gitignore +++ b/.gitignore @@ -166,4 +166,5 @@ cython_debug/ **wandb *.wandb *.eqx -*cached_data \ No newline at end of file +*cached_data +*pyrightconfig.json diff --git a/Dockerfile b/Dockerfile index 52cfab0..b202de4 100755 --- a/Dockerfile +++ b/Dockerfile @@ -19,7 +19,7 @@ RUN pip3 install numpy pandas scipy RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich -RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration +RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration plotly WORKDIR /ReAct_Jax diff --git a/README.md b/README.md index 0fc6fe2..48e283d 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ python3 inferencer.py --checkpoint_path '/Users/neel/Documents/research/ReAct_Ja First, get a preemptible TPUv4-8 node as a queued resource: ```bash -gcloud alpha compute tpus queued-resources create node-v4 \ +gcloud alpha compute tpus queued-resources create $INSTANCE_NAME \ --node-id node-v4 \ --project react-jax \ --zone us-central2-b \ @@ -38,7 +38,7 @@ gcloud alpha compute tpus queued-resources create node-v4 \ Setup the TPU pod slice with basics: ```bash -gcloud compute tpus tpu-vm ssh node-v4 \ +gcloud compute tpus tpu-vm ssh $INSTANCE_NAME \ --zone=us-central2-b --worker=all --command="\ sudo apt-get update; \ sudo snap install nvim --classic; \ @@ -49,7 +49,7 @@ gcloud compute tpus tpu-vm ssh node-v4 \ And then actually kickoff the training by downloading the script and running it: ```bash -gcloud compute tpus tpu-vm ssh node-v4 \ +gcloud compute tpus tpu-vm ssh $INSTANCE_NAME \ --zone=us-central2-b --worker=all --command="\ tmux kill-server; sudo rm -rf ./*; \ sleep 3s && wget https://gist.githubusercontent.com/neel04/3bfc7e4d9cd746829b7e72f1b6fac5de/raw/run.sh; \ @@ -59,7 +59,11 @@ gcloud compute tpus tpu-vm ssh node-v4 \ If you get errors regarding workers not being able to sync up at the distributed barrier, do: ```bash -gcloud compute tpus tpu-vm ssh --zone "us-central2-b" "ondem" \ ---project "react-jax" \ ---command 'sudo docker system prune -f && sudo rm -rf ~/.cache;' +gcloud compute tpus tpu-vm ssh --zone "us-central2-b" $INSTANCE_NAME --worker 'all' --project "react-jax" --command 'sudo docker system prune -f && sudo rm -rf ~/.cache;' +``` + +If Docker is unresponsive, just restart docker service: + +```bash +gcloud compute tpus tpu-vm ssh --zone "us-central2-b" $INSTANCE_NAME --worker 'all' --project "react-jax" --command 'sudo systemctl restart docker' ``` \ No newline at end of file diff --git a/ReAct/data/minipile.py b/ReAct/data/minipile.py index a1b64cd..e9dde50 100644 --- a/ReAct/data/minipile.py +++ b/ReAct/data/minipile.py @@ -4,7 +4,6 @@ import datasets import jax -import jax.numpy as jnp import numpy as np from datasets import load_dataset, load_from_disk from jaxtyping import Array @@ -15,6 +14,7 @@ class MiniPileDataset: def __init__(self, split: str = 'train', max_length: int = 512, bsz: int = 256, vocab_dir: str ='./ReAct/data'): datasets.config.IN_MEMORY_MAX_SIZE = 1e+11 + self.cpus = jax.devices("cpu") self.bsz = bsz self.max_length = max_length + 1 self.split = split @@ -93,12 +93,6 @@ def take_subset(self, dataset, elements: int) -> None: return dataset - def numpify(self, dataset: datasets.Dataset) -> datasets.Dataset: - ''' - Convert the dataset to numpy arrays - ''' - return jax.tree_map(lambda x: jnp.asarray(x), dataset['text']) - def create_dataloader(self, slice: str = '100%'): data_path = Path(f'./cached_data/minipile_{self.split}.data') @@ -110,13 +104,13 @@ def create_dataloader(self, slice: str = '100%'): dataset.set_format(type='numpy') - return self.numpify(dataset) + return dataset except (FileNotFoundError, ValueError): if os.path.exists(data_path): print(f'Loading dataset from {data_path}...') dataset = self.load_data(data_path) - return self.numpify(dataset) + return dataset else: print(f'Building dataset from scratch... [split: {self.split}] | [bsz: {self.bsz}]') @@ -139,4 +133,4 @@ def create_dataloader(self, slice: str = '100%'): self.upload_dataset(dataset, hub_path=f'Neel-Gupta/minipile-processed_{self.bsz}') # upload the processed dataset to the Hub - return self.numpify(dataset) \ No newline at end of file + return dataset \ No newline at end of file diff --git a/ReAct/model/blocks.py b/ReAct/model/blocks.py index b15913d..5fd890c 100644 --- a/ReAct/model/blocks.py +++ b/ReAct/model/blocks.py @@ -151,11 +151,12 @@ def __init__(self, def __call__(self, input: BFloat16[Array, 'batch in_dim'], - **kwargs): + **kwargs) -> Array: mask = kwargs.get('mask', None) mask = jnp.ones_like(self.weight) if mask is None else mask output = input @ (self.weight * mask.astype(input.dtype)) + self.bias + return output class LiteAttention(eqx.Module): diff --git a/ReAct/model/react.py b/ReAct/model/react.py index 9347edf..fdd2295 100644 --- a/ReAct/model/react.py +++ b/ReAct/model/react.py @@ -5,14 +5,18 @@ import jax.numpy as jnp from jaxtyping import Array, PRNGKeyArray, PyTree -from .blocks import AttentionBlock, LinearProj, LiteAttention +from .blocks import AttentionBlock, LinearProj, LiteAttention, MLP class RecurrentModule(eqx.Module): ''' - Bunch of AttentionBlocks + Bunch of AttentionBlocks in a pseuo-LSTM fashion ''' + num_blocks: int = eqx.field(static=True) + attention_blocks: PyTree[AttentionBlock] - reshape_layer: eqx.Module + forget_gate: LinearProj + ctx_gate: LinearProj + reshape_layer: LinearProj def __init__(self, seqlen: int, @@ -20,40 +24,52 @@ def __init__(self, n_heads: int, num_blocks: int, bottleneck: int, - key: PRNGKeyArray): # noqa: E501 - + key: PRNGKeyArray): + + self.num_blocks = num_blocks keys = jax.random.split(key, num_blocks) + + make_block: callable = lambda k: AttentionBlock( # noqa: E731 + seqlen, n_heads, drop_rate, bottleneck, k + ) self.reshape_layer = LinearProj(bottleneck * 2, bottleneck, key=key) + self.forget_gate = MLP(bottleneck, bottleneck, p=drop_rate, key=key) + self.ctx_gate = MLP(bottleneck, bottleneck, p=drop_rate, key=key) - make_block: callable = lambda k: AttentionBlock(seqlen, n_heads, drop_rate, bottleneck, k) # noqa: E731 self.attention_blocks = eqx.filter(eqx.filter_vmap(make_block)(keys), eqx.is_array_like) - - def __call__(self, x: Array, + + def __call__(self, + x: Array, input_arr: Array, pad_mask: Array, enable_dropout: bool, - key: PRNGKeyArray) -> Array: + key: PRNGKeyArray) -> Tuple[Array, Array]: - enable_dropout: bool = True - key: PRNGKeyArray = key - + keys = jax.random.split(key, self.num_blocks) dynamic_part, static_part = eqx.partition(self.attention_blocks, eqx.is_array_like, is_leaf=lambda x: isinstance(x, eqx.nn.Dropout)) - x = self.reshape_layer(x) # (batch, seqlen, bottleneck) + x = self.reshape_layer(x) # (seqlen, width * 2) -> (seqlen, width) - def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, int], int]: - input_arr, idx = input_tup # i is the iteration index + def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, int], None]: + x, idx = input_tup # i is the iteration index block = eqx.combine(_dynamic_bl, static_part) # reconstruct the block - output = block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16) # self-attention - return (output, idx + 1), None + x = jax.lax.cond(idx == 0, + lambda: block(x, input_arr, pad_mask, enable_dropout, keys[idx]), + lambda: block(x, x, pad_mask, enable_dropout, keys[idx])) + + return (x, idx + 1), x - out = eqx.internal.scan(f=f, init=(input_arr, 0), xs=dynamic_part, kind='lax')[0][0] # throw away idx + out, history = eqx.internal.scan(f=f, init=(x, 0), xs=dynamic_part, kind='lax') + history = history.mean(0) + + input_arr *= jax.nn.sigmoid(self.forget_gate(history, True, key)) + input_arr += self.ctx_gate(history, True, key) - return out + return out[0], input_arr class React(eqx.Module): ''' @@ -63,7 +79,6 @@ class React(eqx.Module): max_iters: int = eqx.field(static=True) - iters_weights: Array pos_enc: Array embed_layer: eqx.nn.Embedding main_block: LiteAttention @@ -110,27 +125,31 @@ def positional_encoding(self, seq_len, d_model): @eqx.filter_jit def iterate_for_steps(self, interim_thought: Array, + input_arr: Array, mask: Array, iters_to_do: int, - input_arr: Array, enable_dropout: bool, key: PRNGKeyArray) -> Array: - # These are constants + # Declaring constants input_arr = input_arr.astype(jnp.bfloat16) interim_thought = interim_thought.astype(jnp.bfloat16) mask = mask.astype(jnp.bfloat16) - def body_fun(thought: Array, _) -> Tuple[PyTree, Array]: - latent = jnp.concatenate([thought, input_arr], axis=-1).astype(jnp.bfloat16) - latent = self.main_block(latent, input_arr, mask, enable_dropout, key).astype(jnp.bfloat16) - latent = jax.vmap(self.post_ln)(latent).astype(jnp.bfloat16) # LN to keep scales tidy + def body_fun(carry: Tuple[Array, Array], idx: int) -> Tuple[Array, Array]: + thought, ctx_state = carry + + latent = jnp.concatenate([input_arr, thought], axis=-1) # (seqlen, width * 2) + latent, ctx_state = self.main_block(latent, ctx_state, mask, enable_dropout, key) # (seqlen, width) + latent = jax.vmap(self.post_ln)(latent) # Post-LN for stability + + return (latent, ctx_state), ctx_state - return latent, latent + final_val, history = eqx.internal.scan( + f=body_fun, init=(interim_thought, input_arr), xs=jnp.arange(5), kind="checkpointed" + ) - final_val, _ = eqx.internal.scan(f=body_fun, init=interim_thought, xs=None, length=5, kind='checkpointed') - #return jnp.einsum('i j k, i -> j k', history, self.iters_weights) # dot-product with iters_weights - return final_val + return final_val[0] @eqx.filter_jit def __call__(self, @@ -149,6 +168,6 @@ def __call__(self, input_arr = jax.vmap(self.embed_layer)(input_arr) + self.pos_enc # (batch, seqlen, bottleneck) interim_thought = input_arr.copy() # has to be a copy of the embedded + projected input array - output = self.iterate_for_steps(interim_thought, pad_mask, iters_to_do, input_arr, is_training, key) # (batch, seqlen, bottleneck) + output = self.iterate_for_steps(interim_thought, input_arr, pad_mask, iters_to_do, is_training, key) # (batch, seqlen, bottleneck) return self.out_head(output), output diff --git a/ReAct/utils/helpers.py b/ReAct/utils/helpers.py index 0e0a691..01101db 100644 --- a/ReAct/utils/helpers.py +++ b/ReAct/utils/helpers.py @@ -1,11 +1,61 @@ +import math import os +from typing import Callable, Optional, Tuple + import equinox as eqx import jax import jax.numpy as jnp - from jax import tree_util as jtu from jaxtyping import Array, PRNGKeyArray -from typing import Optional + +def convert_flops(params: int) -> str: + if params == 0: + return "0" + + size_name = ("", "KFLOPs", "MFLOPs", "GFLOPs", "TFLOPs", "PFLOPs", "EFLOPs", "ZFLOPs", "YFLOPs") + i = int(math.floor(math.log(params, 1000))) + p = math.pow(1000, i) + s = round(params / p, 2) + + return "%s %s" % (s, size_name[i]) + +def calc_performance_metrics(args, my_logger: Callable) -> None: + ''' + Estimates FLOPs consumed during a single fwd + bwd pass. + Taken from EleutherAI's GPT-NeoX repo: https://rb.gy/33d6zg + + Returns: the total number of FLOPs + ''' + iter_factor = 3 + args.tokens = args.batch_size * args.seqlen + args.kv_size_ratio = 1 + + my_logger.warning('! Ignoring activation checkpointing in FLOPs calculation !') + + qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_blocks * args.tokens * args.width * args.width) + attention_matrix_flops = iter_factor * 2 * args.num_blocks * args.tokens * args.seqlen * args.width + attention_over_values_flops = iter_factor * 2 * args.num_blocks * args.tokens * args.seqlen * args.width + linear_projection_flops = iter_factor * 2 * args.num_blocks * args.tokens * args.width * args.width + ffn_flops = iter_factor * 16 * args.num_blocks * args.tokens * args.width * args.width + + # handle NewGELU + ffn_flops *= 3.75 + + embedding_flops = 6 * args.tokens * args.width * args.num_classes + total_flops = qkv_flops + attention_matrix_flops + attention_over_values_flops + linear_projection_flops + ffn_flops + embedding_flops + my_logger.info(f"Total FLOPs for the Model: {convert_flops(total_flops)} for a single fwd + bwd pass\n") + + +def xla_calc_flops(fn: Callable, static_argnums: Tuple[int], args: Tuple, my_logger: Callable) -> None: + ''' + Estimates FLOPs consumed during `fn` execution. + Use's XLA HLO analysis to estimate FLOPs. + + Returns: the total number of FLOPs + ''' + compiled = jax.jit(fn, static_argnums=static_argnums).lower(*args).compile() + flops = compiled.cost_analysis()[0]['flops'] + my_logger.info(f"XLA estimate of Total FLOPs for {fn.__name__}: {convert_flops(int(flops))}\n") def half_precision(model: eqx.Module) -> eqx.Module: return jtu.tree_map(lambda x: x.astype(jnp.bfloat16) if eqx.is_inexact_array(x) else x, model) @@ -59,15 +109,4 @@ def inverted_freq(arr: Array): inv_weights = (counts.max() / counts) # scale it down - return inv_weights[arr - arr.min()] - -if __name__ == '__main__': - import plotly.express as px - import pandas as pd - - key = jax.random.PRNGKey(0) - out: Array = get_rand_nums(key, 1, 10, 512, 4) - elems, counts = jnp.unique(out, return_counts=True) - df = pd.DataFrame({'elems': elems, 'counts': counts}) - fig = px.bar(df, x='elems', y='counts') - fig.show() \ No newline at end of file + return inv_weights[arr - arr.min()] \ No newline at end of file diff --git a/ReAct/utils/logger.py b/ReAct/utils/logger.py index 44ccb08..d3f2e45 100644 --- a/ReAct/utils/logger.py +++ b/ReAct/utils/logger.py @@ -40,16 +40,16 @@ def wandb_logger(self, args: dict): if args.resume: # args.resume is of the form: "neel/ReAct_Jax/lxxn0x54 + 20" # we want to extract the run id, i.e "lxxn0x54" - id = args.resume.split("+")[0].split("/")[-1].strip() + wandb_id = args.resume.split("+")[0].split("/")[-1].strip() else: - id = None + wandb_id = None wandb.init(project='ReAct_Jax', config=args, group=args.group, mode='online' if jax.process_index() == 0 and args.exp_logging else 'offline', resume='allow', - id=id, + id=wandb_id, reinit=True) wandb.run.log_code( diff --git a/ReAct/utils/trainer.py b/ReAct/utils/trainer.py index 4499ee8..3630cac 100644 --- a/ReAct/utils/trainer.py +++ b/ReAct/utils/trainer.py @@ -1,13 +1,15 @@ import os from functools import partial -from typing import Callable, List, Tuple +from typing import Any, Callable, List, Optional, Tuple import equinox as eqx import jax import jax.numpy as jnp import optax +import optuna from jaxtyping import Array, PRNGKeyArray, PyTree -from scalax.sharding import MeshShardingHelper, PartitionSpec as P +from scalax.sharding import MeshShardingHelper +from scalax.sharding import PartitionSpec as P from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -16,7 +18,7 @@ from ReAct.model.react import React from ReAct.utils.helpers import count_params, load_eqx_obj, save_eqx_obj -from .helpers import broad_to_bsz, half_precision +from .helpers import broad_to_bsz, calc_performance_metrics, half_precision mesh = MeshShardingHelper(axis_dims=[-1], axis_names=['data']) # handle DDP + TP over multi-node @@ -95,12 +97,12 @@ def make_step(model: eqx.Module, @eqx.filter_value_and_grad def compute_loss(model: eqx.Module, static_model: PyTree, x: Array, y: Array, pad_mask: Array, - n: int, k: int, num_classes: int, keys: PRNGKeyArray = None) -> Tuple[int, PyTree]: + n: int, k: int, num_classes: int, keys: PRNGKeyArray) -> int: ''' - Computes the loss of the model w.r.t the input. Is a closure for accessing static_model + Computes the loss of the model w.r.t the input. Is a closure for accessing static_model ''' model = eqx.combine(model, static_model) - + if model.__name__ == 'ReAct': forward = iters_fwd else: @@ -114,7 +116,7 @@ def compute_loss(model: eqx.Module, static_model: PyTree, x: Array, y: Array, pa diff_model, static_model = eqx.partition(model, filter_spec, is_leaf=lambda x: isinstance(x, eqx.nn.Dropout)) - + loss, grads = compute_loss(diff_model, static_model, x, y, pad_mask, n, k, num_classes, keys) updates, opt_state = optim.update(grads, opt_state, model) model = eqx.apply_updates(model, updates) @@ -133,16 +135,18 @@ def __init__(self, self.args = args self.key = key - # unpacking the loaders & loggers self.my_logger, self.wandb_logger = logger self.trainloader, self.valloader = loaders self.dataset_length = len(self.trainloader) * args.batch_size * args.seqlen + + self.text_table = wandb.Table( + columns=["Step", "Prompt", "Model Generation", "Type"] + ) - self.my_logger.info(f'Using Args: {self.args}\n') + self.my_logger.info(f"Using Args: {self.args}\n") # Assign each arg as a class attribute - for k, v in vars(self.args).items(): - setattr(self, k, v) + self.__dict__.update(vars(self.args)) def get_n_k(self, key: PRNGKeyArray) -> Tuple[Array, Array]: n_key, k_key = jax.random.split(key, 2) @@ -162,10 +166,8 @@ def evaluate_acc(self, model: eqx.Module, loader: DataLoader, eval_iters: int, k metric = [] for step, batch in tqdm(enumerate(loader), total=len(loader), desc='Validating'): - seq, label, pad_mask = batch - + seq, label, pad_mask = jnp.asarray(batch['text']) acc, loss, ppl = self.compute_metrics(model, seq, label, pad_mask, eval_iters, self.num_classes, keys) - metric.extend([acc, loss, ppl]) # Compute cumulatives @@ -202,7 +204,7 @@ def get_filterspec(model: eqx.Module) -> PyTree[bool]: ''' Returns a filter spec for the model to filter out the trainable parameters. Can be used to freeze or unfreeze certain modules of the model depending on the step and epoch. - + Args: model: The model to filter Returns: @@ -213,9 +215,9 @@ def get_filterspec(model: eqx.Module) -> PyTree[bool]: lambda tree: tree.pos_enc, # pos_enc should be frozen filter_spec, replace=False) - + return filter_spec - + def init_model(self, key: PRNGKeyArray): if self.baseline: @@ -228,9 +230,11 @@ def init_model(self, key: PRNGKeyArray): # switch to half precision if self.bf16: model = half_precision(model) - + _, opt_state, model = self.set_optim_and_scheduler(model) + count_params(model) # prints to stdout + calc_performance_metrics(self.args, self.my_logger) # logs via logger return opt_state, model @@ -291,10 +295,14 @@ def compute_metrics(self, return accuracy, loss, perplexity - def train(self): + def train(self, trial: Optional[Any] = None) -> Tuple[float, int]: step_done = 0 + + rndm_n, rndm_k = self.get_n_k(key=self.key) # initial n and k + opt_state, model = self.init_model(self.key) optim, _, _ = self.set_optim_and_scheduler(model) + filter_spec = self.get_filterspec(model) if self.resume: model, opt_state, epoch_done = self.resume_training(model, opt_state) @@ -302,10 +310,7 @@ def train(self): epoch_done = 0 print(f'Model: {model}') - - rndm_n, rndm_k = self.get_n_k(key=self.key) # initial n and k - filter_spec = self.get_filterspec(model) - + for epoch in range(epoch_done, self.epochs): # init empty metrics epoch_key = jnp.array([epoch, epoch + 1]).astype(jnp.uint32) @@ -315,17 +320,16 @@ def train(self): for step, batch in tqdm(enumerate(self.trainloader), total=len(self.trainloader), desc=f'Epoch {epoch}'): step += step_done # for multiple epochs - - seq, label, pad_mask = batch - + + seq, label, pad_mask = jnp.asarray(batch['text']) + loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask, rndm_n, rndm_k, optim, self.num_classes, keys) - if step % 75 == 0: - # cycling through keys to get new n and k + if step % 100 == 0: #rndm_n, rndm_k = self.get_n_k(key=keys[step % self.batch_size]) - accuracy, loss, perplexity = self.compute_metrics(model, seq, label, pad_mask, + accuracy, loss, perplexity = self.compute_metrics(model, seq, label, pad_mask, self.max_iters, self.num_classes, keys) train_acc.append(accuracy) @@ -337,12 +341,15 @@ def train(self): self.wandb_logger.log( { 'Train/loss': loss, - 'Train/Lr': self.schedule_fn(epoch + 1 * step).item(), + 'Train/Lr': self.schedule_fn(epoch + 1 * step).item() }, step=step ) + + if trial is not None and (trial.should_prune() or jnp.isnan(loss)): + raise optuna.exceptions.TrialPruned() - # Terminate if loss is NaN + # Terminate training if loss is NaN if jnp.isnan(loss): self.my_logger.warning(f'\nLoss is NaN at step {step}') return loss @@ -390,10 +397,13 @@ def train(self): self.my_logger.info(f'Model saved at {filepath}') self.wandb_logger.save(filepath) - print(f'Epoch {epoch} done!') step_done = step # prepare for next epoch + trial.report(loss, epoch) # report the loss for optuna + + print(f'Epoch {epoch} done!') - self.wandb_logger.finish() + self.wandb_logger.finish() # Cleanup + return loss def generate(self, model: eqx.Module, input_arr: Array, metadata: dict, max_new_tokens: int, temperature: float = 0.5): @@ -404,7 +414,6 @@ def generate(self, model: eqx.Module, input_arr: Array, metadata: dict, max_new_ key = jax.random.PRNGKey(0) inference_model = eqx.nn.inference_mode(model) - text_table = wandb.Table(columns=["Step", "Prompt", "Model Generation", "Type"]) prompt = f'Prompt: {self.decode_fn(input_arr)}' for _ in range(max_new_tokens): @@ -435,7 +444,7 @@ def generate(self, model: eqx.Module, input_arr: Array, metadata: dict, max_new_ self.my_logger.info(model_gen) # log to logger as a table - text_table.add_data(metadata['step'], prompt, model_gen, metadata['type']) - self.wandb_logger.log({'Generated Samples': text_table}) + self.text_table.add_data(metadata["step"], prompt, model_gen, metadata["type"]) + self.wandb_logger.log({"Generated Samples": self.text_table}) - return input_arr + return input_arr \ No newline at end of file diff --git a/run.sh b/run.sh index 53c771a..becf647 100644 --- a/run.sh +++ b/run.sh @@ -1,20 +1,25 @@ -a#!/bin/bash +#!/bin/bash BRANCH="dev" IMAGE_NAME="docker.io/neel04/react_image:latest" CONTAINER_NAME="react_container" # arguments for train_model.py -TRAIN_ARGS="--save_dir ./ReAct/outputs/ --epochs 4 --warmup_steps 300 \ ---lr 5e-3 --num_blocks 4 \ ---width 384 --batch_size 512 --n_heads 4 --max_iters 5 \ ---weight_decay 3e-4 --drop_rate 0.0001 \ ---log_interval 1000 --save_interval 1000 --seqlen 192 \ ---bf16 --accum_steps 1 --exp_logging" #--tune_hyperparams" +TRAIN_ARGS="--save_dir ./ReAct/outputs/ --dataset 'minipile' --group 'minipile' \ +--num_blocks 4 --width 384 --n_heads 8 --max_iters 5 --epochs 2 --num_classes 50304 \ +--log_interval 500 --save_interval 2000 --seqlen 512 \ +--bf16 --accum_steps 2 --batch_size 512 \ +--warmup_steps 500 --lr 4.5e-3 \ +--weight_decay 5e-4 --drop_rate 0.01 \ +--exp_logging" # Stop all running Docker containers echo "Stopping all running Docker containers..." -sudo docker stop $(sudo docker ps -a -q) -sudo docker rm -f $(sudo docker ps -a -q) + +if ! timeout 300 sudo docker rm -f $CONTAINER_NAME; then + echo "Command timed out. Restarting Docker daemon & retrying..." + sudo systemctl restart docker + sleep 10s; sudo docker rm -f $CONTAINER_NAME +fi # Git stuff git clone -b $BRANCH https://github.com/neel04/ReAct_Jax.git @@ -25,7 +30,7 @@ cd ReAct_Jax/; git pull --all; cd .. # Run the Docker container echo "Running Docker container..." -docker run --pull 'always' -v $(pwd)/ReAct_Jax/:/ReAct_Jax/ -e EQX_ON_ERROR=nan --privileged --rm --net=host --name $CONTAINER_NAME -it -d $IMAGE_NAME +docker run --pull 'always' -v $(pwd)/ReAct_Jax/:/ReAct_Jax/ -e EQX_ON_ERROR=nan -e PJRT_DEVICE=TPU -e XLA_USE_SPMD=1 --privileged --rm --net=host --name $CONTAINER_NAME -it -d $IMAGE_NAME # Get docker container ID to copy files CONTAINER_ID=$(docker ps -aqf "name=$CONTAINER_NAME") diff --git a/train_model.py b/train_model.py index af3a562..f93bd86 100644 --- a/train_model.py +++ b/train_model.py @@ -4,7 +4,6 @@ if platform.processor() != 'arm': jax.distributed.initialize() # don't run on apple sillicon -import jax.numpy as jnp import optuna from jax import config from jax.experimental.compilation_cache import compilation_cache @@ -43,20 +42,32 @@ def main(key: PRNGKeyArray): val_dataset = dataset(split='test', max_length=args.seqlen, bsz=args.batch_size) # ========= Training/Hypertuning ========= + init_hyperparams = [ + {"lr": 1e-3, "drop_rate": 0.01, "weight_decay": 8e-4, "warmup_steps": 100}, + {"lr": 7e-3, "drop_rate": 0.01, "weight_decay": 8e-4, "warmup_steps": 0}, + {"lr": 2e-2, "drop_rate": 0.01, "weight_decay": 8e-4, "warmup_steps": 0}, + ] if args.tune_hyperparams: args.group = 'Sweeps' if args.baseline else 'Sweeps_5i' - - trainloader = train_dataset.create_dataloader('40%') - valloader = val_dataset.create_dataloader('40%') - study = optuna.create_study(direction='minimize', - study_name='ReAct_Jax', - load_if_exists=True, - sampler=optuna.samplers.TPESampler( - seed=69, - consider_magic_clip=True, - )) + trainloader = train_dataset.create_dataloader("40%") + valloader = val_dataset.create_dataloader("40%") + + # Create optuna hypertununing study + study = optuna.create_study( + direction="minimize", + load_if_exists=True, + sampler=optuna.samplers.TPESampler( + seed=69, + consider_magic_clip=True, + consider_endpoints=True, + n_startup_trials=5, + ), + pruner=optuna.pruners.MedianPruner( + n_startup_trials=5, n_warmup_steps=200, n_min_trials=10 + ), + ) wandb_kwargs = { "project": "ReAct_Jax", @@ -74,24 +85,40 @@ def main(key: PRNGKeyArray): } wandbc = WeightsAndBiasesCallback( - metric_name='Train/loss', + metric_name='Train/acc', wandb_kwargs=wandb_kwargs, as_multirun=True ) - study.optimize(lambda trial: kickoff_optuna(trial=trial, **trainer_kwargs), n_trials=50, callbacks=[wandbc]) + # enqueue a few handpicked hyperparams for trials + [study.enqueue_trial(hyperparams) for hyperparams in init_hyperparams] + + study.optimize( + lambda trial: kickoff_optuna(trial=trial, **trainer_kwargs), + n_trials=50, + callbacks=[wandbc], + ) + + fig = optuna.visualization.plot_optimization_history(study) + fig.write_html("optuna_plot.html") + + print(f"Best trial: {study.best_trial}") + print(f'\nValue: {study.best_trial.value}\nParams: {study.best_trial.params}\n') else: trainloader = train_dataset.create_dataloader() valloader = val_dataset.create_dataloader() - logger = UnifiedLogger(args, level='DEBUG') + logger = UnifiedLogger(args, level="DEBUG") my_logger, wandb_logger = logger.my_logger(), logger.wandb_logger(args) - trainer = Trainer(args, logger=(my_logger, wandb_logger), - loaders=(trainloader, valloader), - decode_fn=train_dataset.tok.decode, - key=key) + trainer = Trainer( + args, + logger=(my_logger, wandb_logger), + loaders=(trainloader, valloader), + decode_fn=train_dataset.tok.decode, + key=key, + ) my_logger.info(f"# of all devices: {jax.device_count()}") my_logger.info(f"# of hosts: {jax.process_count()}") @@ -105,10 +132,10 @@ def kickoff_optuna(trial, **trainer_kwargs): args.epochs = 1 - args.lr = trial.suggest_float('lr', 1e-4, 1e-2) - args.drop_rate = trial.suggest_float('drop_rate', 0.0, 0.2) - args.weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3) - args.warmup_steps = trial.suggest_int('warmup_steps', 0, 500, step=50) + args.lr = trial.suggest_float('lr', 1e-4, 1e-2, step=1e-4) + args.drop_rate = trial.suggest_float('drop_rate', 0.0, 0.1, step=0.01) + args.weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3, step=2e-4) + args.warmup_steps = trial.suggest_int('warmup_steps', 0, 500, step=100) args = trainer_kwargs['args'] @@ -124,9 +151,9 @@ def kickoff_optuna(trial, **trainer_kwargs): my_logger.info(f"Host id: {jax.process_index()}") with jax.spmd_mode('allow_all'): - loss = trainer.train() - - return jnp.nan_to_num(loss, nan=9999.0) # return the loss + loss = trainer.train(trial) + + return loss if __name__ == '__main__': compilation_cache.initialize_cache('./compilation_cache')