Example Jupyter notebook for training a simple neural network to learn a sine function using JAX. This notebook uses the Equinox library for neural networks, Optax for optimisation, PyTorch for dataloaders, and the JIT functionality of JAX to speed up training.
This material is heavily inspired by these excellent tutorials: