Skip to content

Commit

Permalink
Squash commit from dev to sync dev and main.
Browse files Browse the repository at this point in the history
commit f07a984
Author: Neel Gupta <[email protected]>
Date:   Sun Apr 28 01:27:00 2024 +0100

    added timeout handling for docker container removal

commit a91424a
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 27 11:01:13 2024 +0100

    init trials are now less reg and more lr

commit 6546b9a
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 26 10:40:07 2024 +0100

    Update data split to 40%

commit 7246119
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 25 17:52:26 2024 +0100

    Enqueue handpicked trials for hypertuning

commit 0a58864
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 25 17:20:30 2024 +0100

    Reported wrong metric

commit 2c4abe5
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 25 12:44:23 2024 +0100

    Hyperparam search: 2 epochs over 20% data

commit d01bf87
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 24 14:54:21 2024 +0100

    Trying to fix optuna

commit a36c09a
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 24 00:03:50 2024 +0100

    Optuna: Using brute force sampler

commit 8bccf9c
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 23 11:39:42 2024 +0100

    Sweep on a full epoch of the dataset

commit 604f397
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 23 00:15:45 2024 +0100

    Fix not averaging the train_acc across a batch

commit 7e765b4
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 22 22:45:08 2024 +0100

    Optuna maximize training accuracy

commit 309f607
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 22 12:36:58 2024 +0100

    Fixing optuna hypertuning feedback loop

commit 1ed4520
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 20 21:34:12 2024 +0100

    Revert "He Initialization"

    This reverts commit 52f682e.

commit 52f682e
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 20 19:55:22 2024 +0100

    He Initialization

commit 8e76e0c
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 20 15:31:14 2024 +0100

    EXP: Full on ctx gate as an MLP for expressiveness

commit e568837
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 20 12:19:25 2024 +0100

    EXP: Lerp + residual conection

commit 5a3f9c9
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 20 02:20:59 2024 +0100

    EXP: Lerp between gated input_array and aggregated history

commit 7580b44
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 19 23:21:42 2024 +0100

    Fixed an error in MLP experiment

commit 5578bf2
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 19 20:17:07 2024 +0100

    EXP: Increasing expressiveness of the MLP

commit 3367b53
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 18 23:19:51 2024 +0100

    EXP: Back to simple sigmoid gating

commit 486a843
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 18 21:00:21 2024 +0100

    EXP: forget_gate + MHSA w/ the LTM gate

commit 3da7c22
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 18 20:11:10 2024 +0100

    EXP: Sigmoid gating + LTM (no concat/add)

commit 63f910b
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 18 18:26:55 2024 +0100

    EXP: MLP over carry state - no concat/add

commit 9541aad
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 18 14:50:59 2024 +0100

    EXP: MLP over forget_gate

commit b053da8
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 18 12:29:06 2024 +0100

    EXP: Forget gate only w/ softmax

commit 692a5e2
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 17 23:58:50 2024 +0100

    EXP: Only forget gate on aggregated (mean) out

commit daa991b
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 17 21:29:51 2024 +0100

    fixed sum

commit 307cbfc
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 17 12:49:51 2024 +0100

    EXP: Aggregating history and conditioning forget gate + LTM on that

commit 735c167
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 17 11:53:52 2024 +0100

    EXP: Softmax for forget gate + swapped ctx x-attn

commit f0417a5
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 17 11:01:45 2024 +0100

    EXP: Remove addition on the LTM gate

commit 79fba12
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 17 10:16:52 2024 +0100

    removed forget gate

commit 94e4aba
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 17 01:06:12 2024 +0100

    EXP: Recurrent LTM styled context state propogation

commit 19a99b0
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 16 18:51:11 2024 +0100

    EXP: Slightly different weighing by a lone scalar

commit ae2e48f
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 16 15:04:42 2024 +0100

    EXP: weighting each sequence element at every recursive iteration

commit 3601031
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 16 13:00:42 2024 +0100

    EXP: x-attn after first MHSA block

