Skip to content

Commit

Permalink
Added official JMP support for real mixed-precision
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed May 9, 2024
1 parent dc39d92 commit a6d53e8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ RUN pip3 install Ipython matplotlib
RUN pip3 install numpy pandas scipy

RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -U jaxlib
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install tensorboard-plugin-profile comet-ml optuna-integration plotly
RUN pip3 install git+https://github.com/deepmind/jmp

WORKDIR /ReAct_Jax

Expand Down
6 changes: 2 additions & 4 deletions ReAct/model/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def positional_encoding(self, seq_len: int, d_model: int):
of shape (batch_size, max_seq_len, d_model) which would be added
to the sequence embeddings.
'''
position = jnp.arange(seq_len, dtype=jnp.bfloat16).reshape(-1, 1)
div_term = jnp.exp(jnp.arange(0, d_model, 2, dtype=jnp.bfloat16) * -(jnp.log(10000.0) / d_model))
position = jnp.arange(seq_len).reshape(-1, 1)
div_term = jnp.exp(jnp.arange(0, d_model, 2) * -(jnp.log(10000.0) / d_model))
pe = jnp.zeros((seq_len, d_model))

pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
Expand All @@ -104,8 +104,6 @@ def __call__(self,
key: PRNGKeyArray) -> Array:

input_arr = jax.vmap(self.embed_layer)(input_arr) + self.pos_enc
input_arr = input_arr.astype(jnp.bfloat16)

output = self.main_block(input_arr, pad_mask, enable_dropout, key)

return self.out_head(output)
Expand Down
86 changes: 46 additions & 40 deletions ReAct/model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, BFloat16, PRNGKeyArray
from jmp import Policy

policy = Policy(compute_dtype=jnp.bfloat16, param_dtype=jnp.float32, output_dtype=jnp.bfloat16)

# ruff: noqa: F722
class AttentionBlock(eqx.Module):
Expand Down Expand Up @@ -47,21 +50,11 @@ def _make_self_attention_mask(self,

"""Create self-attention mask from sequence-level mask."""

mask = jnp.ones((self.seqlen, self.seqlen), dtype=jnp.bfloat16)
mask = jnp.ones((self.seqlen, self.seqlen))
mask = jnp.tril(mask)
mask = jnp.expand_dims(mask, 0)
return jnp.repeat(mask, self.n_heads, axis=0)

def _make_mixer_mask(self,
pad_mask: Array):

# Almost same, but we triu instead of tril
# and we don't need to merge with pad_mask
mask = jnp.ones((self.seqlen, self.seqlen)) * pad_mask
mask = jnp.triu(mask)

return mask

def __call__(self,
inp: BFloat16[Array, 'seqlen bottleneck'],
input_arr: Array,
Expand All @@ -70,7 +63,7 @@ def __call__(self,
key: PRNGKeyArray):

key_1, key_2 = jax.random.split(key, 2)
inp = inp.astype(jnp.bfloat16)
inp, input_arr, mask = policy.cast_to_compute((inp, input_arr, mask))

x = jax.vmap(self.ln1)(inp)
inp += self.attn_gate(x, input_arr, input_arr,
Expand All @@ -80,19 +73,26 @@ def __call__(self,
x = jax.vmap(self.ln2)(inp)
inp += self.mlp(x, enable_dropout=True, key=key_2)

return inp.astype(jnp.bfloat16)
return policy.cast_to_output(inp)


class NewGELU(eqx.Module):
def __call__(self, x: jax.Array, *args) -> jax.Array:
c: float = math.sqrt(2.0 / math.pi)
a: float = 0.044715
return 0.5 * x * (1.0 + jax.nn.tanh(c * (x + a * jnp.power(x, 3.0))))

x = policy.cast_to_compute(x)
output = 0.5 * x * (1.0 + jax.nn.tanh(c * (x + a * jnp.power(x, 3.0))))

return policy.cast_to_output(output)

class MLP(eqx.Module):
'''A simple MLP - w/ Dropout'''
layers: eqx.nn.Sequential

layer_1: eqx.Module
layer_2: eqx.Module
dropout: eqx.nn.Dropout
act: callable

def __init__(self,
input_dim: int,
Expand All @@ -102,11 +102,9 @@ def __init__(self,

key1, key2 = jax.random.split(key, 2)

self.layers = [
LinearProj(input_dim, output_dim * 4, key=key1),
eqx.nn.Lambda(NewGELU()),
LinearProj(output_dim * 4, output_dim, key=key2),
]
self.layer_1 = LinearProj(input_dim, output_dim * 4, key=key1)
self.layer_2 = LinearProj(output_dim * 4, output_dim, key=key2)
self.act = NewGELU()

self.dropout = eqx.nn.Dropout(p=p)

Expand All @@ -115,10 +113,14 @@ def __call__(self,
enable_dropout: bool,
key: PRNGKeyArray):

for layer in self.layers:
x = layer(x).astype(jnp.bfloat16)
x = policy.cast_to_compute(x)

return self.dropout(x, key=key, inference=enable_dropout).astype(jnp.bfloat16)
x = self.act(self.layer_1(x))
x = self.layer_2(x)

output = self.dropout(x, key=key, inference=enable_dropout)

return policy.cast_to_output(self.act(output))

class GatedMLP(eqx.Module):
'''
Expand Down Expand Up @@ -149,13 +151,15 @@ def __init__(self,
self.activation = NewGELU()

def __call__(self, arr: Array) -> Array:

x = policy.cast_to_compute(arr)

x = self.activation(self.up_proj(arr))
x = jax.vmap(self.ln_1)(x)
x = self.down_proj(x * jax.nn.silu(self.gate(arr)))
x = jax.vmap(self.ln_2)(x)
x = self.activation(jax.vmap(self.ln_2)(x))

return self.activation(x)
return policy.cast_to_output(x)

class LinearProj(eqx.Module):
bias: Optional[jax.Array]
Expand All @@ -179,22 +183,24 @@ def __init__(self,
self.use_bias = use_bias

lim = 1 / math.sqrt(input_dim)
self.weight = jax.random.uniform(wkey, (input_dim, output_dim), minval=-lim, maxval=lim).astype(jnp.bfloat16)
self.weight = jax.random.uniform(wkey, (input_dim, output_dim), minval=-lim, maxval=lim)

if use_bias:
self.bias = jax.random.uniform(bkey, (output_dim,), minval=-lim, maxval=lim).astype(jnp.bfloat16)
self.bias = jax.random.uniform(bkey, (output_dim,), minval=-lim, maxval=lim)
else:
self.bias = jnp.zeros((output_dim,)).astype(jnp.bfloat16)
self.bias = jnp.zeros((output_dim,))

def __call__(self,
input: BFloat16[Array, 'batch in_dim'],
arr: BFloat16[Array, 'batch in_dim'],
mask: Optional[Array] = None,
**kwargs) -> Array:

mask = kwargs.get('mask', None)
arr, mask = policy.cast_to_compute((arr, mask))

mask = jnp.ones_like(self.weight) if mask is None else mask
output = input @ (self.weight * mask.astype(input.dtype)) + self.bias
output = arr @ (self.weight * mask.astype(arr.dtype)) + self.bias

return output
return policy.cast_to_output(output)

class LiteAttention(eqx.Module):
input_dim: int = eqx.field(static=True)
Expand All @@ -206,8 +212,10 @@ def __init__(self, input_dim: int, key: PRNGKeyArray):

@jax.jit
def __call__(self, x: BFloat16[Array, 'seqlen in_dim'], mask: Array):
mask, x = policy.cast_to_compute((mask, x))
attn_weights = jax.nn.softmax(self.weight(x.T, mask), axis=1) # type: ignore
return x * attn_weights.T
output = x * attn_weights.T
return policy.cast_to_output(output)

class MixerBlock(eqx.Module):
'''
Expand All @@ -231,14 +239,12 @@ def __init__(self, input_dim: int, seqlen: int, drop_rate: float, key: PRNGKeyAr
self.token_mixer = LinearProj(seqlen, seqlen, key=key2)

def __call__(self, x: BFloat16[Array, 'seqlen in_dim'], mask: Array, key: PRNGKeyArray):
x, mask = policy.cast_to_compute((x, mask))

arr = x.T
arr = self.act_1(self.token_mixer(arr, key=key, mask=mask))
arr = jax.vmap(self.norm)(arr.T)
x = x + arr
return x + self.act_2(self.channel_mixer(arr, key))

if __name__ == '__main__':
key = jax.random.PRNGKey(0)
LA = LiteAttention(256, key)
test = jax.random.normal(key, (128, 256))
print(LA(test).shape)
output = x + self.act_2(self.channel_mixer(arr, key))

return policy.cast_to_output(output)
23 changes: 12 additions & 11 deletions ReAct/model/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray, PyTree
from jmp import Policy

from .blocks import MLP, AttentionBlock, LinearProj, LiteAttention, NewGELU, GatedMLP
from .blocks import MLP, AttentionBlock, GatedMLP, LinearProj, LiteAttention, NewGELU

# ruff: noqa: E402, E731
policy = Policy(compute_dtype=jnp.bfloat16, param_dtype=jnp.float32, output_dtype=jnp.bfloat16)

# ruff: noqa: E402, E731
class RecurrentModule(eqx.Module):
'''
Bunch of Attentionlayers in a pseuo-LSTM fashion
Expand Down Expand Up @@ -77,7 +79,7 @@ def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, i

#ctx_state *= jax.nn.sigmoid(self.forget_gate(hist_lerp, enable_dropout, key))
#ctx_state += self.ctx_gate(hist_lerp, enable_dropout, key)
ctx_state += self.ctx_gate(hist_lerp)
ctx_state = self.ctx_gate(hist_lerp)

return out[0], ctx_state

Expand Down Expand Up @@ -113,7 +115,7 @@ def __init__(self,
self.pos_enc = jax.lax.stop_gradient(self.positional_encoding(seqlen, width))

self.main_block = RecurrentModule(seqlen, drop_rate, n_heads, num_blocks, width, key=key2)
self.alpha = jnp.array([0.5], dtype=jnp.bfloat16)
self.alpha = jnp.array([0.5])

self.post_ln = eqx.nn.LayerNorm(width)
self.out_head = LinearProj(width, vocab_size, key=key4)
Expand All @@ -124,8 +126,8 @@ def positional_encoding(self, seq_len, d_model):
of shape (batch_size, max_seq_len, d_model) which would be added
to the sequence embeddings.
'''
position = jnp.arange(seq_len, dtype=jnp.bfloat16).reshape(-1, 1)
div_term = jnp.exp(jnp.arange(0, d_model, 2, dtype=jnp.bfloat16) * -(jnp.log(10000.0) / d_model))
position = jnp.arange(seq_len).reshape(-1, 1)
div_term = jnp.exp(jnp.arange(0, d_model, 2) * -(jnp.log(10000.0) / d_model))
pe = jnp.zeros((seq_len, d_model))

pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
Expand All @@ -142,17 +144,14 @@ def iterate_for_steps(self,
enable_dropout: bool,
key: PRNGKeyArray) -> Array:

# Declaring constants
input_arr = input_arr.astype(jnp.bfloat16)
interim_thought = interim_thought.astype(jnp.bfloat16)
mask = mask.astype(jnp.bfloat16)

def body_fun(carry: Tuple[Array, Array], idx: int) -> Tuple[Tuple, Array]:
thought, ctx_state = carry

latent = jnp.concatenate([input_arr, thought], axis=-1) # (seqlen, width * 2)
latent, ctx_state = self.main_block(latent, ctx_state, mask, enable_dropout, key) # (seqlen, width)
latent = jax.vmap(self.post_ln)(latent) # Post-LN for stability

latent = policy.cast_to_output(latent) # mixed precision

return (latent, ctx_state), latent

Expand All @@ -179,6 +178,8 @@ def __call__(self,
input_arr = jax.vmap(self.embed_layer)(input_arr) + self.pos_enc # (batch, seqlen, bottleneck)
interim_thought = input_arr.copy() # has to be a copy of the embedded + projected input array

input_arr, interim_thought = policy.cast_to_compute((input_arr, interim_thought))

output = self.iterate_for_steps(interim_thought, input_arr, pad_mask, iters_to_do, is_training, key) # (batch, seqlen, bottleneck)

return self.out_head(output), output
8 changes: 6 additions & 2 deletions ReAct/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import optax
import optuna
from jaxtyping import Array, PRNGKeyArray, PyTree
from jmp import Policy
from scalax.sharding import MeshShardingHelper
from scalax.sharding import PartitionSpec as P
from torch.utils.data import DataLoader
Expand All @@ -18,9 +19,10 @@
from ReAct.model.react import React
from ReAct.utils.helpers import count_params, load_eqx_obj, save_eqx_obj

from .helpers import broad_to_bsz, calc_performance_metrics, half_precision
from .helpers import broad_to_bsz, calc_performance_metrics

mesh = MeshShardingHelper(axis_dims=[-1], axis_names=['data']) # handle DDP + TP over multi-node
policy = Policy(compute_dtype=jnp.bfloat16, param_dtype=jnp.float32, output_dtype=jnp.bfloat16)

@eqx.filter_jit
def n_k_loop(model: eqx.Module, input_arr: Array, pad_mask: Array, n: Array, k: Array, key: PRNGKeyArray) -> Array:
Expand Down Expand Up @@ -117,6 +119,7 @@ def compute_loss(model: eqx.Module, static_model: PyTree, x: Array, y: Array, pa
is_leaf=lambda x: isinstance(x, eqx.nn.Dropout))

loss, grads = compute_loss(diff_model, static_model, x, y, pad_mask, iters_to_do, num_classes, keys)
grads = policy.cast_to_compute(grads) # cast to float32
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)

Expand Down Expand Up @@ -236,7 +239,7 @@ def init_model(self, key: PRNGKeyArray):

# switch to half precision
if self.bf16:
model = half_precision(model)
model = policy.cast_to_param(model)

_, opt_state, model = self.set_optim_and_scheduler(model)

Expand Down Expand Up @@ -326,6 +329,7 @@ def train(self, trial: Optional[Any] = None) -> Tuple[float, int]:
step += step_done # for multiple epochs

seq, label, pad_mask = jnp.asarray(batch['text'])
seq, label, pad_mask = policy.cast_to_compute((seq, label, pad_mask))

loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask,
self.max_iters, optim, self.num_classes, keys)
Expand Down

0 comments on commit a6d53e8

Please sign in to comment.