diff --git a/README.md b/README.md index 3c9f80c7..1475af26 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ These are listed based on status and then alphabetically. | [ResNet50](bonsai/models/resnet50) | Image classification | ✅ | | | [VGG](bonsai/models/vgg19) | Image classification | ✅ | | | [Dinov3](bonsai/models/dinov3) | Vision FM | ⚙️ | | +| [Gemma3](bonsai/models/gemma3) | VLM | ⚙️ | Local attention cache and todos in file | | [Mamba2](bonsai/models/mamba2) | Language SSM | ⚙️ | Caching and sharding | | [umT5](bonsai/models/umt5) | LLM | ⚙️ | Caching and sharding | | [ViT](bonsai/models/vit) | Image classification | ⚙️ | Sharding | @@ -39,7 +40,6 @@ These are listed based on status and then alphabetically. | [UNet](bonsai/models/unet/) | Image | 🟡 | Need a reference implementation and numerical testing | | [VAE](bonsai/models/vae/) | Generative model | 🟡 | Need a reference implementation and numerical testing | | [Whisper](bonsai/models/whisper/) | Speech recognition | 🟡 | Need more numerical testing and not all call methods implemented | -| Gemma3 | | ⏳ | | | CLIP | | ⏳ | | diff --git a/bonsai/models/gemma3/README.md b/bonsai/models/gemma3/README.md new file mode 100644 index 00000000..e89d227a --- /dev/null +++ b/bonsai/models/gemma3/README.md @@ -0,0 +1,24 @@ +# Gemma3 in JAX + +This directory contains a pure JAX implementation of the [Gemma3 model](https://deepmind.google/models/gemma/gemma-3/), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API. + +Note that you need an access token to download the model weights. In order to run the scripts, make sure to save an environment variable `HF_TOKEN` with your huggingface access token. + + +## Model Configuration Support Status + + +### Running this model + + +```sh +python3 -m bonsai.models.gemma3.tests.run_model +``` + + +## How to contribute to this model + +### Remaining Tasks + +1. Update to include kv cache memory reduction benefits from local attention. Currently, decode generation is not performance optimized. +2. Update to optimize parameter loading for larger models. diff --git a/bonsai/models/gemma3/modeling.py b/bonsai/models/gemma3/modeling.py new file mode 100644 index 00000000..982e8bc0 --- /dev/null +++ b/bonsai/models/gemma3/modeling.py @@ -0,0 +1,830 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import TypeAlias + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from flax.nnx.nn.linear import default_embed_init +from jax import P + +# TODO: Would be better to rely on something not in jax._src +from jax._src.nn.functions import _apply_masks +from jax.sharding import PartitionSpec +from jaxtyping import Array + + +class AttentionMode(Enum): + FULL = "full_attention" + SLIDE = "sliding_attention" + + +class ShardMode(Enum): + FSDP = "fsdp" + TP = "tp" + + +def _set_attention_modes(global_attn_freq: int, layers: int) -> list[AttentionMode]: + """Returns a list of attention modes where every global_attn_freq layers uses global attention.""" + return [AttentionMode.FULL if i % global_attn_freq == 0 else AttentionMode.SLIDE for i in range(1, layers + 1)] + + +@dataclass(slots=True, frozen=True) +class VisionShardingCfg: + attn_kernel: PartitionSpec + attn_bias: PartitionSpec + attn_qk_activation: PartitionSpec + fc1_kernel: PartitionSpec + fc1_bias: PartitionSpec + fc2_kernel: PartitionSpec + fc2_bias: PartitionSpec + activation: PartitionSpec + layer_norm: PartitionSpec + emb_patch_kernel: PartitionSpec + emb_patch_bias: PartitionSpec + emb_pos_kernel: PartitionSpec + + @staticmethod + def no_sharding(): + return VisionShardingCfg.default(False, False) + + @staticmethod + def default(use_fsdp: bool, use_tp: bool): + fsdp = ShardMode.FSDP.value if use_fsdp else None + tp = ShardMode.TP.value if use_tp else None + return VisionShardingCfg( + attn_kernel=P(tp, fsdp), + attn_bias=P(tp), + attn_qk_activation=P(fsdp, tp), + fc1_kernel=P(fsdp, tp), + fc1_bias=P(tp), + fc2_kernel=P(tp, fsdp), + fc2_bias=P(tp), + activation=P(fsdp, None, tp), + layer_norm=P(tp), + emb_patch_kernel=P(None, None, None, tp), + emb_patch_bias=P(tp), + emb_pos_kernel=P(None, tp), + ) + + +@dataclass(slots=True, frozen=True) +class TextShardingCfg: + attn_kernel: PartitionSpec + attn_bias: PartitionSpec + attn_qk_activation: PartitionSpec + down_kernel: PartitionSpec + down_bias: PartitionSpec + up_gate_kernel: PartitionSpec + up_gate_bias: PartitionSpec + activation: PartitionSpec + decoder_norm: PartitionSpec + cache: PartitionSpec + emb_kernel: PartitionSpec + + @staticmethod + def no_sharding(): + return TextShardingCfg.default(False, False) + + @staticmethod + def default(use_fsdp: bool, use_tp: bool): + fsdp = ShardMode.FSDP.value if use_fsdp else None + tp = ShardMode.TP.value if use_tp else None + return TextShardingCfg( + attn_kernel=P(tp, fsdp), + attn_bias=P(tp), + attn_qk_activation=P(fsdp, tp), + down_kernel=P(tp, fsdp), + down_bias=P(tp), + up_gate_kernel=P(fsdp, tp), + up_gate_bias=P(tp), + activation=P(fsdp, None, tp), + decoder_norm=P(tp), + cache=P(fsdp, None, tp, None), + emb_kernel=P(None, tp), + ) + + +@dataclass(slots=True, frozen=True) +class ShardingCfg: + mmp_norm: PartitionSpec + mmp_weight: PartitionSpec + + @staticmethod + def default(use_tp: bool): + tp = ShardMode.TP.value if use_tp else None + return ShardingCfg(mmp_norm=P(tp), mmp_weight=P(tp)) + + +@dataclass(frozen=True) +class VisionConfig: + attention_dropout: float # TODO: unused + hidden_size: int + image_size: int + intermediate_size: int + layer_norm_eps: float + num_attention_heads: int + num_channels: int + num_hidden_layers: int + patch_size: int + vision_use_head: bool + shd_cfg: VisionShardingCfg + + @classmethod + def gemma3_4b_it(cls, use_fsdp: bool, use_tp: bool): + return cls( + attention_dropout=0.0, + hidden_size=1152, + image_size=896, + intermediate_size=4304, + layer_norm_eps=1e-6, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=27, + patch_size=14, + vision_use_head=False, + shd_cfg=VisionShardingCfg.default(use_fsdp, use_tp), + ) + + +@dataclass(frozen=True) +class TextConfig: + attention_bias: bool + attention_dropout: float # TODO: unused + head_dim: int + hidden_size: int + intermediate_size: int + layer_types: list[AttentionMode] + max_position_embeddings: int # TODO: unused + num_attention_heads: int + num_hidden_layers: int + num_key_value_heads: int + rms_norm_eps: float + rope_full_factor: float + rope_full_theta: float + rope_slide_factor: float + rope_slide_theta: float + sliding_window: int + vocab_size: int + norm_dtype: jnp.dtype + shd_cfg: TextShardingCfg + + @classmethod + def gemma3_4b_it(cls, use_fsdp: bool, use_tp: bool, *, norm_dtype: jnp.dtype): + num_hidden_layers = 34 + return cls( + attention_bias=False, + attention_dropout=0.0, # TODO: unused + head_dim=256, + hidden_size=2560, + intermediate_size=10240, + layer_types=_set_attention_modes(6, num_hidden_layers), + max_position_embeddings=131072, # TODO: unused + num_attention_heads=8, + num_hidden_layers=num_hidden_layers, + num_key_value_heads=4, + rms_norm_eps=1e-6, + rope_full_factor=8.0, + rope_full_theta=1000000.0, + rope_slide_factor=1.0, + rope_slide_theta=10000.0, + sliding_window=1024, + vocab_size=262208, + norm_dtype=norm_dtype, + shd_cfg=TextShardingCfg.default(use_fsdp, use_tp), + ) + + +@dataclass(frozen=True) +class ModelConfig: + vision_config: VisionConfig + text_config: TextConfig + mm_tokens_per_image: int + dtype: str # TODO: unused + final_logit_softcapping: float | None + shd_cfg: ShardingCfg + + @classmethod + def gemma3_4b_it(cls, use_fsdp: bool = False, use_tp: bool = False, *, norm_dtype: jnp.dtype): + return cls( + vision_config=VisionConfig.gemma3_4b_it(use_fsdp, use_tp), + text_config=TextConfig.gemma3_4b_it(use_fsdp, use_tp, norm_dtype=norm_dtype), + mm_tokens_per_image=256, + dtype="bfloat16", # TODO: unused + final_logit_softcapping=None, + shd_cfg=ShardingCfg.default(use_tp), + ) + + +# --- General Components --- # +# TODO: Replace with nnx.Linear once explicit sharding is supported. +class ShardedLinear(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + *, + use_bias: bool = True, + kernel_sharding, + bias_sharding, + dtype=None, + rngs, + ): + kernel_initializer = jax.nn.initializers.lecun_normal() + self.kernel = nnx.Param( + kernel_initializer(rngs.params(), (in_dim, out_dim), dtype=dtype, out_sharding=kernel_sharding) + ) + if use_bias: + self.bias = nnx.Param(jnp.zeros((out_dim,), dtype=dtype, out_sharding=bias_sharding)) + else: + self.bias = nnx.data(jnp.zeros((out_dim,), dtype=dtype, out_sharding=bias_sharding)) + + def __call__(self, x, *, out_sharding): + return jnp.matmul(x, self.kernel, out_sharding=out_sharding) + self.bias + + +# TODO: Replace with nnx.Embed once explicit sharding is supported. +class ShardedEmbedding(nnx.Embed): + def __call__(self, inputs: Array, *, out_sharding) -> Array: + # Modified from Flax NNX + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError("Input type must be an integer or unsigned integer.") + # Use take because fancy indexing numpy arrays with JAX indices does not + # work correctly. + (embedding,) = self.promote_dtype((self.embedding.value,), dtype=self.dtype, inexact=False) + if self.num_embeddings == 1: + return jnp.broadcast_to(embedding, (*inputs.shape, self.features)) + return embedding.at[inputs].get(out_sharding=out_sharding) + + def attend(self, query: Array, *, out_sharding) -> Array: + query, embedding = self.promote_dtype((query, self.embedding.value), dtype=self.dtype) + return jnp.dot(query, embedding.T, out_sharding=out_sharding) + + +# adapted from the jax.nn.dot_product_attention implementation +def sharded_attention(q, k, v, mask, scale=None, *, attn_logit_sharding: PartitionSpec, out_sharding: PartitionSpec): + logits = jnp.einsum("BTNH,BSNH->BNTS", q, k, out_sharding=attn_logit_sharding) + scale_val = (1.0 / np.sqrt(k.shape[-1])) if scale is None else scale + logits *= jnp.array(scale_val, dtype=logits.dtype) + + is_causal = False + local_window_size, q_seqlen, kv_seqlen = None, None, None + padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, local_window_size) + + padded_logits = padded_logits.astype(np.float32) + probs = jax.nn.softmax(padded_logits, axis=-1).astype(k.dtype) + # TODO: Add dropout here + + attn_out = jnp.einsum("BNTS,BSNH->BTNH", probs, v, out_sharding=out_sharding) + return attn_out + + +# --- Vision Components --- # +# TODO: update to include interpolate_pos_encoding +class SiglipVisionEmbeddings(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.num_patches = (config.image_size // config.patch_size) ** 2 + + ki = partial(jax.nn.initializers.lecun_normal(), out_sharding=config.shd_cfg.emb_patch_kernel) + bi = partial(jax.nn.initializers.zeros, out_sharding=config.shd_cfg.emb_patch_bias) + self.patch_embedding = nnx.Conv( + config.num_channels, + config.hidden_size, + kernel_size=(config.patch_size,) * 2, + strides=(config.patch_size,) * 2, + padding="valid", + kernel_init=ki, + bias_init=bi, + rngs=rngs, + ) + + ei = partial(default_embed_init, out_sharding=config.shd_cfg.emb_pos_kernel) + self.position_embedding = ShardedEmbedding(self.num_patches, config.hidden_size, embedding_init=ei, rngs=rngs) + + shd = P(config.shd_cfg.activation[0]) + self.position_ids = jax.device_put(jnp.expand_dims(jnp.arange(self.num_patches), 0), shd) + + def __call__(self, pixel_values: Array): + patch_embeds = self.patch_embedding(pixel_values) + b, h, w, c = patch_embeds.shape + embeddings = patch_embeds.reshape((b, h * w, c)) + shd = self.config.shd_cfg.activation + out = embeddings + self.position_embedding(self.position_ids, out_sharding=shd) + return out + + +class SiglipAttention(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + hs, shd = config.hidden_size, config.shd_cfg + self.k_proj = ShardedLinear(hs, hs, kernel_sharding=shd.attn_kernel, bias_sharding=shd.attn_bias, rngs=rngs) + self.v_proj = ShardedLinear(hs, hs, kernel_sharding=shd.attn_kernel, bias_sharding=shd.attn_bias, rngs=rngs) + self.q_proj = ShardedLinear(hs, hs, kernel_sharding=shd.attn_kernel, bias_sharding=shd.attn_bias, rngs=rngs) + self.out_proj = ShardedLinear(hs, hs, kernel_sharding=shd.attn_kernel, bias_sharding=shd.attn_bias, rngs=rngs) + + def __call__(self, x: Array, attn_mask: Array | None): + batch_size, seq_length, _ = x.shape + shape = (batch_size, seq_length, self.num_heads, self.head_dim) + shd = self.config.shd_cfg.activation + + q = self.q_proj(x, out_sharding=shd).reshape(shape) + k = self.k_proj(x, out_sharding=shd).reshape(shape) + v = self.v_proj(x, out_sharding=shd).reshape(shape) + + intermediate_shd = self.config.shd_cfg.attn_qk_activation + attn = sharded_attention( + q, k, v, mask=attn_mask, attn_logit_sharding=intermediate_shd, out_sharding=shd + ).reshape(x.shape) + return self.out_proj(attn, out_sharding=shd) + + +class SiglipMLP(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + shd = config.shd_cfg + self.fc1 = ShardedLinear( + config.hidden_size, + config.intermediate_size, + kernel_sharding=shd.fc1_kernel, + bias_sharding=shd.fc1_bias, + rngs=rngs, + ) + self.fc2 = ShardedLinear( + config.intermediate_size, + config.hidden_size, + kernel_sharding=shd.fc2_kernel, + bias_sharding=shd.fc2_bias, + rngs=rngs, + ) + + def __call__(self, x: Array): + x = self.fc1(x, out_sharding=self.config.shd_cfg.activation) + x = jax.nn.gelu(x) + x = self.fc2(x, out_sharding=self.config.shd_cfg.activation) + return x + + +class SiglipEncoderLayer(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + shd = config.shd_cfg.layer_norm + si = partial(jax.nn.initializers.ones, out_sharding=shd) + bi = partial(jax.nn.initializers.zeros, out_sharding=shd) + self.layer_norm1 = nnx.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps, scale_init=si, bias_init=bi, rngs=rngs + ) + self.layer_norm2 = nnx.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps, scale_init=si, bias_init=bi, rngs=rngs + ) + self.self_attn = SiglipAttention(config, rngs=rngs) + self.mlp = SiglipMLP(config, rngs=rngs) + + def __call__(self, x: Array, attn_mask: Array | None): + hidden = self.layer_norm1(x) + hidden = self.self_attn(hidden, attn_mask) + hidden = x + hidden + x = hidden + hidden = self.layer_norm2(hidden) + hidden = self.mlp(hidden) + return hidden + x + + +class SiglipEncoder(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.layers = nnx.List([SiglipEncoderLayer(config, rngs=rngs) for _ in range(config.num_hidden_layers)]) + + def __call__(self, x: Array, attn_mask: Array | None): + for l in self.layers: + x = l(x, attn_mask) + return x + + +# TODO: Skip for now since not in 4b, but test later +class SiglipMultiheadAttentionPoolingHead(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.probe = nnx.Param(nnx.initializers.normal(stddev=0.02)(rngs.params(), (1, 1, config.hidden_size))) + self.attention = nnx.MultiHeadAttention(config.num_attention_heads, config.hidden_size, rngs=rngs) + self.layernorm = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.mlp = SiglipMLP(config, rngs=rngs) + + def __call__(self, *args, **kwargs): + raise NotImplementedError("Not yet implemented") + + +class SiglipVisionTransformer(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.embeddings = SiglipVisionEmbeddings(config, rngs=rngs) + self.encoder = SiglipEncoder(config, rngs=rngs) + shd = config.shd_cfg.layer_norm + si = partial(jax.nn.initializers.ones, out_sharding=shd) + bi = partial(jax.nn.initializers.zeros, out_sharding=shd) + self.post_layernorm = nnx.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps, scale_init=si, bias_init=bi, rngs=rngs + ) + + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + def __call__(self, pixel_values: Array): + x = self.embeddings(pixel_values) + x = self.encoder(x, attn_mask=None) + x = self.post_layernorm(x) + if self.use_head: + x = self.head(x) + return x + + +# --- Language Components --- # +# TODO: Update to use a more efficient cache for local attention. +class LayerCache(nnx.Module): + def __init__(self, cfg: TextConfig, layer_idx: int, batch_size: int, cache_size: int, dtype: jnp.dtype): + cache_shape = (batch_size, cache_size, cfg.num_key_value_heads, cfg.head_dim) + kv_shd = cfg.shd_cfg.cache + self.k_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype, out_sharding=kv_shd)) + self.v_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype, out_sharding=kv_shd)) + self.size = self.k_cache.shape[1] + self.start_ind = nnx.Variable(-1 * jnp.ones((batch_size,), dtype=jnp.int32, out_sharding=P(kv_shd[0]))) + self.cur_ind = nnx.Variable(jnp.zeros((), dtype=jnp.int32)) + + +Cache: TypeAlias = list[LayerCache] + + +# TODO: Update to have a memory efficient cache for sliding window. +def init_cache( + cfg: ModelConfig, batch_size: int, token_len: int, generate_steps: int, dtype: jnp.dtype = jnp.bfloat16 +) -> Cache: + cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1))) # Pad for a sharding-friendly size. + return [ + LayerCache(cfg.text_config, i, batch_size, cache_size, dtype) for i in range(cfg.text_config.num_hidden_layers) + ] + + +class Gemma3RMSNorm(nnx.Module): + def __init__(self, dim: int, eps: float, *, dtype: jnp.dtype, shd: PartitionSpec, rngs: nnx.Rngs): + self.scale = nnx.Param(jax.nn.initializers.zeros(rngs.params(), dim, dtype=dtype, out_sharding=shd)) + self.eps = eps + + @jax.named_scope("rms_norm") + def __call__(self, x: Array) -> Array: + dtype = x.dtype + xf32 = x.astype(jnp.float32) + out = xf32 * jax.lax.rsqrt(jnp.square(xf32).mean(-1, keepdims=True) + self.eps) + out = out * (1.0 + self.scale.value.astype(jnp.float32)) + return out.astype(dtype) + + +class Gemma3TextScaledWordEmbedding(nnx.Module): + def __init__(self, cfg: TextConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + ei = partial(default_embed_init, out_sharding=cfg.shd_cfg.emb_kernel) + self.weight = ShardedEmbedding(cfg.vocab_size, cfg.hidden_size, embedding_init=ei, rngs=rngs) + self.embed_scale = jnp.array(cfg.hidden_size**0.5, dtype=jnp.bfloat16).astype(jnp.float32) + + def __call__(self, input_ids: Array): + shd = self.cfg.shd_cfg.activation + x = self.weight(input_ids, out_sharding=shd) * self.embed_scale + return x + + +def _generate_pos_embeddings( + positions: jax.Array, + head_dim: int, + rope_theta: int = 1_000_000, + factor: float = 1.0, +) -> tuple[jax.Array, jax.Array]: + # Forked from: jax-llm-examples/qwen3/qwen3_jax/model.py;l=571 + fraction = jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim + timescale = rope_theta**fraction + rotational_frequency = 1.0 / timescale + rotational_frequency /= factor + # Use high-precision einsum to prevent catastrophic bfloat16 rounding (ex: 257→256), as sin(257) differs from sin(256). + sinusoid_inp = jnp.einsum("BT,k->BTk", positions, rotational_frequency, precision=jax.lax.Precision.HIGHEST) + return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) + + +def apply_rope(x: jax.Array, sin: jax.Array, cos: jax.Array) -> jax.Array: + assert x.ndim == 4 and sin.ndim == 3 and cos.ndim == 3 + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # [B, T, head_dim] -> [B, h, T, head_dim] + sin, cos = sin[:, :, None, :], cos[:, :, None, :] + out = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1).astype(x.dtype) + return out + + +def count_left_pads(x: jax.Array) -> int: + """Count left padding tokens.""" + return jnp.sum(jnp.cumsum(x != 0, axis=-1) == 0, -1) + + +# TODO: Not used right now +def count_right_pads(x: jax.Array, pad_id) -> int: + result = jnp.where( + jnp.all(x == pad_id, axis=1), x.shape[1], jnp.argmin(jnp.flip(x == pad_id, axis=1).astype(jnp.int32), axis=1) + ) + return jnp.max(result) + + +def compute_positions_from_segment_ids(seg_ids: Array): + return jax.vmap(lambda row: jnp.where(row != 0, jnp.arange(seg_ids.shape[1]) - jnp.argmax(row), 2**30))(seg_ids) + + +def repeat_kv(hidden_states: Array, n_rep: int): + b, t, kv_heads, head_dim = hidden_states.shape + hidden_states = jnp.expand_dims(hidden_states, axis=3) + hidden_states = jnp.repeat(hidden_states, repeats=n_rep, axis=3) + return hidden_states.reshape(b, t, kv_heads * n_rep, head_dim) + + +class Gemma3Attention(nnx.Module): + def __init__(self, config: TextConfig, layer_idx: int, *, rngs: nnx.Rngs): + self.config = config + self.layer_idx = layer_idx + self.use_sliding = config.layer_types[layer_idx] == AttentionMode.SLIDE + self.num_kv_heads = self.config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + shd = config.shd_cfg + self.q_proj = ShardedLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.attn_kernel, + bias_sharding=shd.attn_bias, + rngs=rngs, + ) + self.k_proj = ShardedLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.attn_kernel, + bias_sharding=shd.attn_bias, + rngs=rngs, + ) + self.v_proj = ShardedLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.attn_kernel, + bias_sharding=shd.attn_bias, + rngs=rngs, + ) + self.o_proj = ShardedLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + use_bias=config.attention_bias, + kernel_sharding=shd.attn_kernel, + bias_sharding=shd.attn_bias, + rngs=rngs, + ) + self.q_norm = Gemma3RMSNorm(config.head_dim, config.rms_norm_eps, dtype=config.norm_dtype, shd=P(), rngs=rngs) + self.k_norm = Gemma3RMSNorm(config.head_dim, config.rms_norm_eps, dtype=config.norm_dtype, shd=P(), rngs=rngs) + + self.rope_theta = config.rope_slide_theta if self.use_sliding else config.rope_full_theta + self.factor = config.rope_slide_factor if self.use_sliding else config.rope_full_factor + + self.n_rep = config.num_attention_heads // config.num_key_value_heads + self.scale = config.head_dim**-0.5 + + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: Array | None) -> Array: + # Get projections + new_shape = (*x.shape[:-1], -1, self.head_dim) + shd = self.config.shd_cfg.activation + q = self.q_norm(self.q_proj(x, out_sharding=shd).reshape(new_shape)) + k = self.k_norm(self.k_proj(x, out_sharding=shd).reshape(new_shape)) + v = self.v_proj(x, out_sharding=shd).reshape(new_shape) + + # Apply rope + left_pads = count_left_pads(segment_ids) + cache.start_ind.value = jnp.where(cache.start_ind.value < 0, left_pads, cache.start_ind.value) + position_ids = compute_positions_from_segment_ids(segment_ids) + cache.cur_ind.value + sin, cos = _generate_pos_embeddings(position_ids, self.head_dim, self.rope_theta, factor=self.factor) + q = apply_rope(q, sin, cos) + k = apply_rope(k, sin, cos) + + # Update cache + slice_indices = (0, cache.cur_ind.value, 0, 0) + cache.k_cache.value = jax.lax.dynamic_update_slice(cache.k_cache.value, k, slice_indices) + cache.v_cache.value = jax.lax.dynamic_update_slice(cache.v_cache.value, v, slice_indices) + + k, v = repeat_kv(cache.k_cache.value, self.n_rep), repeat_kv(cache.v_cache.value, self.n_rep) + intermediate_shd = self.config.shd_cfg.attn_qk_activation + qkv = sharded_attention( + q, k, v, mask=mask, scale=self.scale, attn_logit_sharding=intermediate_shd, out_sharding=shd + ) + t = x.shape[1] + cache.cur_ind.value = cache.cur_ind.value + t + return self.o_proj(qkv.reshape(*x.shape[:-1], -1), out_sharding=shd) + + +class Gemma3MLP(nnx.Module): + def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): + self.config = config + hsize, isize, shd = config.hidden_size, config.intermediate_size, config.shd_cfg + self.gate_proj = ShardedLinear( + hsize, isize, use_bias=False, kernel_sharding=shd.up_gate_kernel, bias_sharding=shd.up_gate_bias, rngs=rngs + ) + self.up_proj = ShardedLinear( + hsize, isize, use_bias=False, kernel_sharding=shd.up_gate_kernel, bias_sharding=shd.up_gate_bias, rngs=rngs + ) + self.down_proj = ShardedLinear( + isize, hsize, use_bias=False, kernel_sharding=shd.down_kernel, bias_sharding=shd.down_bias, rngs=rngs + ) + + def __call__(self, x: Array): + ux = self.up_proj(x, out_sharding=self.config.shd_cfg.activation) + gx = jax.nn.gelu(self.gate_proj(x, out_sharding=self.config.shd_cfg.activation)) + out = self.down_proj(gx * ux, out_sharding=self.config.shd_cfg.activation) + return out + + +class Gemma3DecoderLayer(nnx.Module): + def __init__(self, config: TextConfig, layer_idx: int, *, rngs: nnx.Rngs): + super().__init__() + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx, rngs=rngs) + self.mlp = Gemma3MLP(config, rngs=rngs) + + norm_shd = config.shd_cfg.decoder_norm + norm_kwargs = dict(dim=config.hidden_size, eps=config.rms_norm_eps, dtype=config.norm_dtype, shd=norm_shd) + self.input_layernorm = Gemma3RMSNorm(**norm_kwargs, rngs=rngs) + self.post_attention_layernorm = Gemma3RMSNorm(**norm_kwargs, rngs=rngs) + self.pre_feedforward_layernorm = Gemma3RMSNorm(**norm_kwargs, rngs=rngs) + self.post_feedforward_layernorm = Gemma3RMSNorm(**norm_kwargs, rngs=rngs) + + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: Array | None) -> Array: + res = x + x = self.input_layernorm(x) + x = self.self_attn(x, cache, segment_ids, mask=mask) + x = self.post_attention_layernorm(x) + x = res + x + res = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + return x + res + + @property + def head_dim(self): + return self.o_proj.shape[1] + + +class Gemma3TextModel(nnx.Module): + def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): + self.config = config + self.layers = nnx.List( + [Gemma3DecoderLayer(config, layer_idx, rngs=rngs) for layer_idx in range(config.num_hidden_layers)] + ) + norm_shd = config.shd_cfg.decoder_norm + self.norm = Gemma3RMSNorm( + config.hidden_size, config.rms_norm_eps, dtype=config.norm_dtype, shd=norm_shd, rngs=rngs + ) + + def __call__(self, x, cache: Cache, segment_ids: Array, sliding_mask: Array | None, causal_mask: Array | None): + for lt, c, layer in zip(self.config.layer_types, cache, self.layers): + mask = sliding_mask if lt == AttentionMode.SLIDE else causal_mask + x = layer(x, c, segment_ids, mask) + x = self.norm(x) + return x + + +class Gemma3MultiModalProjector(nnx.Module): + def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs): + self.config = config + vhs, ths = config.vision_config.hidden_size, config.text_config.hidden_size + eps = config.vision_config.layer_norm_eps + self.patches_per_img = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_img // self.tokens_per_side + + mmp_w_shd, mmp_norm_shd = config.shd_cfg.mmp_weight, config.shd_cfg.mmp_norm + self.mm_input_projection_weight = nnx.Param(jnp.zeros((vhs, ths), out_sharding=mmp_w_shd), rngs=rngs) + self.mm_soft_emb_norm = Gemma3RMSNorm( + vhs, eps=eps, dtype=config.text_config.norm_dtype, shd=mmp_norm_shd, rngs=rngs + ) + + def __call__(self, vision_outputs: Array) -> Array: + b, _, t = vision_outputs.shape + vision_outputs = vision_outputs.swapaxes(1, 2).reshape(b, t, self.patches_per_img, self.patches_per_img) + + x = nnx.avg_pool( + vision_outputs[:, :, :, :, None], + window_shape=(1, 1, self.kernel_size, self.kernel_size), + strides=(1, 1, self.kernel_size, self.kernel_size), + )[:, :, :, :, 0] + x = x.reshape(b, t, -1).swapaxes(1, 2) + x = self.mm_soft_emb_norm(x) + x = jnp.matmul( + x, self.mm_input_projection_weight.value, out_sharding=self.config.vision_config.shd_cfg.activation + ) + return x.astype(vision_outputs.dtype) + + +def make_causal_mask(layer_cache: LayerCache, token_type_ids: Array, *, out_sharding: PartitionSpec): + b, t = token_type_ids.shape + c = layer_cache.size + seq_arange = jnp.arange(t) + cache_arange = jnp.arange(c) + causal_mask = seq_arange[:, None] - cache_arange[None, :] >= -layer_cache.cur_ind + tti = token_type_ids.astype(jnp.bool_) + cache_padded_tti = jnp.concat([tti, jnp.zeros((b, c - t), dtype=jnp.bool_, out_sharding=out_sharding)], axis=-1) + image_or_mask = tti[:, None, :, None] & cache_padded_tti[:, None, None, :] + causal_mask = causal_mask.astype(jnp.bool_) | image_or_mask + return causal_mask + + +def make_window_mask(layer_cache: LayerCache, token_type_ids: Array, slide_size: int, *, out_sharding: PartitionSpec): + causal_mask = make_causal_mask(layer_cache, token_type_ids, out_sharding=out_sharding) + *_, t, c = causal_mask.shape + seq_arange = jnp.arange(t) + cache_arange = jnp.arange(c) + slide = seq_arange[:, None] - cache_arange[None, :] < slide_size + return causal_mask & slide + + +def merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) -> Array: + # This function fills the image tokens into the text_emb sequence + # The token_mask tells us where the image tokens are (0 for text, 1 for image) + # image_emb is (Li, D) + # text_emb is (Lt, D) + # token_mask is (Lt) + # We have Li < Lt + img_indices = jnp.cumsum(token_mask) - 1 + safe_indices = jnp.clip(img_indices, 0, img_emb.shape[0] - 1) + aligned_images = img_emb[safe_indices] + return jnp.where(token_mask[:, None], aligned_images, text_emb) + + +def batched_merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) -> Array: + # image_emb is (B, Li, D) + # text_emb is (B, Lt, D) + # token_mask is (B, Lt) + # We have Li < Lt + return jax.vmap(merge_modalities)(img_emb, text_emb, token_mask) + + +class Gemma3Model(nnx.Module): + def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): + self.config = cfg + self.sliding_window_size = cfg.text_config.sliding_window + self.embed_tokens = Gemma3TextScaledWordEmbedding(cfg.text_config, rngs=rngs) + self.vision_tower = SiglipVisionTransformer(cfg.vision_config, rngs=rngs) + self.multi_modal_projector = Gemma3MultiModalProjector(cfg, rngs=rngs) + self.language_model = Gemma3TextModel(cfg.text_config, rngs=rngs) + self.final_logit_softcapping = cfg.final_logit_softcapping + + def __call__( + self, input_ids: Array, pixel_values: Array, cache: Cache, segment_ids: Array, token_type_ids: Array + ) -> Array: + assert input_ids.shape == token_type_ids.shape + shd = P(self.config.text_config.shd_cfg.activation[0]) + causal_mask = make_causal_mask(cache[0], token_type_ids, out_sharding=shd) + sliding_mask = make_window_mask(cache[0], token_type_ids, slide_size=self.sliding_window_size, out_sharding=shd) + inputs_embeds = self.embed_tokens(input_ids) + + # Merge text and images + if pixel_values is not None: + vision_outputs = self.vision_tower(pixel_values) + image_features = self.multi_modal_projector(vision_outputs).astype(inputs_embeds.dtype) + inputs_embeds = batched_merge_modalities(image_features, inputs_embeds, token_type_ids) + + out = self.language_model(inputs_embeds, cache, segment_ids, sliding_mask, causal_mask) + shd = P(self.config.text_config.shd_cfg.activation[0]) + out = self.embed_tokens.weight.attend(out, out_sharding=shd) + + if self.config.final_logit_softcapping is not None: + out = out / self.final_logit_softcapping + out = jax.nn.tanh(out) + out = out * self.final_logit_softcapping + + return out + + +@jax.jit +def forward( + model: nnx.Module, cache: Cache, input_ids: Array, pixel_values: Array, segment_ids: Array, token_type_ids +) -> tuple[Array, nnx.Cache]: + logits = model(input_ids, pixel_values, cache, segment_ids, token_type_ids) + return logits[:, -1, :], cache diff --git a/bonsai/models/gemma3/params.py b/bonsai/models/gemma3/params.py new file mode 100644 index 00000000..f6d6f5fb --- /dev/null +++ b/bonsai/models/gemma3/params.py @@ -0,0 +1,266 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parameter helpers for bonsai.models.gemma3. + +Add functions to load or convert pretrained checkpoints and to return +default configuration values used by the model implementation. +""" + +import logging +import re +from enum import Enum + +import jax +import safetensors.flax as safetensors +from etils import epath +from flax import nnx + +from bonsai.models.gemma3 import modeling as model_lib + + +class Transform(Enum): + """ + Specifies default transformation types for model parameter names. + """ + + DEFAULT = None + BIAS = None + LINEAR = ((1, 0), None) + CONV2D = ((2, 3, 1, 0), None) + EMBED = None + + +def _get_key_and_transform_mapping(): + # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)). + return { + r"^language_model\.model\.embed_tokens\.weight$": ( + r"embed_tokens\.weight\.embedding", + Transform.EMBED, + ), + r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight$": ( + r"language_model\.layers\.\1\.input_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.down_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.gate_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.up_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight$": ( + r"language_model\.layers\.\1\.post_attention_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.post_feedforward_layernorm\.weight$": ( + r"language_model\.layers\.\1\.post_feedforward_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.pre_feedforward_layernorm\.weight$": ( + r"language_model\.layers\.\1\.pre_feedforward_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.k_norm\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.k_norm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.k_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.o_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.q_norm\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.q_norm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.q_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.v_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.norm\.weight$": (r"language_model\.norm\.scale", Transform.DEFAULT), + r"^multi_modal_projector\.mm_input_projection_weight$": ( + r"multi_modal_projector\.mm_input_projection_weight", + Transform.DEFAULT, + ), + r"^multi_modal_projector\.mm_soft_emb_norm\.weight$": ( + r"multi_modal_projector\.mm_soft_emb_norm\.scale", + Transform.DEFAULT, + ), + r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.bias$": ( + r"vision_tower\.embeddings\.patch_embedding\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.weight$": ( + r"vision_tower\.embeddings\.patch_embedding\.kernel", + Transform.CONV2D, + ), + r"^vision_tower\.vision_model\.embeddings\.position_embedding\.weight$": ( + r"vision_tower\.embeddings\.position_embedding\.embedding", + Transform.EMBED, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.layer_norm(\d+)\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.layer_norm\2\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.layer_norm(\d+)\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.layer_norm\2\.scale", + Transform.DEFAULT, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc(\d+)\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.mlp\.fc\2\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc(\d+)\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.mlp\.fc\2\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.k_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.k_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.out_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.out_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.q_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.q_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.v_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.v_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.post_layernorm\.bias$": ( + r"vision_tower\.post_layernorm\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.post_layernorm\.weight$": ( + r"vision_tower\.post_layernorm\.scale", + Transform.DEFAULT, + ), + } + + +def _st_key_to_jax_key(mapping, source_key): + """Map a safetensors key to exactly one JAX key & transform, else warn/error.""" + subs = [ + (re.sub(pat, repl, source_key), transform) + for pat, (repl, transform) in mapping.items() + if re.match(pat, source_key) + ] + if not subs: + logging.warning(f"No mapping found for key: {source_key!r}") + return None, None + if len(subs) > 1: + keys = [s for s, _ in subs] + raise ValueError(f"Multiple mappings found for {source_key!r}: {keys}") + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform, sharding_dict): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape = transform + if permute: + tensor = tensor.transpose(permute) + if reshape: + tensor = tensor.reshape(reshape) + if tensor.shape != state_dict[key].shape: + raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + if sharding_dict is not None: + state_dict[key] = jax.device_put(tensor, sharding_dict[key]) + else: + state_dict[key] = jax.device_put(tensor) + else: + next_sharding = sharding_dict[key] if sharding_dict is not None else None + _assign_weights(rest, tensor, state_dict[key], st_key, transform, next_sharding) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +# TODO: Update to optimize parameter loading for larger models +def create_gemma3_from_pretrained(file_dir: str, cfg: model_lib.ModelConfig, *, mesh: jax.sharding.Mesh | None = None): + """ + Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model. + + Returns: + A flax.nnx.Model instance with loaded parameters. + """ + files = list(epath.Path(file_dir).expanduser().glob("*.safetensors")) + if not files: + raise ValueError(f"No safetensors found in {file_dir}") + + tensor_dict = {} + for f in files: + tensor_dict |= safetensors.load_file(f) + + gemma3 = model_lib.Gemma3Model(cfg, rngs=nnx.Rngs(0)) + graph_def, abs_state = nnx.split(gemma3) + jax_state = abs_state.to_pure_dict() + sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None + + mapping = _get_key_and_transform_mapping() + for st_key, tensor in tensor_dict.items(): + jax_key, transform = _st_key_to_jax_key(mapping, st_key) + if jax_key is None: + continue + keys = [_stoi(k) for k in jax_key.split(r"\.")] + try: + _assign_weights(keys, tensor, jax_state, st_key, transform.value, sharding) + except KeyError as e: + print(f"Key error: {keys} at {e}") + except ValueError as e: + print(e) + except Exception as e: + print(keys) + raise e + + return nnx.merge(graph_def, jax_state) diff --git a/bonsai/models/gemma3/tests/run_model.py b/bonsai/models/gemma3/tests/run_model.py new file mode 100644 index 00000000..0e4e3499 --- /dev/null +++ b/bonsai/models/gemma3/tests/run_model.py @@ -0,0 +1,119 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a small inference example for Gemma3.""" + +import os + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import tqdm +from huggingface_hub import snapshot_download +from jax.sharding import AxisType +from transformers import Gemma3Processor + +from bonsai.models.gemma3 import modeling, params +from bonsai.utils import Sampler + + +def make_input(processor, dtype=torch.float32, msg1=True): + url_prefix = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main" + url = "pipeline-cat-chonk.jpeg" if msg1 else "bee.jpg" + prompt = "What is shown in this image?" if msg1 else "Describe this image in detail." + img_key = "url" if msg1 else "image" + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", img_key: f"{url_prefix}/{url}"}, + {"type": "text", "text": prompt}, + ], + }, + ] + + t_inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=dtype) + + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + return n_text, n_img, n_tti + + +def run_model(): + model_name: str = "google/gemma-3-4b-it" + access_token = os.environ["HF_TOKEN"] + processor = Gemma3Processor.from_pretrained(model_name, token=access_token, use_fast=False) + + fsdp, tp = modeling.ShardMode.FSDP.value, modeling.ShardMode.TP.value + + mesh = jax.make_mesh(((1, 1)), (fsdp, tp), axis_types=(AxisType.Explicit, AxisType.Explicit)) + jax.set_mesh(mesh) + + model_ckpt_path = snapshot_download(model_name, token=access_token) + + bonsai_config = modeling.ModelConfig.gemma3_4b_it(norm_dtype=jnp.float32) + bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, bonsai_config, mesh=mesh) + eot_token_id = processor.tokenizer.convert_tokens_to_ids("") + + # Make inputs + n_text, n_img, n_tti = make_input(processor) + gen_steps = 256 + batch_size, num_tokens = n_text.shape + cache = modeling.init_cache(bonsai_config, batch_size, num_tokens, gen_steps, jnp.float32) + + source_key = jax.random.key(0) + sampler = jax.jit(Sampler(temperature=1.0, top_p=0.8, top_k=10)) + + all_tokens = [n_text] + pbar = tqdm.trange(gen_steps, desc="Generating output") + + # Prefill + segment_ids = jnp.ones((batch_size, num_tokens)) + out, cache = modeling.forward(bonsai_model, cache, n_text, n_img, segment_ids, n_tti) + + source_key, key = jax.random.split(source_key) + n_text = sampler(out, key=key) + pbar.update(1) + all_tokens.append(n_text) + + # Decode + n_tti = jnp.zeros((batch_size, 1), dtype=jnp.int32) + n_img, num_tokens = None, 1 + segment_ids = jnp.ones((batch_size, num_tokens)) + + for _ in pbar: + out, cache = modeling.forward(bonsai_model, cache, n_text, n_img, segment_ids, n_tti) + source_key, key = jax.random.split(source_key) + n_text = sampler(out, key=key) + if jnp.all(n_text == eot_token_id): + pbar.close() + print("Hit end of turn.") + break + all_tokens.append(n_text) + + full_tokens = torch.tensor(jnp.concat(all_tokens, axis=1)) + out_tokens = processor.decode(full_tokens[0], skip_special_tokens=True) + print(out_tokens) + + +if __name__ == "__main__": + run_model() diff --git a/bonsai/models/gemma3/tests/test_outputs_gemma3.py b/bonsai/models/gemma3/tests/test_outputs_gemma3.py new file mode 100644 index 00000000..102a0629 --- /dev/null +++ b/bonsai/models/gemma3/tests/test_outputs_gemma3.py @@ -0,0 +1,671 @@ +import os +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from jax import P +from jax.sharding import AxisType +from tqdm import trange +from transformers import AutoProcessor +from transformers.cache_utils import DynamicCache +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.models.gemma3 import Gemma3ForConditionalGeneration +from transformers.models.gemma3.modeling_gemma3 import token_type_ids_mask_function + +from bonsai.models.gemma3 import modeling, params + +# used for skipping smaller tests +SKIP_INTERMEDIATE_TESTS: bool = False + +# used to set highest precision on matrix multiplication for testing +jax.config.update("jax_default_matmul_precision", "highest") + + +def check_hf_token(): + try: + access_token = os.environ["HF_TOKEN"] + AutoProcessor.from_pretrained("google/gemma-3-4b-it", token=access_token, use_fast=False) + except Exception as e: + print("Failed to access HF_TOKEN or download Processor:") + print(e) + return True + return False + + +@unittest.skipIf(check_hf_token(), "Skipping TestModuleForwardPasses due to HF_TOKEN failure.") +class TestModuleForwardPasses(absltest.TestCase): + # Using this for faster testing. This way we can avoid reloading the model. + # Make sure not to modify the Gemma3 model in inconsistent ways between tests. + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_name: str = "google/gemma-3-4b-it" + cls.torch_device = "cpu" + access_token = os.environ["HF_TOKEN"] + + # attempt model download + cls.processor = AutoProcessor.from_pretrained(cls.model_name, token=access_token, use_fast=False) + cls.torch_model = ( + Gemma3ForConditionalGeneration.from_pretrained(cls.model_name, dtype="auto") + .to(device=cls.torch_device, dtype=torch.float32) + .eval() + ) + cls.torch_config = cls.torch_model.config + + cls.mesh = jax.make_mesh(((1, 1)), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit)) + jax.set_mesh(cls.mesh) + cls.bonsai_config = modeling.ModelConfig.gemma3_4b_it(norm_dtype=jnp.float32) + model_ckpt_path = snapshot_download(cls.model_name, token=access_token) + cls.bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, cls.bonsai_config, mesh=cls.mesh) + + def _upgrade_dtypes(self): + self.bonsai_model.embed_tokens.weight.embedding.value = ( + self.bonsai_model.embed_tokens.weight.embedding.value.astype(jnp.float32) + ) + return + + def _make_torch_input(self): + # returns model inputs: + # KEY SHAPE DTYPE + # input_ids torch.Size([1, 281]) int64 + # attention_mask torch.Size([1, 281]) int64 + # token_type_ids torch.Size([1, 281]) int64 + # pixel_values torch.Size([1, 3, 896, 896]) bfloat16 -> float32 + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + out = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + out["pixel_values"] = out["pixel_values"].to(dtype=torch.float32) + + return {k: v.to(device=self.torch_device) for k, v in out.items()} + + def _make_bonsai_input(self, torch_inputs): + out = dict() + for k, v in torch_inputs.items(): + tmp = v.detach().cpu().numpy() + if k == "pixel_values": + tmp = np.permute_dims(tmp, (0, 2, 3, 1)) + out[k] = tmp + return out + + # This should be correct for unbatched inputs + # Adapted from transformers/models/gemma3/modeling_gemma3.py + def _process_torch_inputs( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + token_type_ids=None, + cache_position=None, + inputs_embeds=None, + labels=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + **lm_kwargs, + ): + # Replace image id with PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.torch_config.image_token_id >= self.torch_config.text_config.vocab_size: + special_image_mask = input_ids == self.torch_config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.torch_model.model.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.torch_model.model.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.torch_model.model.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.torch_config.get_text_config(), + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + is_prefill = ( + not use_cache + or past_key_values is None + or not past_key_values.is_initialized + or pixel_values is not None + ) + if token_type_ids is not None and is_prefill: + is_image = (token_type_ids == 1).to(cache_position.device) + new_image_start = is_image & ~torch.nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where( + is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device) + ) + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device), image_group_ids, self.torch_config.mm_tokens_per_image + ) + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + return dict( + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + # This should be correct for unbatched inputs + # Adapted from transformers/models/gemma3/modeling_gemma3.py + def _process_torch_inputs_for_decoder_text_model( + self, + attn_type, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + **kwargs, + ): + training = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None and not training: + past_key_values = DynamicCache(config=self.torch_model.model.config.text_config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + sliding_mask_kwargs = mask_kwargs.copy() + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + position_embeddings_global = self.torch_model.model.language_model.rotary_emb(inputs_embeds, position_ids) + position_embeddings_local = self.torch_model.model.language_model.rotary_emb_local(inputs_embeds, position_ids) + return dict( + hidden_states=inputs_embeds, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask_mapping[attn_type], + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + # Vision tests + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_image_emb(self): + tm = self.torch_model.model.vision_tower.vision_model.embeddings + nm = self.bonsai_model.vision_tower.embeddings + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["pixel_values"] + nx = n_inputs["pixel_values"] + + with torch.no_grad(): + ty = tm(tx) + ny = nm(nx) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5) + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_siglip_encoder_layer(self): + tm = self.torch_model.model.vision_tower.vision_model.encoder.layers[0] + nm = self.bonsai_model.vision_tower.encoder.layers[0] + + tx = torch.randn((1, 4096, 1152), device=self.torch_device) + nx = tx.detach().cpu().numpy() + + with torch.no_grad(): + ty = tm(tx, None) + ny = nm(nx, None) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4) + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_vision_model(self): + # only have deviations on .0567% of the entries and on order 7e-3 + tm = self.torch_model.model.vision_tower + nm = self.bonsai_model.vision_tower + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["pixel_values"] + nx = n_inputs["pixel_values"] + + with torch.no_grad(): + ty = tm(tx).last_hidden_state + ny = nm(nx) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-2, atol=1e-2) + + # Language tests + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_text_embedding(self): + self._upgrade_dtypes() + tm = self.torch_model.model.language_model.embed_tokens + nm = self.bonsai_model.embed_tokens + + torch.testing.assert_close(torch.tensor(nm.weight.embedding.value), tm.weight.cpu()) + torch.testing.assert_close(torch.tensor(nm.embed_scale), tm.embed_scale.cpu()) + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["input_ids"] + nx = n_inputs["input_ids"] + + with torch.no_grad(): + ty = tm(tx) + ny = nm(nx) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy()) + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_attn_projs(self): + tm = self.torch_model.model.language_model.layers[0].self_attn + nm = self.bonsai_model.language_model.layers[0].self_attn + + tx = torch.randn((1, 281, 2560), device=self.torch_device) + nx = tx.detach().cpu().numpy() + + ty = tm.q_proj(tx) + ny = nm.q_proj(nx, out_sharding=P()) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="q") + + ty = tm.k_proj(tx) + ny = nm.k_proj(nx, out_sharding=P()) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="k") + + ty = tm.v_proj(tx) + ny = nm.v_proj(nx, out_sharding=P()) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="v") + + tx = torch.randn((1, 281, 2048), device=self.torch_device) + nx = tx.detach().cpu().numpy() + ty = tm.o_proj(tx) + ny = nm.o_proj(nx, out_sharding=P()) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="o") + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_attn_norms(self): + tm = self.torch_model.model.language_model.layers[0].self_attn + nm = self.bonsai_model.language_model.layers[0].self_attn + + tx = torch.randn((1, 281, 2048), device=self.torch_device).reshape(1, 281, -1, 256) + nx = tx.detach().cpu().numpy() + + np.testing.assert_allclose( + nm.q_norm.scale.value, tm.q_norm.weight.detach().cpu().numpy(), err_msg="q_norm weights" + ) + + ty = tm.q_norm(tx) + ny = nm.q_norm(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, err_msg="q") + + tx = torch.randn((1, 281, 1024), device=self.torch_device).reshape(1, 281, -1, 256) + nx = tx.detach().cpu().numpy() + + ty = tm.k_norm(tx) + ny = nm.k_norm(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, err_msg="k") + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_sin_cos(self): + batch_size, seq_len, dim = 2, 10, 256 + hidden_states = torch.ones((batch_size, seq_len, dim)) + jp = jnp.stack([jnp.arange(seq_len), jnp.arange(seq_len)]) + + # local uses default + rt = self.bonsai_config.text_config.rope_slide_theta + js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=1.0) + rot_emb = self.torch_model.model.language_model.rotary_emb_local + tc, ts = rot_emb(hidden_states, torch.tensor(jp)) + tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] + torch.testing.assert_close(torch.tensor(js), ts) + torch.testing.assert_close(torch.tensor(jc), tc) + + # global uses linear + rt = self.bonsai_config.text_config.rope_full_theta + js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=8.0) + rot_emb = self.torch_model.model.language_model.rotary_emb + tc, ts = rot_emb(hidden_states, torch.tensor(jp)) + tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] + torch.testing.assert_close(torch.tensor(js), ts) + torch.testing.assert_close(torch.tensor(jc), tc) + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_text_decoder_layers(self): + first_t_inputs = self._make_torch_input() + start_t_inputs = self._process_torch_inputs(**first_t_inputs) + + for test_layer in trange(34): + # Models + tm = self.torch_model.model.language_model.layers[test_layer] + nm = self.bonsai_model.language_model.layers[test_layer] + attn_type = tm.attention_type + + # Inputs + t_inputs = self._process_torch_inputs_for_decoder_text_model(attn_type, **start_t_inputs) + nx = t_inputs["hidden_states"].detach().cpu().numpy() + batch_size, num_tokens, _ = nx.shape + nnx_cache = modeling.init_cache( + cfg=self.bonsai_config, batch_size=batch_size, token_len=num_tokens, generate_steps=1, dtype=jnp.float32 + ) + n_tti = first_t_inputs["token_type_ids"].detach().cpu().numpy() + + if attn_type == "full_attention": + mask = modeling.make_causal_mask(nnx_cache[test_layer], n_tti, out_sharding=P()) + else: + mask = modeling.make_window_mask(nnx_cache[test_layer], n_tti, 1024, out_sharding=P()) + + # run models + ty = tm(**t_inputs) + ny = nm(nx, nnx_cache[test_layer], jnp.ones((batch_size, num_tokens)), mask=mask) + + t_inputs["hidden_states"] = ty[0] + + found_exception = False + try: + np.testing.assert_allclose( + ny, ty[0].detach().cpu().numpy(), rtol=5e-3, atol=5e-3, err_msg=f"{test_layer}" + ) + except Exception as e: + print(e) + found_exception = True + assert not found_exception, "FOUND EXCEPTION" + + # multi modal tests + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_multi_modal_projector(self): + t_inputs = self._make_torch_input() + tm = self.torch_model.model + nm = self.bonsai_model.multi_modal_projector + + tx = tm.vision_tower(t_inputs["pixel_values"]).last_hidden_state + nx = tx.detach().cpu().numpy() + + ty = tm.multi_modal_projector(tx) + ny = nm(nx) + + torch.testing.assert_close(torch.tensor(ny), ty, rtol=1e-4, atol=1e-4) + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_text_image_merge(self): + nm = self.bonsai_model + t_inputs = self._make_torch_input() + t_out = self._process_torch_inputs(**t_inputs) + + # answer is input_embeds + t_ans = t_out["inputs_embeds"] + + tmp = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_text = nm.embed_tokens(tmp) + + # return + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_img = nm.vision_tower(n_img) + n_img = nm.multi_modal_projector(n_img) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + n_ans = modeling.batched_merge_modalities(n_img, n_text, n_tti) + + np.testing.assert_allclose(n_ans, t_ans.detach().cpu().numpy(), rtol=1e-3, atol=1e-3) + + @unittest.skipIf(SKIP_INTERMEDIATE_TESTS, "Done") + def test_masks(self): + # Make a really long input so we can test the sliding window + # This only tests for the pre-fill stage + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail." + "hello " * 1500}, + ], + }, + ] + + t_inputs = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=torch.float32) + + batch_size, num_tokens = t_inputs["input_ids"].shape + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + gen_steps = 10 + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, gen_steps) + n_mask = modeling.make_causal_mask(cache[0], n_tti, out_sharding=P()) + + # Full attention + t_inputs = self._process_torch_inputs(**t_inputs) + t_mask = t_inputs["attention_mask"]["full_attention"] + size_for_comp = t_mask.shape[-1] + + np.testing.assert_allclose(n_mask[:, :, :, :size_for_comp], t_mask.detach().cpu().numpy()) + + # Sliding attention + t_mask = t_inputs["attention_mask"]["sliding_attention"] + n_mask = modeling.make_window_mask( + cache[0], n_tti, self.bonsai_config.text_config.sliding_window, out_sharding=P() + ) + + np.testing.assert_allclose(n_mask[:, :, :, :size_for_comp], t_mask.detach().cpu().numpy()) + + @unittest.skip("Skipping - this test is just to observe errors over full model evaluation") + def test_full_in_order(self): + tm = self.torch_model.model + nm = self.bonsai_model + + # Torch inputs + t_inputs = self._make_torch_input() + + # NNX inputs + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + # Get masks + n_causal_mask = modeling.make_causal_mask(cache[0], n_tti, out_sharding=P()) + n_sliding_mask = modeling.make_window_mask( + cache[0], n_tti, self.bonsai_config.text_config.sliding_window, out_sharding=P() + ) + + # text embeds + t_inputs_embeds = tm.language_model.embed_tokens(t_inputs["input_ids"]) + n_inputs_embeds = nm.embed_tokens(n_text) + np.testing.assert_allclose(n_inputs_embeds, t_inputs_embeds.detach().cpu().numpy(), err_msg="text emb") + + # Vision part + t_vis = tm.vision_tower(t_inputs["pixel_values"]).last_hidden_state + n_vis = nm.vision_tower(n_img) + # Mismatched elements: 4608354 / 4718592 (97.7%) + # Max absolute difference among violations: 0.00756264 + # Max relative difference among violations: 15.521739 + np.testing.assert_allclose(n_vis, t_vis.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="vis tower") + + # MM Proj part + t_img_feat = tm.multi_modal_projector(t_vis) + n_img_feat = nm.multi_modal_projector(n_vis) + # Mismatched elements: 648574 / 655360 (99%) + # Max absolute difference among violations: 0.00063944 + # Max relative difference among violations: 20.392141 + np.testing.assert_allclose( + n_img_feat, t_img_feat.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="mm proj" + ) + + # Merging part + special_image_mask = tm.get_placeholder_mask( + t_inputs["input_ids"], inputs_embeds=t_inputs_embeds, image_features=t_img_feat + ) + t_inputs_embeds = t_inputs_embeds.masked_scatter(special_image_mask, t_img_feat) + n_inputs_embeds = modeling.batched_merge_modalities(n_img_feat, n_inputs_embeds, n_tti) + # Mismatched elements: 648574 / 719360 (90.2%) + # Max absolute difference among violations: 0.00063944 + # Max relative difference among violations: 20.392141 + np.testing.assert_allclose( + n_inputs_embeds, t_inputs_embeds.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="merge" + ) + + # Text part in order + t_inputs["output_hidden_states"] = True + t_text_inputs = self._process_torch_inputs(**t_inputs) + t_hidden_states = tm.language_model(**t_text_inputs).hidden_states + assert len(t_hidden_states) - 1 == len(nm.language_model.layers), ( + f"{len(t_hidden_states)} vs {len(nm.language_model.layers)}" + ) + + # check inputs + nx = n_inputs_embeds + + n_hidden_states = [] + for i, layer in enumerate(nm.language_model.layers): + attn_type = tm.language_model.layers[i].attention_type + n_mask = n_causal_mask if attn_type == "full_attention" else n_sliding_mask + n_hidden_states.append(nx) + nx = layer(nx, cache[i], segment_ids, n_mask) + nx = nm.language_model.norm(nx) + n_hidden_states.append(nx) + + for i, (nval, tval) in enumerate(zip(n_hidden_states, t_hidden_states)): + try: + np.testing.assert_allclose(nval, tval.detach().cpu().numpy(), err_msg=f"text {i}") + except Exception as e: + print(e) + found_error = True + assert not found_error, "Found errors in text decoder layers" + # NOTE: some errors are expected here since errors compound with layer + + def test_full(self): + tm = self.torch_model + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + ny = nm(n_text, n_img, cache, segment_ids, n_tti) + ty = tm(**t_inputs) + + torch.testing.assert_close(torch.tensor(ny), ty.logits, rtol=5e-2, atol=5e-2) + + @unittest.skip("TODO") + def test_full_batched(self): + tm = self.torch_model + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + # Test simple batching + n_img = jnp.concat([n_img, n_img]) + n_text = jnp.concat([n_text, n_text]) + n_tti = jnp.concat([n_tti, n_tti]) + + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + ny = nm(n_text, n_img, cache, segment_ids, n_tti) + ty = tm(**t_inputs) + + torch.testing.assert_close(torch.tensor(ny)[0:1], ty.logits, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(torch.tensor(ny)[1:2], ty.logits, rtol=5e-2, atol=5e-2) + + raise NotImplementedError("Need to get more complex batched inputs working") + # When doing batching, prompts have >= 0 images (not all same) -> change batched_merge_modalities + # for this, we might also need to keep track of where images came from + # We also need to update the left padding to deal with different padding for each prompt + + +if __name__ == "__main__": + absltest.main() diff --git a/bonsai/models/gemma3/tests/test_sharding_gemma3.py b/bonsai/models/gemma3/tests/test_sharding_gemma3.py new file mode 100644 index 00000000..f5b9d0aa --- /dev/null +++ b/bonsai/models/gemma3/tests/test_sharding_gemma3.py @@ -0,0 +1,89 @@ +import os +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from flax import nnx +from jax import P +from jax.sharding import AxisType +from transformers import AutoProcessor + +from bonsai.models.gemma3 import modeling +from bonsai.models.gemma3.tests.test_outputs_gemma3 import check_hf_token + + +@unittest.skipIf(check_hf_token(), "Skipping TestSharding due to HF_TOKEN failure.") +class TestSharding(absltest.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_name: str = "google/gemma-3-4b-it" + access_token = os.environ["HF_TOKEN"] + cls.processor = AutoProcessor.from_pretrained(cls.model_name, token=access_token, use_fast=False) + cls.torch_device = "cpu" + + fsdp, tp = modeling.ShardMode.FSDP.value, modeling.ShardMode.TP.value + + cls.mesh = jax.make_mesh(((1, 1)), (fsdp, tp), axis_types=(AxisType.Explicit, AxisType.Explicit)) + jax.set_mesh(cls.mesh) + + cls.bonsai_config = modeling.ModelConfig.gemma3_4b_it(True, True, norm_dtype=jnp.float32) + cls.bonsai_model = modeling.Gemma3Model(cls.bonsai_config, rngs=nnx.Rngs(0)) + + def _make_torch_input(self): + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + out = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + out["pixel_values"] = out["pixel_values"].to(dtype=torch.float32) + + return {k: v.to(device=self.torch_device) for k, v in out.items()} + + def test_full(self): + nm = self.bonsai_model + fsdp = modeling.ShardMode.FSDP.value + + t_inputs = self._make_torch_input() + + n_img = jnp.array( + np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1)), out_sharding=P(fsdp) + ) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy(), out_sharding=P(fsdp)) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy(), out_sharding=P(fsdp)) + + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens), out_sharding=P(fsdp)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + nm(n_text, n_img, cache, segment_ids, n_tti) + + @unittest.skip("Only for viewing purposes") + def test_view_model(self): + state = nnx.state(self.bonsai_model) + out = jax.tree_util.tree_map(lambda x: jax.typeof(x), state) + + # print(out) + # print(out.vision_tower) + # print(out.language_model) + # print(out.embed_tokens) + print(out.multi_modal_projector) + + +if __name__ == "__main__": + absltest.main() diff --git a/pyproject.toml b/pyproject.toml index 02b96692..2061c1bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,13 +97,6 @@ select = [ "YTT", "ASYNC", "E101", - "E112", - "E113", - "E115", - "E117", - "E225", - "E227", - "E228", ] [tool.ruff.format]