commit f3b334b
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 16 10:54:56 2024 +0100

    EXP: input_arr skip connection

commit 65486fe
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 16 00:07:15 2024 +0100

    EXP: thought skip connection + XLA FLOPs computation

commit dd245ec
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 15 22:35:59 2024 +0100

    EXP: block initial input_arr concat + x_attn on first block only

commit b1c9ed0
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 15 18:55:42 2024 +0100

    EXP: Alternate x-attn or MHSA

commit 0533f17
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 15 17:59:51 2024 +0100

    EXP: All cross-attention

commit bbf36c1
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 15 15:35:11 2024 +0100

    Exp: OG Xattn + fixed iterate_for_steps impl.

commit ab02f05
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 15 11:35:42 2024 +0100

    Fixed UT flow dependence on `input_arr` instead of `x`

commit 52e9987
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 15 01:16:04 2024 +0100

    EXP: LinComb

commit e7f1f96
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 15 00:54:56 2024 +0100

    Added EAI FLOPs calculation

commit 9f28c97
Author: Neel Gupta <[email protected]>
Date:   Sun Apr 14 20:42:13 2024 +0100

    Convert numpy array to jax in the training loop

commit a572c47
Author: Neel Gupta <[email protected]>
Date:   Sun Apr 14 11:44:22 2024 +0100

    EXP: Upgrade to a jax.lax.cond

commit 1cb791d
Author: Neel Gupta <[email protected]>
Date:   Sun Apr 14 11:40:32 2024 +0100

    EXP: Use cross attention

commit 61f5c4c
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 13 23:13:51 2024 +0100

    FLOPs calculation + storing dataset on CPU during recomputation

commit 9f67154
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 11 21:41:12 2024 +0100

    Fixed LR scheduling=

commit b3d7540
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 11 16:59:28 2024 +0100

    Putting the numpificaiton of data in the trainloader

commit 7e839ee
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 11 14:30:23 2024 +0100

    Sharding compute_metrics

commit 6c9385e
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 10 23:56:15 2024 +0100

    Fully removed TP code

commit f185723
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 10 22:22:12 2024 +0100

    Don't donate model

commit 3a9b7b3
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 10 01:03:55 2024 +0100

    reduce comms overhead

commit f3c761d
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 9 23:53:23 2024 +0100

    Turn off TP. OnlyDDP

commit 66bde33
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 8 01:04:26 2024 +0100

    precasting the data to jax arrays

commit 894ffc2
Author: Neel Gupta <[email protected]>
Date:   Mon Apr 8 00:34:32 2024 +0100

    report devices for tuning hyperparams

commit 96efca6
Author: Neel Gupta <[email protected]>
Date:   Sun Apr 7 02:11:54 2024 +0100

    Support for DDP + TP

commit 24accfd
Author: Neel Gupta <[email protected]>
Date:   Sat Apr 6 00:40:36 2024 +0100

    FSDP -> DDP. Dataset length autodetection

commit 7f418f2
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 5 23:25:47 2024 +0100

    all gather sharded model before serializing it

commit 7822b11
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 5 21:33:04 2024 +0100

    cast to jnp array

commit b86b75d
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 5 21:02:16 2024 +0100

    fixed FSDP not doing DP on `compute_metrics`

commit e31cf49
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 5 14:21:24 2024 +0100

    disabling DP in metrics computation

commit 9f0ab4d
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 5 14:12:13 2024 +0100

    sharding metrics computation to be fsdp

commit b81e458
Author: Neel Gupta <[email protected]>
Date:   Fri Apr 5 10:37:25 2024 +0100

    Initial FSDP code

commit 2b942b0
Author: Neel Gupta <[email protected]>
Date:   Thu Apr 4 15:17:49 2024 +0100

    reduced sweep dataset slice

commit 97e8ecd
Author: Neel Gupta <[email protected]>
Date:   Wed Apr 3 14:24:01 2024 +0100

    speedup all_gather

commit 879dc8f
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 2 18:38:04 2024 +0530

    put the LN back

