This is just an illustrative example of preparing a PyTorch model for beeing used from JVM environment.
- The Problem and the data
- Exploration
- Train
- Predict
- Reduce the model size
- Interpret the model
- Optimizations
Predict a group of the Yest gene
A Multi-class classification \w structured data in libsvm format.
- classes: 14
- features: 103
- data points: 1,500 (train) / 917 (test)
TODO: add a notebook checking the dataset for imbalanced classes.
# install dependencies
virtualenv .venv
source .venv/bin/activate
pip install -r requirements.txt
# get the data
wget 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_train.svm.bz2'
wget 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_test.svm.bz2'
bzip2 -d *.bz2
# test the dataloader
./train.py
# train the model
./train.py --model model.pt
Does not include logging, early stopping, model checkpointing and lots of other nice goodies.
But PyTorch Lightning does include all that, and many more for free 🎉
./train_ptl.py --model models/ptl/model.pt
Monitor the progess \w tensorboard
tensorboard --logdir lightning_logs/
open http://localhost:6667
3 epochs of 400it/s result in precision 0.768 when the original paper has 0.762.
PyTorch inference in Python
./predict.py --model model.pt < single_example.txt
Correct anser is 2, 3
.
TODO https://github.com/pytorch/java-demo/blob/master/src/main/java/demo/App.java
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
In: trained model.pt
Out: model.onnx
./onnx_export.py --model model.pt --out model.onnx
./onnx_predict.py --model model.onnx < single_example.txt
Using JNI-based Java API of ONNX JVM Runtime
cp model.onnx onnx-predict-java/src/main/resources/
cd onnx-predict-java
./gradlew jar
java -jar ./build/libs/onnx-predict-java.jar < single_example.txt`
- see this for discussion on JNI and multipel classloader support
- ONNX Runtime dependency is 92Mb
Explore different NN architectures
- Deep & Cross Netowrk (DCN) paper, posts, tutorial, PyTorch impl
Architecture-neutural optimizations
- fp16 quantization-aware training \w PTL (GPU-only)
- hyperparameter search \w PTL for layers dimensions
- 8bit dynamic quantilization \w
torch.quantization.quantize_dynamic
(tutorial) - pruning \w
torch.nn.utils.prune
(tutorial) - bayesian hyperparameter optimization \w Optuna, estimating importance
Model | Params | On disk | Train time |
---|---|---|---|
fp32 mlp | 52kb | ||
onnx mlp | 48kb | ||
fp16 mlp | ? | ||
8bit mlp | ? | ||
fp32 mlp+hyperopt | ? | ||
fp32 dcn | ? |
How important are some of the features? Explain, how it’s weights contribute towards it’s final decision.
- Primary attribution \w integrated gradients for feature importance using https://captum.ai
- PTL profiler
- new PyTorch profiler
- is model execution time dominated by loading weights from memory or computing the matrix multiplications?