This is my personal study note for Jax and Flax.
Some code snippets are copy paste from the tutorials below but I made them more complete and easier to understand, e.g., end to end code with trivial training loop, trivial dataset, etc.
I also added various code snippets for the areas that the docs don't explain, e.g., writing stateful LSTM, categorical focal cross entropy, and so on.