commit 5c180ff
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 2 16:04:19 2024 +0530

    removed LN and fixed sampling

commit 8eecef2
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 2 01:42:54 2024 +0530

    double the sweeps data

commit 12d23b2
Author: Neel Gupta <[email protected]>
Date:   Tue Apr 2 01:15:33 2024 +0530

    Some optimizations + filter_spec-ing

commit 569d37d
Author: Neel Gupta <[email protected]>
Date:   Sun Mar 31 00:07:11 2024 +0530

    use vmap to reduce compilation time for react

commit b98629c
Author: Neel Gupta <[email protected]>
Date:   Sat Mar 30 11:16:42 2024 +0530

    remove tensorflow from dockerfile

commit 35f68b3
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 29 23:31:14 2024 +0530

    fixed baseline scan

commit 89cffb5
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 29 21:36:20 2024 +0530

    force format to numpy

commit 3c30560
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 29 18:01:04 2024 +0530

    Using 40% of data for sweeps ~600M tokens

commit 674740a
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 29 14:38:41 2024 +0530

    fixed IO and partially added vmap on baseline to fix compilation speeds

commit 28b0d02
Author: Neel Gupta <[email protected]>
Date:   Wed Mar 27 17:48:10 2024 +0530

    Refactored data pipeline to be more robust

commit 49d478e
Author: Neel Gupta <[email protected]>
Date:   Wed Mar 27 00:42:21 2024 +0530

    fix ambigous truth value bug

commit 13cfd9a
Author: Neel Gupta <[email protected]>
Date:   Tue Mar 26 22:57:55 2024 +0530

    Trying to fix the slow dataset split

commit 1da3893
Author: Neel Gupta <[email protected]>
Date:   Tue Mar 26 14:23:30 2024 +0530

    Added docker GH action

commit bcc224e
Author: Neel Gupta <[email protected]>
Date:   Tue Mar 26 14:03:44 2024 +0530

    Uploading dataset to HF & added GH action

commit 0437454
Author: Neel Gupta <[email protected]>
Date:   Mon Mar 25 04:12:58 2024 +0530

    Fixing group naming

commit c714a89
Author: Neel Gupta <[email protected]>
Date:   Mon Mar 25 03:58:01 2024 +0530

    fixed warmup steps for hypertuning runs

commit 042580f
Author: Neel Gupta <[email protected]>
Date:   Mon Mar 25 03:39:41 2024 +0530

    Updated dataset length

commit 825b872
Author: Neel Gupta <[email protected]>
Date:   Mon Mar 25 01:21:23 2024 +0530

    reduced hypertuning epochs

commit d8001e7
Author: Neel Gupta <[email protected]>
Date:   Sun Mar 24 16:06:47 2024 +0530

    Caching datasets and optimizing dataset preprocessing

commit f9291b8
Author: Neel Gupta <[email protected]>
Date:   Sun Mar 24 00:16:24 2024 +0530

    Switched to fast tokenizer

commit ce5c829
Author: Neel Gupta <[email protected]>
Date:   Sat Mar 23 17:00:37 2024 +0530

    Sped up data preprocessing

commit c1693e6
Author: Neel Gupta <[email protected]>
Date:   Sat Mar 23 14:32:07 2024 +0530

    turned off multiproc for first map

commit cca2f8c
Author: Neel Gupta <[email protected]>
Date:   Sat Mar 23 13:16:24 2024 +0530

    force stop multiprocessing for 1st map

commit 6af5617
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 22 23:01:18 2024 +0530

    trying to optimize dataset preproc step

commit ebf7628
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 22 21:14:25 2024 +0530

    remove JIT

commit 8fa3de1
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 22 14:48:53 2024 +0530

    Added minipile dataset

commit 5cc0270
Author: Neel Gupta <[email protected]>
Date:   Thu Mar 21 21:17:40 2024 +0530

    Added OpenWebText dataset support

