Skip to content

Commit

Permalink
Revamped profiling logic
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Nov 23, 2024
1 parent 880dcc7 commit 8067e74
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 15 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ RUN pip3 install numpy pandas scipy

RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly lm-eval pdbpp
RUN pip3 install -U optuna-integration plotly lm-eval pdbpp
RUN pip3 install git+https://github.com/deepmind/jmp
RUN pip3 install git+https://github.com/Findus23/jax-array-info.git
RUN pip3 install -q tensorflow tensorboard-plugin-profile "cloud-tpu-profiler>=2.3.0"

WORKDIR /ReAct_Jax

Expand Down
25 changes: 14 additions & 11 deletions ReAct/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@

class Profiler:
def __init__(
self, activate_profiler: bool = False, logdir: str = "./profiles/"
self, activate_profiler: bool = True, logdir: str = "./profiles/"
) -> None:
self.warmup_steps = 20
self.activate_profiler = activate_profiler
self.logdir = logdir

def start_prof(self) -> None:
if self.activate_profiler:
print(f'Started TensorBoard Profiler at: {self.logdir}')
jax.profiler.start_trace(self.logdir)
def start_prof(self, step: int) -> None:
if step == self.warmup_steps:
if self.activate_profiler:
print(f'Started TensorBoard Profiler at: {self.logdir}')
jax.profiler.start_trace(self.logdir)

def stop_prof(self, output: Any) -> None:
if self.activate_profiler:
output = output.block_until_ready() # wait for output
jax.profiler.stop_trace()
print(f'Stopped Profiler at: {self.logdir}')
def stop_prof(self, output: Any, step: int) -> None:
if step == self.warmup_steps:
if self.activate_profiler:
output = output.block_until_ready() # wait for output
jax.profiler.stop_trace()
print(f'Stopped Profiler at: {self.logdir}')

exit()
self.activate_profiler = False

def convert_flops(params: int) -> str:
if params == 0:
Expand Down
5 changes: 2 additions & 3 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,9 @@ def train(self, trial: Optional[Any] = None) -> float:
epoch_key = jnp.array([epoch, epoch + 1]).astype(jnp.uint32)
keys = jax.random.split(epoch_key, self.args.batch_size)

prof.start_prof()

for step, batch in tqdm(enumerate(self.trainloader), total=self.dataset_length, desc=f'Epoch {epoch}'):
step += step_done # for multiple epochs
prof.start_prof(step)

seq, label, pad_mask = jnp.asarray(batch["text"])
seq, label, pad_mask = strategy.shard_cast((seq, label, pad_mask))
Expand All @@ -415,7 +414,7 @@ def train(self, trial: Optional[Any] = None) -> float:
num_classes=self.args.num_classes,
)

prof.stop_prof(loss) # end trace if profiled
prof.stop_prof(loss, step) # end trace if profiled

if step % 100 == 0:
accuracy, loss, perplexity = self.compute_metrics(keys, model, self.args.baseline, seq, label, pad_mask, self.args.max_iters, self.args.num_classes)
Expand Down
1 change: 1 addition & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ if [ ! -f "$FLAG_FILE" ]; then
uv pip install -q optuna-integration wandb lm-eval nvitop pdbpp
uv pip install -q git+https://github.com/deepmind/jmp
uv pip install -q git+https://github.com/Findus23/jax-array-info.git
uv pip install -q tensorflow tensorboard-plugin-profile "cloud-tpu-profiler>=2.3.0"
# ------------------
# Create the flag file
touch "$FLAG_FILE"
Expand Down

0 comments on commit 8067e74

Please sign in to comment.