diff --git a/Dockerfile b/Dockerfile index a816dd2..8206f8d 100755 --- a/Dockerfile +++ b/Dockerfile @@ -19,7 +19,7 @@ RUN pip3 install numpy pandas scipy RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich -RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration plotly +RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly RUN pip3 install git+https://github.com/deepmind/jmp WORKDIR /ReAct_Jax diff --git a/ReAct/utils/trainer.py b/ReAct/utils/trainer.py index 74746af..95bd1f6 100644 --- a/ReAct/utils/trainer.py +++ b/ReAct/utils/trainer.py @@ -304,6 +304,15 @@ def compute_metrics(self, perplexity = jnp.exp(loss) return accuracy, loss, perplexity + + def optuna_log(self, trial: Optional[Any], metrics: Tuple[float, int]): + ''' + Logs the metrics to the optuna trial + ''' + loss, epoch = metrics + + if trial is not None: + trial.report(loss, epoch) def train(self, trial: Optional[Any] = None) -> Tuple[float, int]: step_done = 0 @@ -354,8 +363,9 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]: step=step ) - #if trial is not None and (trial.should_prune() or jnp.isnan(loss)): - #raise optuna.exceptions.TrialPruned() + if trial is not None and trial.should_prune(): + self.optuna_log(trial, (loss, epoch)) + raise optuna.exceptions.TrialPruned() if jnp.isnan(loss): self.my_logger.warning(f'\nLoss is NaN at step {step}') @@ -405,7 +415,7 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]: self.wandb_logger.save(filepath) step_done = step # prepare for next epoch - trial.report(loss, epoch) if trial is not None else ... + self.optuna_log(trial, (loss, epoch)) print(f'Epoch {epoch} done!') diff --git a/train_model.py b/train_model.py index 643dd46..814aefa 100644 --- a/train_model.py +++ b/train_model.py @@ -134,7 +134,7 @@ def kickoff_optuna(trial, **trainer_kwargs): # Regularization hyperparams args.lr = trial.suggest_float('lr', 1e-4, 1e-2) args.drop_rate = trial.suggest_float('drop_rate', 0.0, 0.1, step=0.01) - args.weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3, step=2e-4) + args.weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3) args.warmup_steps = trial.suggest_int('warmup_steps', 0, 500, step=100) # Optimizer hyperparams @@ -158,7 +158,7 @@ def kickoff_optuna(trial, **trainer_kwargs): with jax.spmd_mode('allow_all'): loss = trainer.train(trial) - return loss + return jax.numpy.nan_to_num(loss, nan=9e9) if __name__ == '__main__': key = jax.random.PRNGKey(69)