commit 19877a4
Author: Neel Gupta <[email protected]>
Date:   Wed Mar 20 22:51:04 2024 +0530

    hardcoded k=5

commit 3f207b8
Author: Neel Gupta <[email protected]>
Date:   Wed Mar 20 21:49:38 2024 +0530

    switching to iters_fwd

commit 545e235
Author: Neel Gupta <[email protected]>
Date:   Tue Mar 19 17:28:30 2024 +0530

    Fixed React forward pass not carrying `interim_thought`. Back to n_k

commit e508e98
Author: Neel Gupta <[email protected]>
Date:   Mon Mar 18 23:59:45 2024 +0530

    Added run.sh

commit dbdc3cf
Author: Neel Gupta <[email protected]>
Date:   Mon Mar 18 23:59:18 2024 +0530

    Reverting back to k-only

commit 554898f
Author: Neel Gupta <[email protected]>
Date:   Sun Mar 17 15:29:32 2024 +0530

    Turned out inference-time dropout

commit 703aa3d
Author: Neel Gupta <[email protected]>
Date:   Sun Mar 17 01:53:00 2024 +0530

    Fixed dropout API for baseline model as well

commit 43691e7
Author: Neel Gupta <[email protected]>
Date:   Fri Mar 15 15:34:15 2024 +0530

    Get back vanilla n_k loop

commit 695f1f0
Author: Neel Gupta <[email protected]>
Date:   Thu Mar 14 22:29:18 2024 +0530

    Print hosts, and refresh keys every 75 steps

commit c636d6b
Author: Neel Gupta <[email protected]>
Date:   Wed Mar 13 01:54:31 2024 +0530

    Added scalax for multi-node support

commit 382d3be
Author: Neel Gupta <[email protected]>
Date:   Thu Mar 7 12:29:20 2024 +0000

    Rought draft of initial multi-node setup

commit ffbff6e
Author: Neel Gupta <[email protected]>
Date:   Thu Feb 22 00:25:08 2024 +0000

    Setting iters_fwd the default and updating n+k by making it less dynamic

commit 84dbfea
Author: Neel Gupta <[email protected]>
Date:   Sun Feb 18 19:34:23 2024 +0000

    Integrating equinoxs while_loop with checkpointing

commit 3cba47d
Author: Neel Gupta <[email protected]>
Date:   Fri Feb 16 23:11:56 2024 +0000

    Fixed the n-pass recipie and other other improvements

commit 1427c61
Author: Neel Gupta <[email protected]>
Date:   Tue Feb 13 11:21:56 2024 +0000

    Removed the history in react forward pass + other bells and whistles

commit c267773
Author: Neel Gupta <[email protected]>
Date:   Fri Feb 9 11:22:21 2024 +0000

    Re-added n_k_loop for more experimentation

commit 39f03ff
Author: Neel Gupta <[email protected]>
Date:   Fri Feb 2 11:12:01 2024 +0000

    Slightly updated hypertuning search procedure

commit 73c48ea
Author: Neel Gupta <[email protected]>
Date:   Mon Jan 29 10:41:49 2024 +0000

    Added NaN handling

commit db9d6c5
Author: Neel Gupta <[email protected]>
Date:   Fri Jan 26 22:50:00 2024 +0000

    Initial Optuna integration

commit fcf1965
Author: Neel Gupta <[email protected]>
Date:   Tue Jan 23 14:51:28 2024 +0000

    Removed alpha, fixed some bugs

commit 1a5450a
Author: Neel Gupta <[email protected]>
Date:   Mon Jan 22 23:33:21 2024 +0000

    Fixed hyperparameter tuning

commit 5c4af1c
Author: Neel Gupta <[email protected]>
Date:   Fri Jan 19 14:23:23 2024 +0000

    Added hyperband

commit 8604c24
Author: Neel Gupta <[email protected]>
Date:   Thu Jan 18 15:49:03 2024 +0000

    Added prelim sweep and hypertuning functionality

commit 88543a3
Author: Neel Gupta <[email protected]>
Date:   Fri Jan 12 17:07:20 2024 +0000

    why tf am I jax.checkpoint-ing still

