|
| 1 | +# NN PRUNING RESEARCH REPORT (WIP) |
| 2 | + |
| 3 | + |
| 4 | + |
| 5 | +Origin of the project |
| 6 | +--- |
| 7 | + |
| 8 | +The project started with the observation that unstructured sparsity is hard to accelerate, because parallel machines |
| 9 | +need some form of data regularity/locality to be able to be used at peak performance. |
| 10 | + |
| 11 | +Using blocks with various shape is the most simple form of regularity that can be introduced to improved performance. |
| 12 | + |
| 13 | +Pytorch Block Sparse |
| 14 | +--- |
| 15 | +Because that kind of feature was not available, or for different software stack, the first step that I chose was to implement a library for fast block sparse basic algebra to measure this. |
| 16 | +[pytorch_block_sparse](https://github.com/huggingface/pytorch_block_sparse), using CUDA / CUTLASS kernels, is the result of this first step. |
| 17 | + |
| 18 | +Creating such a library, and reaching good performances, was quite a challenge. |
| 19 | +Diving low into the lower software layers and the hardware details takes some significant time and effort. |
| 20 | + |
| 21 | +Other tools could be used, like the Triton Compiler, but at the time I did not consider it stable enough. |
| 22 | + |
| 23 | +The result of this work was a library that was able to reach parity with dense code when sparsity >= 60%. |
| 24 | +That is a decent result, but better levels of performance are probably possible. |
| 25 | + |
| 26 | +The positive point is that it saves memory proportionally to sparsity, and it's a drop-in replacement for pytorch Linear layers. |
| 27 | + |
| 28 | +Block Movement Pruning |
| 29 | +--- |
| 30 | +The level of pruning Movement Pruning was able to reach on several fine-tuning tasks was very interesting. |
| 31 | +This was a good candidate to test block pruning instead of unstructured sparsity. |
| 32 | + |
| 33 | +A lot of different experiments were tested (~250, producing more than 5000 models). |
| 34 | +The parameter space is large, and this is a subset of what could be tested: |
| 35 | +Parameter space is described at least by: |
| 36 | + - pruning algorithms (those proposed in Movement Pruning: magnitude, topK, threshold, sigmoied_threshold...) |
| 37 | + - block shape (1x1 up to 64x64, square or rectangular, entire heads for BERT = 64*768...) |
| 38 | + - attention vs FFNs (not using the same pruning algorithm for each part is actually a good idea) |
| 39 | + - pruning speed (number of epochs used to transition from 100% to N%) |
| 40 | + - pretrained network, and teacher (BERT-base/large ...) |
| 41 | + |
| 42 | +It was not economically feasible to sweep this parameter space, and of course some initial observations were instrumental |
| 43 | + into introducing new dimensions in the search (like applying different pruning on attention and FFNs). |
| 44 | +Most of the models were built on a RTX 3090 on a home server, running almost 24-7 for > 4 months. |
| 45 | +Some were created on Amazon SageMaker, for simple sweeps, block shape tests for example, and it could be useful to run |
| 46 | + some more. |
| 47 | + |
| 48 | +The main task that was tested was SQuAD v1, and MNLI. |
| 49 | +Now that the library has been released, some external users are starting to test it on their own dataset (as of 11th of March 2021). |
| 50 | + |
| 51 | +SQuADv1 |
| 52 | +--- |
| 53 | +The following experiments were tested: |
| 54 | + |
| 55 | +- Block pruning for both attention and FFNs |
| 56 | + |
| 57 | + Main observations: |
| 58 | + - 32x32 block pruning approaches 1x1 performance (1x1 = unstructured), when using twice the number of epochs |
| 59 | + |
| 60 | +- Movement Pruning with twice the epochs (20 instead of 10) |
| 61 | + |
| 62 | + This resulted in a large improvement on 1x1 blocks, showing that the reachable accuracy at same level of sparsity was way higher than initially obtained in Movement Pruning Paper. |
| 63 | + |
| 64 | + Experiment conclusions: |
| 65 | + - 32x32 is good, but 1x1 is still better (makes more sense) |
| 66 | + - Searching for even better hyper-parameters for Movement Pruning may lead to further improvement |
| 67 | + |
| 68 | +- 1D pruning |
| 69 | + |
| 70 | + One additional observations from the initial block pruning experiment was that the pruning patterns seems mostly unidimensional. |
| 71 | + It was mentioned too by Victor Sanh on the original movement pruning models, but with blocks the pattern is much |
| 72 | + more massive: full attention heads are quickly pruned when using 32x32 blocks. |
| 73 | + |
| 74 | + The next step was then to prune rows or col (1d pruning) |
| 75 | + Experiment conclusions: |
| 76 | + - The results are not very good on attention |
| 77 | + An hypothesis is that the head dimension is too small, so removing even a small part of a 64 dimensions key have a too large impact when |
| 78 | + doing the dot-product with a 64 dimensions value. |
| 79 | + To be tested: row/col pruning but only on value and output layers. |
| 80 | + But the computation structure makes it not that efficient compared to full head pruning (extracting a subset of input dimensions is not very fast) |
| 81 | + - 1D pruning is good on FFNs, especially when you apply the same mask for 1st FFN rows and 2nd FFN cols (that makes sense as they are following each other) |
| 82 | + |
| 83 | +- Hybrid pruning: |
| 84 | + |
| 85 | + The next logical step was to use blocks for attention, and rows/cols for FFNs |
| 86 | + This leads to significant improvement compared to the "block everywhere" strategy. |
| 87 | + |
| 88 | + The resulting pruned network is faster because of: |
| 89 | + - head pruning -> smaller matrices, heads are still block sparse |
| 90 | + - structured (rows/cols) pruning of FFNs -> smaller, but dense matrices |
| 91 | + |
| 92 | +- Optional final step: densification of attention heads |
| 93 | + |
| 94 | + As noted in the previous step, we removed some heads, but the remaining heads are block sparse. |
| 95 | + |
| 96 | + The problem is that we don't have an efficient block sparse library for low sparsity, and even if we had, it would probably not be available everywhere. |
| 97 | + |
| 98 | + To get around this, we can successfully use the following method: |
| 99 | + |
| 100 | + - fill attention empty blocks with a small random noise (to avoid null gradients) |
| 101 | + - Perform some epochs of fine tuning, using the same distillation parameters used previously. |
| 102 | + |
| 103 | + As a consequence: |
| 104 | + - we get back some accuracy: F1 is higher |
| 105 | + - but the model speed is the same, as we have the same heads, and FFNs shape is unchanged: the tradeoff speed/F1 is improved |
| 106 | + - we increased the non-zero parameters (but we observe that the sparsity/F1 tradeoff is actually not different, the curve is just translated) |
| 107 | + - we have a fully dense network, no special runtime is needed |
| 108 | + |
| 109 | +- Further improvements |
| 110 | + |
| 111 | + The previous experiments can be improved or modified using some orthogonal changes: |
| 112 | + |
| 113 | + - Using a large teacher : even for 75% sparse networks the difference in accuracy is significant |
| 114 | + - Replacing LayerNorm with NoNorm : small drop in F1 (~ 0.5), but MobileBERT claims that it leads to a large latency gain |
| 115 | + |
| 116 | +MNLI |
| 117 | +--- |
| 118 | +The results are similar to TinyBERT and MobileBERT. |
| 119 | +The latest found improvements were not tested yet on MNLI, so some gain is probably possible. |
| 120 | + |
| 121 | +More testing is still needed |
| 122 | + |
| 123 | +To Be Tested |
| 124 | +--- |
| 125 | + - replacing GeLU by ReLU |
| 126 | + - Ampere sparsity (code is 90% there, but needs some fixes) |
| 127 | + - Quantization impact (quantized version for bert is now 40MB instead of 400MB initially, with 1% drop in F1 ) |
| 128 | + - Real-life benchmarks (will probably be better than current numbers that are quite conservative) |
| 129 | + - More tasks |
| 130 | + - Improve integration with Transformers (varying sizes for FFNs) |
| 131 | + - Original Movement Pruning with large teacher |
| 132 | + |
| 133 | + |
| 134 | +Block Pruning Pros and Cons |
| 135 | +--- |
| 136 | + |
| 137 | +- Pros |
| 138 | + - Smaller |
| 139 | + - Faster |
| 140 | + - No special runtime |
| 141 | + |
| 142 | +- Cons |
| 143 | + - Precision loss for large speedups. But the method offers a continuous tradeoff between speed/size and accuracy contrary to optimized pretrained networks like MobileBERT. |
| 144 | + - Much larger fine-tuning time (14H instead of 1H when using all tricks). But the idea is to train once and then inference time is much more important |
| 145 | + - Minor: needs to patch network after loading with transformers to remove empty parts (head pruning is native, but not for the TF version). |
| 146 | + |
| 147 | + |
| 148 | +Software & Packaging |
| 149 | +--- |
| 150 | + |
| 151 | +#### Principles |
| 152 | +The original Movement Pruning code was using its own version of the BERT source code, with MaskedLinear layers. |
| 153 | + |
| 154 | +One of the goal of the nn_pruning library was to make it possible to prune any network (including non-transformers architectures). |
| 155 | +The library would be hardly attractive if the first step to do so was to rewrite the source code of the source model. |
| 156 | +Maintaining an up to date version of the rewritten source code should be avoided too. |
| 157 | + |
| 158 | +Python and Pytorch flexibility makes it really easy to patch dynamically existing models. |
| 159 | + |
| 160 | +To apply the nn_library on a new model, you just have to write a small set of regular expressions to define the layers you want to target. |
| 161 | +Using Pytorch module iterator, it's then easy to replace all nn.Linear with MaskedLinear, without any source code modification. |
| 162 | +Rewriting back the network to its original form is done in the same way, just as the final pruning of the heads or parts of the |
| 163 | +matrices, to optimize the network for inference. |
| 164 | + |
| 165 | +#### Encoding a large family of pruning algorithms |
| 166 | +Using simple `module name` to `mask info` mapping rules, a large set of pruning strategies can be explored (and some are still to be): |
| 167 | +- 1D pruning : for pruning the two BERT FFNs jointly, a single mask is created, and their MaskedLinear objects points to this mask |
| 168 | +- 2D pruning : each MaskedLinear has its own Mask |
| 169 | +- 2D joint pruning : for attention, pruning the 4 layers (KQV + output) at once can be testing easily: map all these layers to the same Mask information |
| 170 | + |
| 171 | +#### Packaging |
| 172 | + |
| 173 | +The idea was not to produce just a research project |
| 174 | +The library is structured in three layers: |
| 175 | +- Low level: details of the pruning algorithms |
| 176 | +It should be easy to extend the set of supported pruning algorithms, mostly by deriving new subclasses. |
| 177 | +- PatchCoordinator : minimum set of functions to be called to patch the network, fine-prune-distill it, and compile it back to its original form |
| 178 | +The PatchCoordinator is meant to be used as in any fine-tuning code, using a Trainer or not. |
| 179 | +- SparseTrainer : a MixIn to be used with the HuggingFace Trainer, that encapsulate the PatchCoordinator |
| 180 | +The SparseTrainer should manage almost 100% of the pruning process, using itself a PatchCoordinator |
| 181 | + |
| 182 | + |
| 183 | +#### Inference |
| 184 | +The benchmarks that are presented on the nn_pruning main page are quite conservative: it's hard to know from the papers |
| 185 | +what kind of setup was used to benchmark MobileBert or TinyBERT. |
| 186 | +I chose to measure the CUDA time on PyTorch forward pass, with large batches, but the displayed speedups may include some time that is not 100% related to model computation. |
| 187 | +More benchmarks with ONNX/ORT may actually show an even better performance. |
| 188 | + |
| 189 | +The produced models are 100% compatible with ONNX, as they are dense networks, with just some different shapes on different layers. |
| 190 | +On this topic, Morgan Funtowicz proposed a [patch](https://github.com/microsoft/onnxruntime/pull/6850) for ONNX/ORT that was accepted, to support different numbers of heads for each layer when using specially optimized attention module. |
| 191 | + |
| 192 | +Further work on inference: |
| 193 | +- Quantization: it is still not yet known what kind of accuracy drop will be observed |
| 194 | +- LayerNorm replacement with NoNorm is claimed by the ModelBERT paper to reduce latency by 3x. |
| 195 | + Some models have been produced with very limited accuracy drop, we still have to test their speed. |
| 196 | +- for the same reason, we can replace GeLUs by ReLUs, but this has not been yet implemented (much easier than LayerNorm) |
| 197 | +- Fusion of NoNorm with Linear Layers (if ORT doesn't fuse automatically those layers) |
| 198 | + |
0 commit comments