Skip to content

Commit

Permalink
Added EAI FLOPs calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Apr 14, 2024
1 parent 9f28c97 commit e7f1f96
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
46 changes: 37 additions & 9 deletions ReAct/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,50 @@
import math
import os
from typing import Callable, Optional

import equinox as eqx
import jax
import jax.numpy as jnp

from jax import tree_util as jtu
from jaxtyping import Array, PRNGKeyArray
from typing import Optional, Callable

def calc_performance_metrics(fn: Callable, static_argnums: tuple[int], args: tuple[int]) -> float:
def convert_flops(params: int) -> str:
if params == 0:
return "0"

size_name = ("", "KFLOPs", "MFLOPs", "GFLOPs", "TFLOPs", "PFLOPs", "EFLOPs", "ZFLOPs", "YFLOPs")
i = int(math.floor(math.log(params, 1000)))
p = math.pow(1000, i)
s = round(params / p, 2)

return "%s %s" % (s, size_name[i])

def calc_performance_metrics(args, my_logger: Callable) -> None:
'''
Calculate the number of FLOPs and memory requirements
for a given function using AOT compilation.
Returns the number of FLOPs in PetaFLOPs
Estimates FLOPs consumed during a single fwd + bwd pass.
Taken from EleutherAI's GPT-NeoX repo: https://rb.gy/33d6zg
Returns: the total number of FLOPs
'''
compiled = jax.jit(fn, static_argnums=static_argnums).lower(*args).compile()
cost_anal = compiled.cost_analysis()
iter_factor = 3
args.tokens = args.batch_size * args.seqlen
args.kv_size_ratio = 1

# TODO: Ignores activation checkpointing. Fix this at some point
my_logger.warning('! Ignoring activation checkpointing in FLOPs calculation !')

qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_classes * args.tokens * args.width * args.width)
attention_matrix_flops = iter_factor * 2 * args.num_classes * args.tokens * args.seqlen * args.width
attention_over_values_flops = iter_factor * 2 * args.num_classes * args.tokens * args.seqlen * args.width
linear_projection_flops = iter_factor * 2 * args.num_classes * args.tokens * args.width * args.width
ffn_flops = iter_factor * 16 * args.num_classes * args.tokens * args.width * args.width

# handle NewGELU
ffn_flops *= 3.75

return cost_anal[0]['flops'] / 1e15
embedding_flops = 6 * args.tokens * args.width * args.num_classes
total_flops = qkv_flops + attention_matrix_flops + attention_over_values_flops + linear_projection_flops + ffn_flops + embedding_flops
my_logger.info(f"Total FLOPs for the Model: {convert_flops(total_flops)} for a single fwd + bwd pass\n")

def half_precision(model: eqx.Module) -> eqx.Module:
return jtu.tree_map(lambda x: x.astype(jnp.bfloat16) if eqx.is_inexact_array(x) else x, model)
Expand Down
11 changes: 4 additions & 7 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def init_model(self, key: PRNGKeyArray):
model = half_precision(model)

_, opt_state, model = self.set_optim_and_scheduler(model)

count_params(model) # prints to stdout
calc_performance_metrics(self.args, self.my_logger) # logs via logger

return opt_state, model

Expand Down Expand Up @@ -322,11 +324,7 @@ def train(self):

if step % 100 == 0:
#rndm_n, rndm_k = self.get_n_k(key=keys[step % self.batch_size])
pflops_consumed = calc_performance_metrics(make_step,
static_argnums=(2, 8, 9),
args=(model, opt_state, filter_spec, seq, label,
pad_mask, rndm_n, rndm_k, optim, self.num_classes, keys))


accuracy, loss, perplexity = self.compute_metrics(model, seq, label, pad_mask,
self.max_iters, self.num_classes, keys)

Expand All @@ -339,8 +337,7 @@ def train(self):
self.wandb_logger.log(
{
'Train/loss': loss,
'Train/Lr': self.schedule_fn(epoch + 1 * step).item(),
'Metrics/Step_PFLOPs': pflops_consumed,
'Train/Lr': self.schedule_fn(epoch + 1 * step).item()
},
step=step
)
Expand Down

0 comments on commit e7f1f96

Please sign in to comment.