commit 368a5a0
Author: Neel Gupta <[email protected]>
Date:   Wed Jan 10 12:42:55 2024 +0000

    Switch to Lion optimizer

commit 21e7a7d
Author: Neel Gupta <[email protected]>
Date:   Mon Jan 8 22:37:29 2024 +0000

    optimized throughput by 15% and added LR logging
  • Loading branch information
neel04 committed Apr 28, 2024
1 parent 43e80ff commit 849c987
Show file tree
Hide file tree
Showing 12 changed files with 242 additions and 139 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/docker-image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ name: Docker Image CI
on:
push:
branches: [ "dev" ]
paths:
- '**Dockerfile**'
pull_request:
branches: [ "main" ]
paths:
- '**Dockerfile**'

jobs:
build:
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,5 @@ cython_debug/
**wandb
*.wandb
*.eqx
*cached_data
*cached_data
*pyrightconfig.json
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ RUN pip3 install numpy pandas scipy

RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration
RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration plotly

WORKDIR /ReAct_Jax

Expand Down
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python3 inferencer.py --checkpoint_path '/Users/neel/Documents/research/ReAct_Ja
First, get a preemptible TPUv4-8 node as a queued resource:

```bash
gcloud alpha compute tpus queued-resources create node-v4 \
gcloud alpha compute tpus queued-resources create $INSTANCE_NAME \
--node-id node-v4 \
--project react-jax \
--zone us-central2-b \
Expand All @@ -38,7 +38,7 @@ gcloud alpha compute tpus queued-resources create node-v4 \
Setup the TPU pod slice with basics:

```bash
gcloud compute tpus tpu-vm ssh node-v4 \
gcloud compute tpus tpu-vm ssh $INSTANCE_NAME \
--zone=us-central2-b --worker=all --command="\
sudo apt-get update; \
sudo snap install nvim --classic; \
Expand All @@ -49,7 +49,7 @@ gcloud compute tpus tpu-vm ssh node-v4 \
And then actually kickoff the training by downloading the script and running it:

