Skip to content

Commit

Permalink
FLOPs calculation + storing dataset on CPU during recomputation
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Apr 13, 2024
1 parent 9f67154 commit 61f5c4c
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 49 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,5 @@ cython_debug/
**wandb
*.wandb
*.eqx
*cached_data
*cached_data
*pyrightconfig.json
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,5 @@ gcloud compute tpus tpu-vm ssh node-v4 \
If you get errors regarding workers not being able to sync up at the distributed barrier, do:

```bash
gcloud compute tpus tpu-vm ssh --zone "us-central2-b" "ondem" \
--project "react-jax" \
--command 'sudo docker system prune -f && sudo rm -rf ~/.cache;'
gcloud compute tpus tpu-vm ssh --zone "us-central2-b" "ondem" --worker 'all' --project "react-jax" --command 'sudo docker system prune -f && sudo rm -rf ~/.cache;'
```
15 changes: 5 additions & 10 deletions ReAct/data/minipile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class MiniPileDataset:
def __init__(self, split: str = 'train', max_length: int = 512, bsz: int = 256, vocab_dir: str ='./ReAct/data'):
datasets.config.IN_MEMORY_MAX_SIZE = 1e+11

self.cpus = jax.devices("cpu")
self.bsz = bsz
self.max_length = max_length + 1
self.split = split
Expand Down Expand Up @@ -93,12 +94,6 @@ def take_subset(self, dataset, elements: int) -> None:

return dataset

def numpify(self, dataset: datasets.Dataset) -> datasets.Dataset:
'''
Convert the dataset to numpy arrays
'''
return jax.tree_map(lambda x: jnp.asarray(x), dataset['text'])

def create_dataloader(self, slice: str = '100%'):
data_path = Path(f'./cached_data/minipile_{self.split}.data')

Expand All @@ -108,9 +103,9 @@ def create_dataloader(self, slice: str = '100%'):

print(f'Loaded {self.split} dataset from HuggingFace Hub')

dataset.set_format(type='numpy')
dataset.set_format(type='jax')

return self.numpify(dataset)
return dataset

except (FileNotFoundError, ValueError):
if os.path.exists(data_path):
Expand All @@ -134,9 +129,9 @@ def create_dataloader(self, slice: str = '100%'):
dataset = dataset.map(self.shift_tokens, batched=True, batch_size=self.bsz,
keep_in_memory=True, drop_last_batch=True, num_proc=None)

dataset.set_format(type='numpy')
dataset.set_format(type='jax')

self.upload_dataset(dataset,
hub_path=f'Neel-Gupta/minipile-processed_{self.bsz}') # upload the processed dataset to the Hub

return self.numpify(dataset)
return dataset
26 changes: 13 additions & 13 deletions ReAct/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@

from jax import tree_util as jtu
from jaxtyping import Array, PRNGKeyArray
from typing import Optional
from typing import Optional, Callable

def calc_performance_metrics(fn: Callable, static_argnums: tuple[int], args: tuple[int]) -> float:
'''
Calculate the number of FLOPs and memory requirements
for a given function using AOT compilation.
Returns the number of FLOPs in PetaFLOPs
'''
compiled = jax.jit(fn, static_argnums=static_argnums).lower(*args).compile()
cost_anal = compiled.cost_analysis()

return cost_anal[0]['flops'] / 1e15

def half_precision(model: eqx.Module) -> eqx.Module:
return jtu.tree_map(lambda x: x.astype(jnp.bfloat16) if eqx.is_inexact_array(x) else x, model)
Expand Down Expand Up @@ -59,15 +70,4 @@ def inverted_freq(arr: Array):

inv_weights = (counts.max() / counts) # scale it down

return inv_weights[arr - arr.min()]

if __name__ == '__main__':
import plotly.express as px
import pandas as pd

key = jax.random.PRNGKey(0)
out: Array = get_rand_nums(key, 1, 10, 512, 4)
elems, counts = jnp.unique(out, return_counts=True)
df = pd.DataFrame({'elems': elems, 'counts': counts})
fig = px.bar(df, x='elems', y='counts')
fig.show()
return inv_weights[arr - arr.min()]
45 changes: 24 additions & 21 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import jax.numpy as jnp
import optax
from jaxtyping import Array, PRNGKeyArray, PyTree
from scalax.sharding import MeshShardingHelper, PartitionSpec as P
from scalax.sharding import MeshShardingHelper
from scalax.sharding import PartitionSpec as P
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

Expand All @@ -16,7 +17,7 @@
from ReAct.model.react import React
from ReAct.utils.helpers import count_params, load_eqx_obj, save_eqx_obj

from .helpers import broad_to_bsz, half_precision
from .helpers import broad_to_bsz, calc_performance_metrics, half_precision

mesh = MeshShardingHelper(axis_dims=[-1], axis_names=['data']) # handle DDP + TP over multi-node

Expand Down Expand Up @@ -95,12 +96,12 @@ def make_step(model: eqx.Module,

@eqx.filter_value_and_grad
def compute_loss(model: eqx.Module, static_model: PyTree, x: Array, y: Array, pad_mask: Array,
n: int, k: int, num_classes: int, keys: PRNGKeyArray = None) -> Tuple[int, PyTree]:
n: int, k: int, num_classes: int, keys: PRNGKeyArray) -> int:
'''
Computes the loss of the model w.r.t the input. Is a closure for accessing static_model
Computes the loss of the model w.r.t the input. Is a closure for accessing static_model
'''
model = eqx.combine(model, static_model)

if model.__name__ == 'ReAct':
forward = iters_fwd
else:
Expand All @@ -114,7 +115,7 @@ def compute_loss(model: eqx.Module, static_model: PyTree, x: Array, y: Array, pa

diff_model, static_model = eqx.partition(model, filter_spec,
is_leaf=lambda x: isinstance(x, eqx.nn.Dropout))

loss, grads = compute_loss(diff_model, static_model, x, y, pad_mask, n, k, num_classes, keys)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
Expand All @@ -137,7 +138,7 @@ def __init__(self,
self.my_logger, self.wandb_logger = logger
self.trainloader, self.valloader = loaders
self.dataset_length = len(self.trainloader) * args.batch_size * args.seqlen

self.my_logger.info(f'Using Args: {self.args}\n')

# Assign each arg as a class attribute
Expand All @@ -162,10 +163,8 @@ def evaluate_acc(self, model: eqx.Module, loader: DataLoader, eval_iters: int, k
metric = []

for step, batch in tqdm(enumerate(loader), total=len(loader), desc='Validating'):
seq, label, pad_mask = batch

seq, label, pad_mask = batch['text']
acc, loss, ppl = self.compute_metrics(model, seq, label, pad_mask, eval_iters, self.num_classes, keys)

metric.extend([acc, loss, ppl])

# Compute cumulatives
Expand Down Expand Up @@ -202,7 +201,7 @@ def get_filterspec(model: eqx.Module) -> PyTree[bool]:
'''
Returns a filter spec for the model to filter out the trainable parameters.
Can be used to freeze or unfreeze certain modules of the model depending on the step and epoch.
Args:
model: The model to filter
Returns:
Expand All @@ -213,9 +212,9 @@ def get_filterspec(model: eqx.Module) -> PyTree[bool]:
lambda tree: tree.pos_enc, # pos_enc should be frozen
filter_spec,
replace=False)

return filter_spec

def init_model(self, key: PRNGKeyArray):

if self.baseline:
Expand All @@ -228,7 +227,7 @@ def init_model(self, key: PRNGKeyArray):
# switch to half precision
if self.bf16:
model = half_precision(model)

_, opt_state, model = self.set_optim_and_scheduler(model)
count_params(model) # prints to stdout

Expand Down Expand Up @@ -315,17 +314,20 @@ def train(self):

for step, batch in tqdm(enumerate(self.trainloader), total=len(self.trainloader), desc=f'Epoch {epoch}'):
step += step_done # for multiple epochs
seq, label, pad_mask = batch

seq, label, pad_mask = batch['text']

loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask,
rndm_n, rndm_k, optim, self.num_classes, keys)

if step % 75 == 0:
# cycling through keys to get new n and k
if step % 100 == 0:
#rndm_n, rndm_k = self.get_n_k(key=keys[step % self.batch_size])

accuracy, loss, perplexity = self.compute_metrics(model, seq, label, pad_mask,
pflops_consumed = calc_performance_metrics(make_step,
static_argnums=(2, 8, 9),
args=(model, opt_state, filter_spec, seq, label,
pad_mask, rndm_n, rndm_k, optim, self.num_classes, keys))

accuracy, loss, perplexity = self.compute_metrics(model, seq, label, pad_mask,
self.max_iters, self.num_classes, keys)

train_acc.append(accuracy)
Expand All @@ -338,6 +340,7 @@ def train(self):
{
'Train/loss': loss,
'Train/Lr': self.schedule_fn(epoch + 1 * step).item(),
'Metrics/Step_PFLOPs': pflops_consumed,
},
step=step
)
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
a#!/bin/bash
#!/bin/bash
BRANCH="dev"
IMAGE_NAME="docker.io/neel04/react_image:latest"
CONTAINER_NAME="react_container"
Expand Down

0 comments on commit 61f5c4c

Please sign in to comment.