Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions bonsai/models/qwen3/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import dataclasses
import math
from functools import partial
from typing import Tuple, TypeAlias

Expand Down Expand Up @@ -209,7 +210,10 @@ def __init__(self, cfg: ModelConfig, batch_size: int, cache_size: int, dtype: jn
Cache: TypeAlias = list[LayerCache]


def init_cache(cfg: ModelConfig, batch_size: int, cache_size: int, dtype: jnp.dtype = jnp.bfloat16) -> Cache:
def init_cache(
cfg: ModelConfig, batch_size: int, token_len: int, generate_steps: int, dtype: jnp.dtype = jnp.bfloat16
) -> Cache:
cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1))) # Pad for a sharding-friendly size.
return [LayerCache(cfg, batch_size, cache_size, dtype) for _ in range(cfg.num_layers)]


Expand Down Expand Up @@ -308,7 +312,7 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
self.scale = self.head_dim**-0.5

@jax.named_scope("attention")
def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, num_right_pads: int) -> Array:
def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array) -> Array:
query_proj = shard(self.q_norm(self.q_proj(x)), self.shd_cfg.act_btnh)
key_proj = shard(self.k_norm(self.k_proj(x)), self.shd_cfg.act_btnh)
value_proj = shard(self.v_proj(x), self.shd_cfg.act_btnh)
Expand Down Expand Up @@ -341,7 +345,7 @@ def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, num_r
attn_weights = jax.nn.softmax(attn_logits.astype(jnp.float32), axis=-1).astype(attn_logits.dtype)
qkv = jnp.einsum("BHGTS,BSHD->BTHGD", attn_weights, cache.v_cache.value).reshape((b, t, qh, d))

cache.cur_ind.value = cache.cur_ind.value + t - num_right_pads
cache.cur_ind.value = cache.cur_ind.value + t
return shard(self.o_proj(qkv), self.shd_cfg.act_btd)

@property
Expand Down Expand Up @@ -380,9 +384,9 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
self.post_attention_layernorm = RMSNorm(cfg.emb_dim, cfg, rngs=rngs)
self.mlp = MLP(cfg=cfg, rngs=rngs)

def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, num_right_pads: int) -> Array:
def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array) -> Array:
inputs_normalized = self.input_layernorm(x)
attn_output = x + self.attn(inputs_normalized, cache, segment_ids, num_right_pads)
attn_output = x + self.attn(inputs_normalized, cache, segment_ids)
outputs = attn_output + self.mlp(self.post_attention_layernorm(attn_output))
return outputs

Expand All @@ -396,10 +400,10 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
einsum_str="BTD,DV->BTV", shape=(cfg.emb_dim, cfg.vocab_size), shd=cfg.shd_cfg.emb_dv, rngs=rngs
)

def __call__(self, tokens, segment_ids, num_right_pads, cache):
def __call__(self, tokens, segment_ids, cache):
x = self.embedder.encode(tokens)
for i, layer in enumerate(self.layers):
x = layer(x, cache[i], segment_ids, num_right_pads)
x = layer(x, cache[i], segment_ids)
logits = self.lm_head(self.final_norm(x))
return logits

Expand All @@ -410,7 +414,6 @@ def forward(
) -> tuple[Array, nnx.State]:
model, cache = nnx.merge(graphdef, state)
segment_ids = 1 * (tokens != pad_id)
num_right_pads = count_right_pads(segment_ids, out_sharding=P(None))
logits = model(tokens, segment_ids, num_right_pads, cache)
logits = model(tokens, segment_ids, cache)
state = jax.tree.leaves(nnx.state((model, cache)))
return logits[:, -num_right_pads - 1], state
return logits[:, -1], state
55 changes: 25 additions & 30 deletions bonsai/models/qwen3/tests/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import AutoTokenizer

from bonsai.models.qwen3 import modeling, params
from bonsai.utils import Sampler
from bonsai.utils import GreedySampler, Sampler


def tokenize(tokenizer, input: list[str]):
Expand All @@ -37,9 +37,8 @@ def tokenize(tokenizer, input: list[str]):
for l in input
]
lines = [tokenizer.encode(line) for line in lines]
max_l = max(len(line) for line in lines) # Right-align, left-padding to the max token length.
buffer_len = 2 ** math.ceil(math.log2(max(max_l, 1))) # Pad the sequence to power-of-two buffer length.
return jnp.array([np.pad(l, (max_l - len(l), buffer_len - max_l), constant_values=pad_idx) for l in lines]), max_l
max_len = max(len(line) for line in lines) # Right-align, left-padding to the max token length.
return jnp.array([np.pad(l, (max_len - len(l), 0), constant_values=pad_idx) for l in lines])


def run_model():
Expand All @@ -50,17 +49,20 @@ def run_model():
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit))
jax.set_mesh(mesh)

query = ["Why is the sky blue instead of any other color like purple?", "Who am I?"]
query = [
"Why is the sky blue instead of any other color like purple?",
"Who am I?",
"Tell me 10 flavors of ice creams.",
]

tokenizer = AutoTokenizer.from_pretrained(model_ckpt_path)
tokens, max_len = tokenize(tokenizer, query)
batch_size, _ = tokens.shape
tokens = tokenize(tokenizer, query)
batch_size, token_len = tokens.shape

cache_size, gen_steps = 128, 11
assert cache_size >= max_len + gen_steps, f"Cache size ({cache_size}) must be >= {max_len} + {gen_steps}"
generate_steps = 256
cache = modeling.init_cache(config, batch_size, token_len, generate_steps)

model = params.create_model_from_safe_tensors(model_ckpt_path, config, mesh)
cache = modeling.init_cache(config, batch_size, cache_size)
graphdef, state = nnx.split((model, cache))
state = jax.tree.leaves(state) # Better perf from flattened jax state due to no pytree trasversals.

Expand All @@ -71,33 +73,26 @@ def run_model():
logits, state = modeling.forward(graphdef, state, tokens, tokenizer.pad_token_id)
next_tokens = sampler(logits, key=key)

# decode - warmup
# decode
tokens_list = [next_tokens]
logits, state = modeling.forward(graphdef, state, next_tokens, tokenizer.pad_token_id)
next_tokens = sampler(logits, key=key)
tokens_list.append(next_tokens)

# profile
jax.profiler.start_trace("/tmp/profile-data")
for i in range(5):
logits, state = modeling.forward(graphdef, state, next_tokens, tokenizer.pad_token_id)
next_tokens = sampler(logits, key=key)
tokens_list.append(next_tokens)
jax.block_until_ready(tokens_list)
jax.profiler.stop_trace()

decode_steps = 128
for i in range(decode_steps):
finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
for i in range(generate_steps):
logits, state = modeling.forward(graphdef, state, next_tokens, tokenizer.pad_token_id)
next_tokens = sampler(logits, key=key)
finished = finished | (next_tokens.squeeze(-1) == tokenizer.eos_token_id)
tokens_list.append(next_tokens)
if finished.all():
break

tokens_list = jnp.concatenate(tokens_list, axis=-1)
all_output_tokens = jax.device_get(jnp.concatenate(tokens_list, axis=-1))
for i, q in enumerate(query):
print(f"User:\n {q}")
print(
f"Answer:\n {tokenizer.decode(tokens_list.at[i].get(out_sharding=P(None)), skip_special_tokens=True)}\n\n"
)
seq_tokens = all_output_tokens[i]
eos_idx = np.where(seq_tokens == tokenizer.eos_token_id)[0]
if eos_idx.size > 0:
seq_tokens = seq_tokens[: eos_idx[0]]
decoded = tokenizer.decode(seq_tokens, skip_special_tokens=True)
print(f"Answer:\n {decoded}\n\n")


