Skip to content

Conversation

@Aatman09
Copy link
Contributor

@Aatman09 Aatman09 commented Jan 4, 2026

Resolves #107

Reference
This implementation is based on the following tutorial:
JAX Machine Translation Tutorial

Changes made

  • Added dataclass-based configuration for improved clarity and structure
  • Enhanced the tutorial with additional Markdown explanations for better readability

Notes

  • Key–value (KV) caching has been left out

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@chapman20j
Copy link
Collaborator

Hi @Aatman09 . Thank you for the nice commit. Could you please include a few pip installs at the beginning of the notebook for additional dependencies. Please also include their versions. e.g. ! pip install "grain==0.2.15. Also, please ensure that this notebook runs on colab.

@chapman20j
Copy link
Collaborator

For the KV cache, this would be nice to add in the Use Model For Inference section. Using caching makes the inference faster by allowing attention to re-use the previously computed k and v tensors. This gives you two options

  1. Implement your own caching logic
  2. Change the flags for the attention layers

I think option 2 makes the most sense for this tutorial so it doesn't get too in the weeds on the cache. Implementing your own caching may also require writing your own attention layers. For more details, the nnx docs cover how to initialize a cache (https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html). This can be done with .init_cache or the .set_mode methods. Please let me know if you'd like any further clarification or more discussion around this.

@Aatman09
Copy link
Contributor Author

Aatman09 commented Jan 8, 2026

Thank you for the review I will implement the changes as soon as possible

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Port Encoder-Decoder Example from Jax AI Stack

3 participants