Skip to content

Commit

Permalink
Attempt #2 at fixing optuna bug
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed May 13, 2024
1 parent de0bad3 commit 2cf34f8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ RUN pip3 install numpy pandas scipy

RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration plotly
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly
RUN pip3 install git+https://github.com/deepmind/jmp

WORKDIR /ReAct_Jax
Expand Down
16 changes: 13 additions & 3 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ def compute_metrics(self,
perplexity = jnp.exp(loss)

return accuracy, loss, perplexity

def optuna_log(self, trial: Optional[Any], metrics: Tuple[float, int]):
'''
Logs the metrics to the optuna trial
'''
loss, epoch = metrics

if trial is not None:
trial.report(loss, epoch)

def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
step_done = 0
Expand Down Expand Up @@ -354,8 +363,9 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
step=step
)

#if trial is not None and (trial.should_prune() or jnp.isnan(loss)):
#raise optuna.exceptions.TrialPruned()
if trial is not None and trial.should_prune():
self.optuna_log(trial, (loss, epoch))
raise optuna.exceptions.TrialPruned()

if jnp.isnan(loss):
self.my_logger.warning(f'\nLoss is NaN at step {step}')
Expand Down Expand Up @@ -405,7 +415,7 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
self.wandb_logger.save(filepath)

step_done = step # prepare for next epoch
trial.report(loss, epoch) if trial is not None else ...
self.optuna_log(trial, (loss, epoch))

print(f'Epoch {epoch} done!')

Expand Down
4 changes: 2 additions & 2 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def kickoff_optuna(trial, **trainer_kwargs):
# Regularization hyperparams
args.lr = trial.suggest_float('lr', 1e-4, 1e-2)
args.drop_rate = trial.suggest_float('drop_rate', 0.0, 0.1, step=0.01)
args.weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3, step=2e-4)
args.weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3)
args.warmup_steps = trial.suggest_int('warmup_steps', 0, 500, step=100)

# Optimizer hyperparams
Expand All @@ -158,7 +158,7 @@ def kickoff_optuna(trial, **trainer_kwargs):
with jax.spmd_mode('allow_all'):
loss = trainer.train(trial)

return loss
return jax.numpy.nan_to_num(loss, nan=9e9)

if __name__ == '__main__':
key = jax.random.PRNGKey(69)
Expand Down

0 comments on commit 2cf34f8

Please sign in to comment.