Skip to content

Commit 442513c

Browse files
[JAX] Add tutorial for integrating TE/JAX quantization into an existing framework (#2423)
* Tutorial for integration te/jax quantization into an existing framework Signed-off-by: Jeremy Berchtold <[email protected]> * add todos Signed-off-by: Jeremy Berchtold <[email protected]> * support nvfp4 sr rng key, move wrapper module into TE itself, fix bfloat16 cast Signed-off-by: Jeremy Berchtold <[email protected]> * update docstrings Signed-off-by: Jeremy Berchtold <[email protected]> * Fix QKV proj and out proj in Flax example transformer layer Signed-off-by: Jeremy Berchtold <[email protected]> * Use fused attention in quickstart_jax example Signed-off-by: Jeremy Berchtold <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remat policy Signed-off-by: Jeremy Berchtold <[email protected]> * add tutorial to docs Signed-off-by: Jeremy Berchtold <[email protected]> * update title Signed-off-by: Jeremy Berchtold <[email protected]> * remove unused dtype from TE DPA module Signed-off-by: Jeremy Berchtold <[email protected]> * Fix notebook title Signed-off-by: Jeremy Berchtold <[email protected]> * Fix lint Signed-off-by: Jeremy Berchtold <[email protected]> * Add explanation of flax module wrapper Signed-off-by: Jeremy Berchtold <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5c2f2ff commit 442513c

File tree

9 files changed

+688
-109
lines changed

9 files changed

+688
-109
lines changed

docs/examples/quickstart_jax.ipynb

Lines changed: 76 additions & 85 deletions
Large diffs are not rendered by default.

docs/examples/quickstart_jax_utils.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ def speedometer(
1919
variables: Any,
2020
input: jnp.ndarray,
2121
output_grad: jnp.ndarray,
22-
dropout_key: jax.random.PRNGKey,
2322
model_init_fn: Callable = None,
2423
forward_kwargs: dict = {},
2524
autocast_kwargs: Optional[dict] = None,
2625
timing_iters: int = 50,
2726
warmup_iters: int = 50,
27+
rngs: Dict[str, jax.random.PRNGKey] = None,
2828
) -> None:
2929
"""Measure average runtime for a JAX module
3030
Perform forward and backward passes .
@@ -33,19 +33,21 @@ def speedometer(
3333
autocast_kwargs = {"enabled": False}
3434
model_init_fn = None
3535

36+
if rngs is None:
37+
rngs = {}
38+
3639
train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)
3740

3841
# Warm up runs
39-
key = dropout_key
4042
for _ in range(warmup_iters):
41-
key, step_key = jax.random.split(key)
42-
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
43+
rngs, step_rngs = _split_step_rngs(rngs)
44+
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
4345

4446
# Timing runs
4547
start = time.time()
4648
for _ in range(timing_iters):
47-
key, step_key = jax.random.split(key)
48-
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
49+
rngs, step_rngs = _split_step_rngs(rngs)
50+
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
4951
end = time.time()
5052

5153
print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
@@ -63,8 +65,12 @@ def create_train_step_fn(
6365
if forward_kwargs is None:
6466
forward_kwargs = {}
6567

66-
def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
67-
rngs = {"dropout": dropout_key}
68+
def loss_fn(
69+
variables: Any,
70+
inp: jnp.ndarray,
71+
grad_target: jnp.ndarray,
72+
rngs: Dict[str, jax.random.PRNGKey],
73+
):
6874
with te.autocast(**autocast_kwargs):
6975
# Forward Pass: Apply the model using current parameters and variables
7076
call_kwargs = {**forward_kwargs, "rngs": rngs}
@@ -84,3 +90,16 @@ def fwd_bwd_fn(*args, **kwargs):
8490

8591
# JIT-compile the fwd_bwd_fn
8692
return jax.jit(fwd_bwd_fn)
93+
94+
95+
def _split_step_rngs(
96+
rngs: Dict[str, jax.random.PRNGKey],
97+
) -> Tuple[Dict[str, jax.random.PRNGKey], Dict[str, jax.random.PRNGKey]]:
98+
"""Splits each RNG in the rngs dictionary for a new step."""
99+
step_rngs = {}
100+
new_rngs = {}
101+
for name, key in rngs.items():
102+
new_key, step_key = jax.random.split(key)
103+
new_rngs[name] = new_key
104+
step_rngs[name] = step_key
105+
return new_rngs, step_rngs

0 commit comments

Comments
 (0)