Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* optimized throughput by 15% and added LR logging * Switch to Lion optimizer * why tf am I jax.checkpoint-ing still * Added prelim sweep and hypertuning functionality * Added hyperband * Fixed hyperparameter tuning * Removed alpha, fixed some bugs * Initial Optuna integration * Added NaN handling * Slightly updated hypertuning search procedure * Re-added n_k_loop for more experimentation * Removed the history in react forward pass + other bells and whistles * Fixed the n-pass recipie and other other improvements * Integrating equinoxs while_loop with checkpointing * Setting iters_fwd the default and updating n+k by making it less dynamic * Rought draft of initial multi-node setup * Added scalax for multi-node support * Print hosts, and refresh keys every 75 steps * Get back vanilla n_k loop * Fixed dropout API for baseline model as well * Turned out inference-time dropout * Reverting back to k-only * Added run.sh * Fixed React forward pass not carrying `interim_thought`. Back to n_k * switching to iters_fwd * hardcoded k=5 * Added OpenWebText dataset support * Added minipile dataset * remove JIT * trying to optimize dataset preproc step * force stop multiprocessing for 1st map * turned off multiproc for first map * Sped up data preprocessing * Switched to fast tokenizer * Caching datasets and optimizing dataset preprocessing * reduced hypertuning epochs * Updated dataset length * fixed warmup steps for hypertuning runs * Fixing group naming * Uploading dataset to HF & added GH action * Added docker GH action * Trying to fix the slow dataset split * fix ambigous truth value bug * Refactored data pipeline to be more robust * fixed IO and partially added vmap on baseline to fix compilation speeds * Using 40% of data for sweeps ~600M tokens * force format to numpy * fixed baseline scan * remove tensorflow from dockerfile * use vmap to reduce compilation time for react * Some optimizations + filter_spec-ing * double the sweeps data * removed LN and fixed sampling * put the LN back * speedup all_gather * reduced sweep dataset slice * Initial FSDP code * sharding metrics computation to be fsdp * disabling DP in metrics computation * fixed FSDP not doing DP on `compute_metrics` * cast to jnp array * all gather sharded model before serializing it * FSDP -> DDP. Dataset length autodetection * Support for DDP + TP * report devices for tuning hyperparams * precasting the data to jax arrays * Turn off TP. OnlyDDP * reduce comms overhead * Don't donate model * Fully removed TP code * Sharding compute_metrics * Putting the numpificaiton of data in the trainloader * Fixed LR scheduling=
- Loading branch information