Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions bonsai/models/gemma3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Qwen3 in JAX

This directory contains a pure JAX implementation of the [Gemma3 model](https://deepmind.google/models/gemma/gemma-3/), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API.

Note that you need an access token to download the model weights. In order to run the scripts, make sure to save an environment variable `HF_TOKEN` with your huggingface access token.


## Model Configuration Support Status


### Running this model


```sh
python3 -m bonsai.models.gemma3.tests.run_model
```


## How to contribute to this model

### Remaining Tasks

1. Implement with batching. Need this for FSDP.
2. Optimize based on the profiling.
3. Clean up code (variable names, etc.). Simplify unused configs (marked these with TODO) or use them.
4. Update to include other model sizes and optimize parameter loading.
Loading