|
1 | 1 | ### FastBert, 复现ACL2020论文 [FastBERT: a Self-distilling BERT with Adaptive Inference Time](https://arxiv.org/pdf/2004.02178.pdf) |
2 | 2 |
|
3 | | -### 简介 |
4 | | - 相比一众BERT蒸馏方法,FastBERT的蒸馏过程和具体任务一起进行, 且主干BERT网络不变,有如下优点: |
5 | | - 1. 准确率损失较小, 加速比约1~10倍, |
6 | | - 2. 和BERT预训练模型兼容,可灵活替换各种预训练模型,如ernie |
7 | | - 3. 使用简单, 较为实用 |
8 | | - 适用于文本分类任务,想要用BERT提升效果但受限于机器资源的场景 |
9 | | - |
10 | | -### 使用方法 |
11 | | - 下载pretrain的bert模型: |
12 | | - <br> |
13 | | - [bert中文(微云)](https://share.weiyun.com/5goxygS),[bert中文(百度网盘),提取码:mxfb](https://pan.baidu.com/s/15YpZB7uxuiCjEFZh1zx38w) |
14 | | - 放至目录./pretrained_model/bert-chinese/bert-pytorch-google/下 |
15 | | - <br> |
16 | | - [百度ernie(微云)](https://share.weiyun.com/5rYpEBs),[百度ernie(百度网盘),提取码:vui3](https://pan.baidu.com/s/1nD_hIYfis6Qn2JTXvX7dGg) |
17 | | - 放至目录./pretrained_model/ernie/ERNIE_stable-1.0.1-pytorch/下 |
18 | | -```bash |
19 | | - 1. 初始训练: |
20 | | - sh run_scripts/script_train_stage0.sh |
21 | | - |
22 | | - 2. 蒸馏训练: |
23 | | - sh run_scripts/script_train_stage1.sh |
24 | | - **注意** :蒸馏阶段输入数据为无监督数据,可依据需要引入更多数据提升鲁棒性 |
25 | | - |
26 | | - 3. 推理: |
27 | | - sh run_scripts/script_infer.sh |
28 | | - 其中 inference_speed参数(0.0~1.0)控制加速程度 |
29 | | - |
30 | | - 4. 部署使用 |
31 | | - python3 predict.py |
32 | | -``` |
33 | | - |
34 | | -```bash |
35 | | -若想替换为ernie模型,只需要把 |
36 | | ---model_config_file='config/fastbert_cls.json' |
37 | | -替换为 |
38 | | ---model_config_file='config/fastbert_cls_ernie.json' |
39 | | -``` |
40 | | - |
41 | | -### 模型思路 |
42 | | -##### 加速思路 |
43 | | - 1. 每一层transformer都接一个子分类器 |
44 | | - 2. 根据样本输入,自适应12层transformer的推理深度,子分类器置信度高则提前返回 |
45 | | -##### 训练思路 |
46 | | - 1. load 预训练好的bert |
47 | | - 2. 跟普通训练一样finetune模型 |
48 | | - 3. freeze主干网络和最后层的teacher分类器,每层的子模型拟合teacher分类器(KL散度为loss) |
49 | | - 4. inference阶段,根据样本输入,子分类器置信度高则提前返回 |
50 | | - |
51 | | -<img src="img/arch.png" alt="arch" width="800"/> |
52 | | - |
53 | | -### 结果对比 |
54 | | -##### 论文结果: |
55 | | -<img src="img/compare.png" alt="compare" width="800"/> |
56 | | - |
57 | | -##### 实测结果: |
58 | | -```bash |
59 | | -ChnSentiCorp: |
60 | | -speed_arg:0.0, time_per_record:0.14725032741416672, acc:0.9400, 基准 |
61 | | -speed_arg:0.1, time_per_record:0.10302954971909761, acc:0.9420, 1.42倍 |
62 | | -speed_arg:0.5, time_per_record:0.03420266199111938, acc:0.9340, 4.29倍 |
63 | | -speed_arg:0.8, time_per_record:0.019530397139952513, acc:0.9160, 7.54倍 |
64 | | -注:speed=0.1的情况下比基准的准确率还高,是有可能的,正则之类的效应 |
65 | | -``` |
66 | | - |
67 | | -1. 对于实际场景任务中,测试了语义理解要求较高的模型,经测试在加速1~6倍时,精度和原始BERT差距很小 |
68 | | -2. 对于稀疏场景模型,线上正例占比很小,则可进一步提高加速比, 某模型负例约17%走到第12层,正例约81%走到第12层 |
69 | | - |
70 | 3 |
|
0 commit comments