if __name__ == "__main__":
Expand Down
18 changes: 9 additions & 9 deletions bonsai/models/qwen3/tests/test_outputs_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ def _nnx_forward_logits(self, cache: modeling.Cache, tokens: jax.Array, dtype: D
"""Forward pass for the nnx model"""
segment_ids = 1 * (tokens != self.tokenizer.pad_token_id)
x = self.nnx_model.embedder.encode(tokens).astype(dtype)
right_pads = modeling.count_right_pads(segment_ids, out_sharding=P(None))
for i, layer in enumerate(self.nnx_model.layers):
x = layer(x, cache[i], segment_ids, right_pads).astype(dtype)
x = layer(x, cache[i], segment_ids).astype(dtype)
nnx_logits = self.nnx_model.lm_head(self.nnx_model.final_norm(x))
return nnx_logits

Expand All @@ -112,7 +111,7 @@ def _process_hf_tokens(self, query: list[str]):

def _init_nnx_cache(self, batch_size: int):
return modeling.init_cache(
cfg=self.bonsai_config, batch_size=batch_size, cache_size=self.cache_size, dtype=jnp.float32
cfg=self.bonsai_config, batch_size=batch_size, token_len=10, generate_steps=32, dtype=jnp.float32
)

def test_embedder(self):
Expand All @@ -135,7 +134,7 @@ def test_decoder_layer(self):
nnx_cache = self._init_nnx_cache(self.batch_size)
torch_inputs = self._setup_torch_attn(tx)

jy, ty = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens)), 0), tm(**torch_inputs)
jy, ty = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens))), tm(**torch_inputs)
torch.testing.assert_close(torch.tensor(jy), ty)

def test_all_decoder_layers(self):
Expand All @@ -146,7 +145,7 @@ def test_all_decoder_layers(self):
jx = jax.random.normal(jax.random.key(0), shape=shape)
tx = torch.tensor(jx)

jy = nm(jx, nc, jnp.ones((self.batch_size, self.num_input_tokens)), 0)
jy = nm(jx, nc, jnp.ones((self.batch_size, self.num_input_tokens)))
torch_inputs = self._setup_torch_attn(tx)
ty = tm.to(torch.float32)(**torch_inputs)
torch.testing.assert_close(torch.tensor(jy), ty, atol=self.relaxed_tol, rtol=self.relaxed_tol)
Expand All @@ -172,7 +171,7 @@ def test_self_attn(self):
torch_inputs = self._setup_torch_attn(tx)
nnx_cache = self._init_nnx_cache(self.batch_size)

jy = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens), dtype=jnp.float32), 0)
jy = nm(jx, nnx_cache[0], jnp.ones((self.batch_size, self.num_input_tokens), dtype=jnp.float32))
ty = tm(**torch_inputs)[0]
torch.testing.assert_close(torch.tensor(jy), ty)

Expand Down Expand Up @@ -267,20 +266,21 @@ def test_sin_cos(self):

def test_full(self):
query = ["Why is the sky blue instead of any other color like purple?"]
tokens, max_len = tokenize(self.tokenizer, query)
tokens = tokenize(self.tokenizer, query)
_, token_len = tokens.shape
self.torch_model = self.torch_model.to(torch.float32)
nnx_cache = self._init_nnx_cache(len(query))

nnx_logits = self._nnx_forward_logits(nnx_cache, tokens, jnp.float32)
torch_inputs = self._process_hf_tokens(query)
torch_logits = self.torch_model(**torch_inputs).logits
torch.testing.assert_close(
torch.tensor(nnx_logits)[:, :max_len, :], torch_logits, rtol=self.relaxed_tol, atol=self.relaxed_tol
torch.tensor(nnx_logits)[:, :token_len, :], torch_logits, rtol=self.relaxed_tol, atol=self.relaxed_tol
)

def test_full_batched(self):
query = ["Why is the sky blue instead of any other color like purple?", "Who am I?"]
tokens, _ = tokenize(self.tokenizer, query)
tokens = tokenize(self.tokenizer, query)
self.torch_model = self.torch_model.to(torch.float32)
nnx_cache = self._init_nnx_cache(len(query))

Expand Down