Skip to content

Conversation

@chapman20j
Copy link
Collaborator

Resolves #100
Adds the Gemma3 model to the Bonsai Repo. This first commit is just a working version, but I am still working on optimizing it.

Reference
Refer to Issue #100

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

Updated configs
Moved embed_tokens to more natural place
Updated run_model to use sampler and stop at end_of_turn token
Added test_sharding_gemma3
Added batched forward test. Need more complex behavior and testing
@jenriver
Copy link
Member

Also, please make sure your selective tests are passing

Comment on lines 538 to 544
def init_cache(
cfg: ModelConfig, batch_size: int, token_len: int, generate_steps: int, dtype: jnp.dtype = jnp.bfloat16
) -> Cache:
cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1))) # Pad for a sharding-friendly size.
return [
LayerCache(cfg.text_config, batch_size, cache_size, dtype) for _ in range(cfg.text_config.num_hidden_layers)
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is a global KV cache implementation and doesn't introduce kv cache mem reduction benefits in gemma 3 -

Could we have something like this to account for local vs. global?

def init_cache(...) -> Cache:
    full_cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1)))
    window_size = cfg.text_config.sliding_window # Typically 1024
    
    caches = []
    for i, layer_type in enumerate(cfg.text_config.layer_types):
        size = full_cache_size if layer_type == AttentionMode.FULL else window_size
        caches.append(LayerCache(cfg.text_config, batch_size, size, dtype))
    return caches

We can also fix the k / v cache update parts to be based on window_size modulo.

jenriver
jenriver previously approved these changes Jan 8, 2026
@jenriver jenriver merged commit 36896a5 into jax-ml:main Jan 8, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gemma3

2 participants