When Linear Attention Meets Autoregressive Decoding: Towards More Effective and Efficient Linearized Large Language Models
Haoran You, Yichao Fu, Zheng Wang, Amir Yazdanbakhsh, Yingyan (Celine) Lin
Accepted by ICML 2024. More Info: [ Paper | Github ]
- [ ✅ New ] Jun. 11, 2024. 💥 Release our trained LLaMA-2-7B model checkpoints on Huggingface!
- [ ✅ New ] Jun. 11, 2024. 💥 Linearized-LLM's PyTorch implementation codes are released!
The main implementation can be found in the autoregressive_wrapper.py
and flash_pytorch.py
files. The code is adapted from FLASH.
Please set up the environment using the following commands and ensure that CUDA is included in your PATH:
export PATH=/PATH-TO-CUDA/:$PATH
conda create -n LinearLLM python==3.10
conda activate LinearLLM
pip install -r requirements.txt
pip install flash-attn
We provide our trained model checkpoints at this HuggingFace repository. Follow the bash script below to download the model:
# Linearized LLaMA-2 weights
huggingface-cli download LinearizedLLM/llama-2-7b-aug-linear --local-dir llama-2-7b-aug-linear
# Medusa Head for Linearized LLaMA-2 weights
huggingface-cli download LinearizedLLM/llama-2-7b-medusa-head-aug-linear --local-dir llama-2-7b-medusa-head-aug-linear
To reproduce Table 8 from the paper, which demonstrates the speedup of augmented linearized LLaMA-2 with speculative decoding, use the following bash script. The code is adapted from the Medusa repository.
cd experiments
bash run_medusa.sh
To reproduce Table 4, which shows latency and memory improvements with our augmented linear attention, use the following bash script. Note that we use transformers==4.37.0
.
pip install transformers==4.37.0
cd experiments
bash run_benchmark.sh
Use the bash script below to train a 24-layer FLASH Model from scratch:
bash runall-125k.sh
Use the bash script below to finetune T5 with augmented linear attention. The code is adapted from the transformers repository.
cd experiments
bash tasks_run-t5.sh
Use the bash script below to finetune GPT-2 with augmented linear attention. The code is adapted from the transformers repository.
cd experiments
bash tasks_run-gpt2.sh
Use the bash script below to finetune LLaMA-2 with augmented Linear Attention. The code is adapted from the LongLoRA repository.
cd experiments
bash tasks_run-llama2.sh
@inproceedings{you2024linear,
title={When Linear Attention Meets Autoregressive Decoding: Towards More Effective and Efficient Linearized Large Language Models},
author={You, Haoran and Fu, Yichao and Wang, Zheng and Yazdanbakhsh, Amir and Lin, Yingyan (Celine)},
booktitle={Proceedings of the 41st International Conference on Machine Learning (ICML 2024)},
year={2024},
}
Thanks to the developers of FLASH, transformers, LongLoRA, and Medusa for providing their codebases!