Skip to content

Commit

Permalink
Hyperparam search: 2 epochs over 20% data
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Apr 25, 2024
1 parent d01bf87 commit 2c4abe5
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ RUN pip3 install numpy pandas scipy

RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration
RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration plotly

WORKDIR /ReAct_Jax

Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python3 inferencer.py --checkpoint_path '/Users/neel/Documents/research/ReAct_Ja
First, get a preemptible TPUv4-8 node as a queued resource:

```bash
gcloud alpha compute tpus queued-resources create node-v4 \
gcloud alpha compute tpus queued-resources create $INSTANCE_NAME \
--node-id node-v4 \
--project react-jax \
--zone us-central2-b \
Expand All @@ -38,7 +38,7 @@ gcloud alpha compute tpus queued-resources create node-v4 \
Setup the TPU pod slice with basics:

```bash
gcloud compute tpus tpu-vm ssh node-v4 \
gcloud compute tpus tpu-vm ssh $INSTANCE_NAME \
--zone=us-central2-b --worker=all --command="\
sudo apt-get update; \
sudo snap install nvim --classic; \
Expand All @@ -49,7 +49,7 @@ gcloud compute tpus tpu-vm ssh node-v4 \
And then actually kickoff the training by downloading the script and running it:

```bash
gcloud compute tpus tpu-vm ssh node-v4 \
gcloud compute tpus tpu-vm ssh $INSTANCE_NAME \
--zone=us-central2-b --worker=all --command="\
tmux kill-server; sudo rm -rf ./*; \
sleep 3s && wget https://gist.githubusercontent.com/neel04/3bfc7e4d9cd746829b7e72f1b6fac5de/raw/run.sh; \
Expand All @@ -59,5 +59,5 @@ gcloud compute tpus tpu-vm ssh node-v4 \
If you get errors regarding workers not being able to sync up at the distributed barrier, do:

```bash
gcloud compute tpus tpu-vm ssh --zone "us-central2-b" "ondem" --worker 'all' --project "react-jax" --command 'sudo docker system prune -f && sudo rm -rf ~/.cache;'
gcloud compute tpus tpu-vm ssh --zone "us-central2-b" $INSTANCE_NAME --worker 'all' --project "react-jax" --command 'sudo docker system prune -f && sudo rm -rf ~/.cache;'
```
1 change: 1 addition & 0 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
self.wandb_logger.save(filepath)

print(f'Epoch {epoch} done!')
trial.report(train_acc, epoch) # report the accuracy for optuna
step_done = step # prepare for next epoch

self.wandb_logger.finish() # Cleanup
Expand Down
8 changes: 3 additions & 5 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def main(key: PRNGKeyArray):
if args.tune_hyperparams:
args.group = 'Sweeps' if args.baseline else 'Sweeps_5i'

trainloader = train_dataset.create_dataloader('40%')
valloader = val_dataset.create_dataloader('40%')
trainloader = train_dataset.create_dataloader('20%')
valloader = val_dataset.create_dataloader('20%')

# Create optuna hypertununing study
study = optuna.create_study(
Expand Down Expand Up @@ -119,7 +119,7 @@ def main(key: PRNGKeyArray):
def kickoff_optuna(trial, **trainer_kwargs):
args = trainer_kwargs['args']

args.epochs = 1
args.epochs = 2

args.lr = trial.suggest_float('lr', 1e-4, 1e-3, step=1e-4)
args.drop_rate = trial.suggest_float('drop_rate', 0.0, 0.1, step=0.02)
Expand All @@ -142,8 +142,6 @@ def kickoff_optuna(trial, **trainer_kwargs):
with jax.spmd_mode('allow_all'):
loss = trainer.train(trial)

trial.report(loss, 1)

return loss

if __name__ == '__main__':
Expand Down

0 comments on commit 2c4abe5

Please sign in to comment.