```bash
gcloud compute tpus tpu-vm ssh node-v4 \
gcloud compute tpus tpu-vm ssh $INSTANCE_NAME \
--zone=us-central2-b --worker=all --command="\
tmux kill-server; sudo rm -rf ./*; \
sleep 3s && wget https://gist.githubusercontent.com/neel04/3bfc7e4d9cd746829b7e72f1b6fac5de/raw/run.sh; \
Expand All @@ -59,7 +59,11 @@ gcloud compute tpus tpu-vm ssh node-v4 \
If you get errors regarding workers not being able to sync up at the distributed barrier, do:

```bash
gcloud compute tpus tpu-vm ssh --zone "us-central2-b" "ondem" \
--project "react-jax" \
--command 'sudo docker system prune -f && sudo rm -rf ~/.cache;'
gcloud compute tpus tpu-vm ssh --zone "us-central2-b" $INSTANCE_NAME --worker 'all' --project "react-jax" --command 'sudo docker system prune -f && sudo rm -rf ~/.cache;'
```

If Docker is unresponsive, just restart docker service:

```bash
gcloud compute tpus tpu-vm ssh --zone "us-central2-b" $INSTANCE_NAME --worker 'all' --project "react-jax" --command 'sudo systemctl restart docker'
```
14 changes: 4 additions & 10 deletions ReAct/data/minipile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import datasets
import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset, load_from_disk
from jaxtyping import Array
Expand All @@ -15,6 +14,7 @@ class MiniPileDataset:
def __init__(self, split: str = 'train', max_length: int = 512, bsz: int = 256, vocab_dir: str ='./ReAct/data'):
datasets.config.IN_MEMORY_MAX_SIZE = 1e+11

self.cpus = jax.devices("cpu")
self.bsz = bsz
self.max_length = max_length + 1
self.split = split
Expand Down Expand Up @@ -93,12 +93,6 @@ def take_subset(self, dataset, elements: int) -> None:

return dataset

def numpify(self, dataset: datasets.Dataset) -> datasets.Dataset:
'''
Convert the dataset to numpy arrays
'''
return jax.tree_map(lambda x: jnp.asarray(x), dataset['text'])

def create_dataloader(self, slice: str = '100%'):
data_path = Path(f'./cached_data/minipile_{self.split}.data')

Expand All @@ -110,13 +104,13 @@ def create_dataloader(self, slice: str = '100%'):

dataset.set_format(type='numpy')

return self.numpify(dataset)
return dataset

except (FileNotFoundError, ValueError):
if os.path.exists(data_path):
print(f'Loading dataset from {data_path}...')
dataset = self.load_data(data_path)
return self.numpify(dataset)
return dataset
else:
print(f'Building dataset from scratch... [split: {self.split}] | [bsz: {self.bsz}]')

Expand All @@ -139,4 +133,4 @@ def create_dataloader(self, slice: str = '100%'):
self.upload_dataset(dataset,
hub_path=f'Neel-Gupta/minipile-processed_{self.bsz}') # upload the processed dataset to the Hub

return self.numpify(dataset)
return dataset
3 changes: 2 additions & 1 deletion ReAct/model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def __init__(self,

def __call__(self,
input: BFloat16[Array, 'batch in_dim'],
**kwargs):
**kwargs) -> Array:

mask = kwargs.get('mask', None)
mask = jnp.ones_like(self.weight) if mask is None else mask
output = input @ (self.weight * mask.astype(input.dtype)) + self.bias

return output

class LiteAttention(eqx.Module):
Expand Down
81 changes: 50 additions & 31 deletions ReAct/model/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,71 @@
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray, PyTree

from .blocks import AttentionBlock, LinearProj, LiteAttention
from .blocks import AttentionBlock, LinearProj, LiteAttention, MLP

class RecurrentModule(eqx.Module):
'''
Bunch of AttentionBlocks
Bunch of AttentionBlocks in a pseuo-LSTM fashion
'''
num_blocks: int = eqx.field(static=True)

attention_blocks: PyTree[AttentionBlock]
reshape_layer: eqx.Module
forget_gate: LinearProj
ctx_gate: LinearProj
reshape_layer: LinearProj

def __init__(self,
seqlen: int,
drop_rate: float,
n_heads: int,
num_blocks: int,
bottleneck: int,
key: PRNGKeyArray): # noqa: E501

key: PRNGKeyArray):

self.num_blocks = num_blocks
keys = jax.random.split(key, num_blocks)

make_block: callable = lambda k: AttentionBlock( # noqa: E731
seqlen, n_heads, drop_rate, bottleneck, k
)

self.reshape_layer = LinearProj(bottleneck * 2, bottleneck, key=key)
self.forget_gate = MLP(bottleneck, bottleneck, p=drop_rate, key=key)
self.ctx_gate = MLP(bottleneck, bottleneck, p=drop_rate, key=key)

make_block: callable = lambda k: AttentionBlock(seqlen, n_heads, drop_rate, bottleneck, k) # noqa: E731
self.attention_blocks = eqx.filter(eqx.filter_vmap(make_block)(keys), eqx.is_array_like)

def __call__(self, x: Array,

def __call__(self,
x: Array,
input_arr: Array,
pad_mask: Array,
enable_dropout: bool,
key: PRNGKeyArray) -> Array:
key: PRNGKeyArray) -> Tuple[Array, Array]:

enable_dropout: bool = True
key: PRNGKeyArray = key

keys = jax.random.split(key, self.num_blocks)
dynamic_part, static_part = eqx.partition(self.attention_blocks, eqx.is_array_like,
is_leaf=lambda x: isinstance(x, eqx.nn.Dropout))

x = self.reshape_layer(x) # (batch, seqlen, bottleneck)
x = self.reshape_layer(x) # (seqlen, width * 2) -> (seqlen, width)

def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, int], int]:
input_arr, idx = input_tup # i is the iteration index
def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, int], None]:
x, idx = input_tup # i is the iteration index

block = eqx.combine(_dynamic_bl, static_part) # reconstruct the block
output = block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16) # self-attention

return (output, idx + 1), None
x = jax.lax.cond(idx == 0,
lambda: block(x, input_arr, pad_mask, enable_dropout, keys[idx]),
lambda: block(x, x, pad_mask, enable_dropout, keys[idx]))

return (x, idx + 1), x

out = eqx.internal.scan(f=f, init=(input_arr, 0), xs=dynamic_part, kind='lax')[0][0] # throw away idx
out, history = eqx.internal.scan(f=f, init=(x, 0), xs=dynamic_part, kind='lax')
history = history.mean(0)

input_arr *= jax.nn.sigmoid(self.forget_gate(history, True, key))
input_arr += self.ctx_gate(history, True, key)

return out
return out[0], input_arr

class React(eqx.Module):
'''
Expand All @@ -63,7 +79,6 @@ class React(eqx.Module):

max_iters: int = eqx.field(static=True)

iters_weights: Array
pos_enc: Array
embed_layer: eqx.nn.Embedding
main_block: LiteAttention
Expand Down Expand Up @@ -110,27 +125,31 @@ def positional_encoding(self, seq_len, d_model):
@eqx.filter_jit
def iterate_for_steps(self,
interim_thought: Array,
input_arr: Array,
mask: Array,
iters_to_do: int,
input_arr: Array,
enable_dropout: bool,
key: PRNGKeyArray) -> Array:

# These are constants
# Declaring constants
input_arr = input_arr.astype(jnp.bfloat16)
interim_thought = interim_thought.astype(jnp.bfloat16)
mask = mask.astype(jnp.bfloat16)

def body_fun(thought: Array, _) -> Tuple[PyTree, Array]:
latent = jnp.concatenate([thought, input_arr], axis=-1).astype(jnp.bfloat16)
latent = self.main_block(latent, input_arr, mask, enable_dropout, key).astype(jnp.bfloat16)
latent = jax.vmap(self.post_ln)(latent).astype(jnp.bfloat16) # LN to keep scales tidy
def body_fun(carry: Tuple[Array, Array], idx: int) -> Tuple[Array, Array]:
thought, ctx_state = carry

latent = jnp.concatenate([input_arr, thought], axis=-1) # (seqlen, width * 2)
latent, ctx_state = self.main_block(latent, ctx_state, mask, enable_dropout, key) # (seqlen, width)
latent = jax.vmap(self.post_ln)(latent) # Post-LN for stability

return (latent, ctx_state), ctx_state

return latent, latent
final_val, history = eqx.internal.scan(
f=body_fun, init=(interim_thought, input_arr), xs=jnp.arange(5), kind="checkpointed"
)

final_val, _ = eqx.internal.scan(f=body_fun, init=interim_thought, xs=None, length=5, kind='checkpointed')
#return jnp.einsum('i j k, i -> j k', history, self.iters_weights) # dot-product with iters_weights
return final_val
return final_val[0]

@eqx.filter_jit
def __call__(self,
Expand All @@ -149,6 +168,6 @@ def __call__(self,
input_arr = jax.vmap(self.embed_layer)(input_arr) + self.pos_enc # (batch, seqlen, bottleneck)
interim_thought = input_arr.copy() # has to be a copy of the embedded + projected input array

output = self.iterate_for_steps(interim_thought, pad_mask, iters_to_do, input_arr, is_training, key) # (batch, seqlen, bottleneck)
output = self.iterate_for_steps(interim_thought, input_arr, pad_mask, iters_to_do, is_training, key) # (batch, seqlen, bottleneck)

return self.out_head(output), output
67 changes: 53 additions & 14 deletions ReAct/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,61 @@
import math
import os
from typing import Callable, Optional, Tuple

import equinox as eqx
import jax
import jax.numpy as jnp

from jax import tree_util as jtu
from jaxtyping import Array, PRNGKeyArray
from typing import Optional

def convert_flops(params: int) -> str:
if params == 0:
return "0"

size_name = ("", "KFLOPs", "MFLOPs", "GFLOPs", "TFLOPs", "PFLOPs", "EFLOPs", "ZFLOPs", "YFLOPs")
i = int(math.floor(math.log(params, 1000)))
p = math.pow(1000, i)
s = round(params / p, 2)

return "%s %s" % (s, size_name[i])

def calc_performance_metrics(args, my_logger: Callable) -> None:
'''
Estimates FLOPs consumed during a single fwd + bwd pass.
Taken from EleutherAI's GPT-NeoX repo: https://rb.gy/33d6zg
Returns: the total number of FLOPs
'''
iter_factor = 3
args.tokens = args.batch_size * args.seqlen
args.kv_size_ratio = 1

my_logger.warning('! Ignoring activation checkpointing in FLOPs calculation !')

qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_blocks * args.tokens * args.width * args.width)
attention_matrix_flops = iter_factor * 2 * args.num_blocks * args.tokens * args.seqlen * args.width
attention_over_values_flops = iter_factor * 2 * args.num_blocks * args.tokens * args.seqlen * args.width
linear_projection_flops = iter_factor * 2 * args.num_blocks * args.tokens * args.width * args.width
ffn_flops = iter_factor * 16 * args.num_blocks * args.tokens * args.width * args.width

# handle NewGELU
ffn_flops *= 3.75

embedding_flops = 6 * args.tokens * args.width * args.num_classes
total_flops = qkv_flops + attention_matrix_flops + attention_over_values_flops + linear_projection_flops + ffn_flops + embedding_flops
my_logger.info(f"Total FLOPs for the Model: {convert_flops(total_flops)} for a single fwd + bwd pass\n")


def xla_calc_flops(fn: Callable, static_argnums: Tuple[int], args: Tuple, my_logger: Callable) -> None:
'''
Estimates FLOPs consumed during `fn` execution.
Use's XLA HLO analysis to estimate FLOPs.
Returns: the total number of FLOPs
'''
compiled = jax.jit(fn, static_argnums=static_argnums).lower(*args).compile()
flops = compiled.cost_analysis()[0]['flops']
my_logger.info(f"XLA estimate of Total FLOPs for {fn.__name__}: {convert_flops(int(flops))}\n")

def half_precision(model: eqx.Module) -> eqx.Module:
return jtu.tree_map(lambda x: x.astype(jnp.bfloat16) if eqx.is_inexact_array(x) else x, model)
Expand Down Expand Up @@ -59,15 +109,4 @@ def inverted_freq(arr: Array):

inv_weights = (counts.max() / counts) # scale it down

return inv_weights[arr - arr.min()]

if __name__ == '__main__':
import plotly.express as px
import pandas as pd

key = jax.random.PRNGKey(0)
out: Array = get_rand_nums(key, 1, 10, 512, 4)
elems, counts = jnp.unique(out, return_counts=True)
df = pd.DataFrame({'elems': elems, 'counts': counts})
fig = px.bar(df, x='elems', y='counts')
fig.show()
return inv_weights[arr - arr.min()]
6 changes: 3 additions & 3 deletions ReAct/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def wandb_logger(self, args: dict):
if args.resume:
# args.resume is of the form: "neel/ReAct_Jax/lxxn0x54 + 20"
# we want to extract the run id, i.e "lxxn0x54"
id = args.resume.split("+")[0].split("/")[-1].strip()
wandb_id = args.resume.split("+")[0].split("/")[-1].strip()
else:
id = None
wandb_id = None

wandb.init(project='ReAct_Jax',
config=args,
group=args.group,
mode='online' if jax.process_index() == 0 and args.exp_logging else 'offline',
resume='allow',
id=id,
id=wandb_id,
reinit=True)

wandb.run.log_code(
Expand Down
Loading

0 comments on commit 849c987

Please sign in to comment.