Skip to content

Commit

Permalink
WTF is wrong with git (#13)
Browse files Browse the repository at this point in the history
* optimized throughput by 15% and added LR logging

* Switch to Lion optimizer

* why tf am I jax.checkpoint-ing still

* Added prelim sweep and hypertuning functionality

* Added hyperband

* Fixed hyperparameter tuning

* Removed alpha, fixed some bugs

* Initial Optuna integration

* Added NaN handling

* Slightly updated hypertuning search procedure

* Re-added n_k_loop for more experimentation

* Removed the history in react forward pass + other bells and whistles

* Fixed the n-pass recipie and other other improvements

* Integrating equinoxs while_loop with checkpointing

* Setting iters_fwd the default and updating n+k by making it less dynamic

* Rought draft of initial multi-node setup

* Added scalax for multi-node support

* Print hosts, and refresh keys every 75 steps

* Get back vanilla n_k loop

* Fixed dropout API for baseline model as well

* Turned out inference-time dropout

* Reverting back to k-only

* Added run.sh

* Fixed React forward pass not carrying `interim_thought`. Back to n_k

* switching to iters_fwd

* hardcoded k=5

* Added OpenWebText dataset support

* Added minipile dataset

* remove JIT

* trying to optimize dataset preproc step

* force stop multiprocessing for 1st map

* turned off multiproc for first map

* Sped up data preprocessing

* Switched to fast tokenizer

* Caching datasets and optimizing dataset preprocessing

* reduced hypertuning epochs

* Updated dataset length

* fixed warmup steps for hypertuning runs

* Fixing group naming

* Uploading dataset to HF & added GH action

* Added docker GH action

* Trying to fix the slow dataset split

* fix ambigous truth value bug

* Refactored data pipeline to be more robust

* fixed IO and partially added vmap on baseline to fix compilation speeds

* Using 40% of data for sweeps ~600M tokens

* force format to numpy

* fixed baseline scan

* remove tensorflow from dockerfile

* use vmap to reduce compilation time for react

* Some optimizations + filter_spec-ing

* double the sweeps data

* removed LN and fixed sampling

* put the LN back

* speedup all_gather

* reduced sweep dataset slice

* Initial FSDP code

* sharding metrics computation to be fsdp

* disabling DP in metrics computation

* fixed FSDP not doing DP on `compute_metrics`

* cast to jnp array

* all gather sharded model before serializing it

* FSDP -> DDP. Dataset length autodetection

* Support for DDP + TP

* report devices for tuning hyperparams

* precasting the data to jax arrays

* Turn off TP. OnlyDDP

* reduce comms overhead

* Don't donate model

* Fully removed TP code

* Sharding compute_metrics

* Putting the numpificaiton of data in the trainloader

* Fixed LR scheduling=
  • Loading branch information
neel04 authored Apr 11, 2024
1 parent e78828b commit 43e80ff
Showing 0 changed files with 0 additions and 0 deletions.

0 comments on commit 43e80ff

Please sign in to comment.