From-scratch implementations of core transformer concepts in Jax NumPy.
For example - an implementation of Google's T5 model, all the way from a linear layer to autoregressive encoder-decoder text generation.
Or OpenAI's GPT2 (which agrees on the difficulty of formatting CLI printing):
(See examples/t5_generation.ipynb
)
# Load pytorch model
torch_t5_model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
# Copy weights to jax state
state = to_jax_state(torch_t5_model)
jax_t5_model = T5Model(vocab_size=32128, emb_size=512)
rng = jax.random.PRNGKey(42)
INPUT = "translate English to French: Hello!"
input_ids = jnp.array(tokenizer(INPUT, return_tensors="pt").input_ids)
pred_token_ids = jax_t5_model.generate(state, input_ids, rng, max_length=40)
pred_text = tokenizer.decode(pred_token_ids[0], skip_special_tokens=True)
print(pred_text)
# >> Bonjour!
See examples/example_dense.ipynb
.
Python version: 3.11.6
Run tests with pytest
Filter for a specific test to run with pytest -k test_name
Enable debug logs with pytest --log-cli-level=DEBUG
Run formatting with black .
-
Implement softmax
-
Implement network
-
Implement batch norm
-
Implement vanilla attention
-
Implement multi headed attention
-
Add tests to compare with pytorch implementations
-
Simplify logging
-
Switch to
__name__
logging -
Switch to using single Array type for shape hinting
-
Switch to using
poetry
for dependencies -
Attention Masking
-
Layer Norm
-
Dropout
-
Positional Encoding
-
Tokenizer (tiktoken)
-
Encoder
-
Decoder
-
Transformer
-
Seq2Seq
-
T5
- Inference using random weights
- Inference using t5-small weights
- Beam search
- CLI inference example
-
Wav2Vec
-
GPT2
- Dense
- Attention
- Fix interfering test (fixed by limiting the scope of each test to function)
- GPT2Block
- GPT2Model
- GPT2 Generation
-
Use torch hooks to debug tests
-
Implement hooks for basemodel
-
Switch to uv for package management
-
RMSNorm
-
SwiGLU layer
-
GeGLU
-
Rotational Positional Encoding
-
K/V Cache
-
Multi-Query Attention (https://arxiv.org/abs/1911.02150)
-
Grouped Query Attention
-
Vectorize stuff with vmap
-
Windowed attention with recomputation
-
Streaming LLM attention
- Snake activation
source .venv/bin/activate
uv add <package>
uv run pytest
Tree Map: https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html
Stateful computation: https://jax.readthedocs.io/en/latest/jax-101/07-state.html
Efficient Streaming Attention: https://arxiv.org/abs/2309.17453
Softmax off by one: https://www.evanmiller.org/attention-is-off-by-one.html
ML Library in Jax: https://flax.readthedocs.io/en/latest/
NN Example: https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
Illustrated transformer: https://jalammar.github.io/illustrated-transformer/
Annotated transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
Pytorch implementation https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L5101
Annotated pytorch implementation https://nn.labml.ai/transformers/mha.html https://nn.labml.ai/transformers/models.html
Attention explained https://storrs.io/attention/
https://math.stackexchange.com/questions/63074/is-there-a-3-dimensional-matrix-by-matrix-product
Self-attention Does Not Need O(n^2) Memory https://arxiv.org/abs/2112.05682
Reformer: The Efficient Transformer https://arxiv.org/abs/2001.04451
RoFormer: Enhanced Transformer with Rotary Position Embedding https://arxiv.org/abs/2104.09864
Fast Transformer Decoding: One Write-Head is All You Need https://arxiv.org/abs/1911.02150 (Multi-Query Attention)
Llama2: https://arxiv.org/abs/2307.09288
GLU Variants Improve Transformer: https://arxiv.org/abs/2002.05202