-
Notifications
You must be signed in to change notification settings - Fork 35
Gemma3 initial commit #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Updated configs Moved embed_tokens to more natural place Updated run_model to use sampler and stop at end_of_turn token Added test_sharding_gemma3 Added batched forward test. Need more complex behavior and testing
|
Also, please make sure your selective tests are passing |
| def init_cache( | ||
| cfg: ModelConfig, batch_size: int, token_len: int, generate_steps: int, dtype: jnp.dtype = jnp.bfloat16 | ||
| ) -> Cache: | ||
| cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1))) # Pad for a sharding-friendly size. | ||
| return [ | ||
| LayerCache(cfg.text_config, batch_size, cache_size, dtype) for _ in range(cfg.text_config.num_hidden_layers) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this is a global KV cache implementation and doesn't introduce kv cache mem reduction benefits in gemma 3 -
Could we have something like this to account for local vs. global?
def init_cache(...) -> Cache:
full_cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1)))
window_size = cfg.text_config.sliding_window # Typically 1024
caches = []
for i, layer_type in enumerate(cfg.text_config.layer_types):
size = full_cache_size if layer_type == AttentionMode.FULL else window_size
caches.append(LayerCache(cfg.text_config, batch_size, size, dtype))
return caches
We can also fix the k / v cache update parts to be based on window_size modulo.
Resolves #100
Adds the Gemma3 model to the Bonsai Repo. This first commit is just a working version, but I am still working on optimizing it.
Reference
Refer to Issue #100
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).