diff --git a/bonsai/models/qwen3/modeling.py b/bonsai/models/qwen3/modeling.py index 020262c0..52e7d078 100644 --- a/bonsai/models/qwen3/modeling.py +++ b/bonsai/models/qwen3/modeling.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import math from functools import partial from typing import Tuple, TypeAlias @@ -209,7 +210,10 @@ def __init__(self, cfg: ModelConfig, batch_size: int, cache_size: int, dtype: jn Cache: TypeAlias = list[LayerCache] -def init_cache(cfg: ModelConfig, batch_size: int, cache_size: int, dtype: jnp.dtype = jnp.bfloat16) -> Cache: +def init_cache( + cfg: ModelConfig, batch_size: int, token_len: int, generate_steps: int, dtype: jnp.dtype = jnp.bfloat16 +) -> Cache: + cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1))) # Pad for a sharding-friendly size. return [LayerCache(cfg, batch_size, cache_size, dtype) for _ in range(cfg.num_layers)] @@ -308,7 +312,7 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): self.scale = self.head_dim**-0.5 @jax.named_scope("attention") - def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, num_right_pads: int) -> Array: + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array) -> Array: query_proj = shard(self.q_norm(self.q_proj(x)), self.shd_cfg.act_btnh) key_proj = shard(self.k_norm(self.k_proj(x)), self.shd_cfg.act_btnh) value_proj = shard(self.v_proj(x), self.shd_cfg.act_btnh) @@ -341,7 +345,7 @@ def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, num_r attn_weights = jax.nn.softmax(attn_logits.astype(jnp.float32), axis=-1).astype(attn_logits.dtype) qkv = jnp.einsum("BHGTS,BSHD->BTHGD", attn_weights, cache.v_cache.value).reshape((b, t, qh, d)) - cache.cur_ind.value = cache.cur_ind.value + t - num_right_pads + cache.cur_ind.value = cache.cur_ind.value + t return shard(self.o_proj(qkv), self.shd_cfg.act_btd) @property @@ -380,9 +384,9 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): self.post_attention_layernorm = RMSNorm(cfg.emb_dim, cfg, rngs=rngs) self.mlp = MLP(cfg=cfg, rngs=rngs) - def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, num_right_pads: int) -> Array: + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array) -> Array: inputs_normalized = self.input_layernorm(x) - attn_output = x + self.attn(inputs_normalized, cache, segment_ids, num_right_pads) + attn_output = x + self.attn(inputs_normalized, cache, segment_ids) outputs = attn_output + self.mlp(self.post_attention_layernorm(attn_output)) return outputs @@ -396,10 +400,10 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): einsum_str="BTD,DV->BTV", shape=(cfg.emb_dim, cfg.vocab_size), shd=cfg.shd_cfg.emb_dv, rngs=rngs ) - def __call__(self, tokens, segment_ids, num_right_pads, cache): + def __call__(self, tokens, segment_ids, cache): x = self.embedder.encode(tokens) for i, layer in enumerate(self.layers): - x = layer(x, cache[i], segment_ids, num_right_pads) + x = layer(x, cache[i], segment_ids) logits = self.lm_head(self.final_norm(x)) return logits @@ -410,7 +414,6 @@ def forward( ) -> tuple[Array, nnx.State]: model, cache = nnx.merge(graphdef, state) segment_ids = 1 * (tokens != pad_id) - num_right_pads = count_right_pads(segment_ids, out_sharding=P(None)) - logits = model(tokens, segment_ids, num_right_pads, cache) + logits = model(tokens, segment_ids, cache) state = jax.tree.leaves(nnx.state((model, cache))) - return logits[:, -num_right_pads - 1], state + return logits[:, -1], state diff --git a/bonsai/models/qwen3/tests/run_model.py b/bonsai/models/qwen3/tests/run_model.py index badecdcb..160eff85 100644 --- a/bonsai/models/qwen3/tests/run_model.py +++ b/bonsai/models/qwen3/tests/run_model.py @@ -25,7 +25,7 @@ from transformers import AutoTokenizer from bonsai.models.qwen3 import modeling, params -from bonsai.utils import Sampler +from bonsai.utils import GreedySampler, Sampler def tokenize(tokenizer, input: list[str]): @@ -37,9 +37,8 @@ def tokenize(tokenizer, input: list[str]): for l in input ] lines = [tokenizer.encode(line) for line in lines] - max_l = max(len(line) for line in lines) # Right-align, left-padding to the max token length. - buffer_len = 2 ** math.ceil(math.log2(max(max_l, 1))) # Pad the sequence to power-of-two buffer length. - return jnp.array([np.pad(l, (max_l - len(l), buffer_len - max_l), constant_values=pad_idx) for l in lines]), max_l + max_len = max(len(line) for line in lines) # Right-align, left-padding to the max token length. + return jnp.array([np.pad(l, (max_len - len(l), 0), constant_values=pad_idx) for l in lines]) def run_model(): @@ -50,17 +49,20 @@ def run_model(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit)) jax.set_mesh(mesh) - query = ["Why is the sky blue instead of any other color like purple?", "Who am I?"] + query = [ + "Why is the sky blue instead of any other color like purple?", + "Who am I?", + "Tell me 10 flavors of ice creams.", + ] tokenizer = AutoTokenizer.from_pretrained(model_ckpt_path) - tokens, max_len = tokenize(tokenizer, query) - batch_size, _ = tokens.shape + tokens = tokenize(tokenizer, query) + batch_size, token_len = tokens.shape - cache_size, gen_steps = 128, 11 - assert cache_size >= max_len + gen_steps, f"Cache size ({cache_size}) must be >= {max_len} + {gen_steps}" + generate_steps = 256 + cache = modeling.init_cache(config, batch_size, token_len, generate_steps) model = params.create_model_from_safe_tensors(model_ckpt_path, config, mesh) - cache = modeling.init_cache(config, batch_size, cache_size) graphdef, state = nnx.split((model, cache)) state = jax.tree.leaves(state) # Better perf from flattened jax state due to no pytree trasversals. @@ -71,33 +73,26 @@ def run_model(): logits, state = modeling.forward(graphdef, state, tokens, tokenizer.pad_token_id) next_tokens = sampler(logits, key=key) - # decode - warmup + # decode tokens_list = [next_tokens] - logits, state = modeling.forward(graphdef, state, next_tokens, tokenizer.pad_token_id) - next_tokens = sampler(logits, key=key) - tokens_list.append(next_tokens) - - # profile - jax.profiler.start_trace("/tmp/profile-data") - for i in range(5): - logits, state = modeling.forward(graphdef, state, next_tokens, tokenizer.pad_token_id) - next_tokens = sampler(logits, key=key) - tokens_list.append(next_tokens) - jax.block_until_ready(tokens_list) - jax.profiler.stop_trace() - - decode_steps = 128 - for i in range(decode_steps): + finished = jnp.zeros((batch_size,), dtype=jnp.bool_) + for i in range(generate_steps): logits, state = modeling.forward(graphdef, state, next_tokens, tokenizer.pad_token_id) next_tokens = sampler(logits, key=key) + finished = finished | (next_tokens.squeeze(-1) == tokenizer.eos_token_id) tokens_list.append(next_tokens) + if finished.all(): + break - tokens_list = jnp.concatenate(tokens_list, axis=-1) + all_output_tokens = jax.device_get(jnp.concatenate(tokens_list, axis=-1)) for i, q in enumerate(query): print(f"User:\n {q}") - print( - f"Answer:\n {tokenizer.decode(tokens_list.at[i].get(out_sharding=P(None)), skip_special_tokens=True)}\n\n" - ) + seq_tokens = all_output_tokens[i] + eos_idx = np.where(seq_tokens == tokenizer.eos_token_id)[0] + if eos_idx.size > 0: + seq_tokens = seq_tokens[: eos_idx[0]] + decoded = tokenizer.decode(seq_tokens, skip_special_tokens=True) + print(f"Answer:\n {decoded}\n\n") if __name__ == "__main__": diff --git a/bonsai/models/qwen3/tests/test_outputs_qwen3.py b/bonsai/models/qwen3/tests/test_outputs_qwen3.py index 6e2121ae..9f85c371 100644 --- a/bonsai/models/qwen3/tests/test_outputs_qwen3.py +++ b/bonsai/models/qwen3/tests/test_outputs_qwen3.py @@ -86,9 +86,8 @@ def _nnx_forward_logits(self, cache: modeling.Cache, tokens: jax.Array, dtype: D """Forward pass for the nnx model""" segment_ids = 1 * (tokens != self.tokenizer.pad_token_id) x = self.nnx_model.embedder.encode(tokens).astype(dtype) - right_pads = modeling.count_right_pads(segment_ids, out_sharding=P(None)) for i, layer in enumerate(self.nnx_model.layers): - x = layer(x, cache[i], segment_ids, right_pads).astype(dtype) + x = layer(x, cache[i], segment_ids).astype(dtype) nnx_logits = self.nnx_model.lm_head(self.nnx_model.final_norm(x)) return nnx_logits @@ -112,7 +111,7 @@ def _process_hf_tokens(self, query: list[str]): def _init_nnx_cache(self, batch_size: int): return modeling.init_cache( - cfg=self.bonsai_config, batch_size=batch_size, cache_size=self.cache_size, dtype=jnp.float32 + cfg=self.bonsai_config, batch_size=batch_size, token_len=10, generate_steps=32, dtype=jnp.float32 ) def test_embedder(self): @@ -135,7 +134,7 @@ def test_decoder_layer(self): nnx_cache = self._init_nnx_cache(self.batch_size) torch_inputs = self._setup_torch_attn(tx) - jy, ty = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens)), 0), tm(**torch_inputs) + jy, ty = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens))), tm(**torch_inputs) torch.testing.assert_close(torch.tensor(jy), ty) def test_all_decoder_layers(self): @@ -146,7 +145,7 @@ def test_all_decoder_layers(self): jx = jax.random.normal(jax.random.key(0), shape=shape) tx = torch.tensor(jx) - jy = nm(jx, nc, jnp.ones((self.batch_size, self.num_input_tokens)), 0) + jy = nm(jx, nc, jnp.ones((self.batch_size, self.num_input_tokens))) torch_inputs = self._setup_torch_attn(tx) ty = tm.to(torch.float32)(**torch_inputs) torch.testing.assert_close(torch.tensor(jy), ty, atol=self.relaxed_tol, rtol=self.relaxed_tol) @@ -172,7 +171,7 @@ def test_self_attn(self): torch_inputs = self._setup_torch_attn(tx) nnx_cache = self._init_nnx_cache(self.batch_size) - jy = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens), dtype=jnp.float32), 0) + jy = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens), dtype=jnp.float32)) ty = tm(**torch_inputs)[0] torch.testing.assert_close(torch.tensor(jy), ty) @@ -267,7 +266,8 @@ def test_sin_cos(self): def test_full(self): query = ["Why is the sky blue instead of any other color like purple?"] - tokens, max_len = tokenize(self.tokenizer, query) + tokens = tokenize(self.tokenizer, query) + _, token_len = tokens.shape self.torch_model = self.torch_model.to(torch.float32) nnx_cache = self._init_nnx_cache(len(query)) @@ -275,12 +275,12 @@ def test_full(self): torch_inputs = self._process_hf_tokens(query) torch_logits = self.torch_model(**torch_inputs).logits torch.testing.assert_close( - torch.tensor(nnx_logits)[:, :max_len, :], torch_logits, rtol=self.relaxed_tol, atol=self.relaxed_tol + torch.tensor(nnx_logits)[:, :token_len, :], torch_logits, rtol=self.relaxed_tol, atol=self.relaxed_tol ) def test_full_batched(self): query = ["Why is the sky blue instead of any other color like purple?", "Who am I?"] - tokens, _ = tokenize(self.tokenizer, query) + tokens = tokenize(self.tokenizer, query) self.torch_model = self.torch_model.to(torch.float32) nnx_cache = self._init_nnx_cache(len(query))