diff --git a/README.md b/README.md
index f33ebaf..619d980 100644
--- a/README.md
+++ b/README.md
@@ -46,14 +46,19 @@
## News
+🔥🔥🔥 [2024/05/20] We released **MFTCoder v0.4.2**, mainly for MFTCoder-accelerate. It supports **QLoRA + DeepSpeed Zero3** and **QLoRA + FSDP** as options allowing you training very large models. It now supports new models like Qwen2, Qwen2-MoE, Starcoder2, Gemma, etc.
-🔥🔥🔥 [2024/01/30] The model [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) fine-tuned with MFTCoder ranks first in HuggingFace [Big Code Models LeaderBoard](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard)
+🔥🔥🔥 [2024/05/20] Our paper [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) has been accepted by KDD2024.
-🔥🔥🔥 [2024/01/17] We released MFTCoder v0.3.0, mainly for MFTCoder-accelerate. It now supports new models like Mixtral(MoE), DeepSeek-coder, chatglm3. It supports FSDP as an option. It also supports Self-paced Loss as a solution for convergence balance in Multitask Fine-tuning.
+🔥🔥🔥 [2024/05/20] [CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) has been released, achieving a pass@1 (greedy decoding) score of 73.2% on HumanEval.
-🔥🔥🔥 [2024/01/17] [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) has been released, achieving a pass@1 (greedy decoding) score of 78.7% on HumanEval. It lists as top-1 LLM on Bigcode Leardboard in terms of win-rate, the official result is going to be published later.
+🔥🔥 [2024/01/30] The model [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) fine-tuned with MFTCoder ranks first in HuggingFace [Big Code Models LeaderBoard](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard)
-🔥🔥🔥 [2024/01/17] [CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8X7B) has been released, achieving a pass@1 (greedy decoding) score of 56.1% on HumanEval.
+🔥🔥 [2024/01/17] We released MFTCoder v0.3.0, mainly for MFTCoder-accelerate. It now supports new models like Mixtral(MoE), DeepSeek-coder, chatglm3. It supports FSDP as an option. It also supports Self-paced Loss as a solution for convergence balance in Multitask Fine-tuning.
+
+🔥🔥 [2024/01/17] [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) has been released, achieving a pass@1 (greedy decoding) score of 78.7% on HumanEval. It lists as top-1 LLM on Bigcode Leardboard in terms of win-rate, the official result is going to be published later.
+
+🔥🔥 [2024/01/17] [CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8X7B) has been released, achieving a pass@1 (greedy decoding) score of 56.1% on HumanEval.
🔥🔥 [2023/11/07] [MFTCoder Paper](https://arxiv.org/abs/2311.02303) has been released on Arxiv, which discloses technique details of multi-task-fine-tuning.
@@ -73,6 +78,7 @@
| **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 |
| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 |
| **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 |
+| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 |
| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 |
| GPT-4(zero-shot) | 67.0% | 2023/03 |
| PanGu-Coder2 15B | 61.6% | 2023/08 |
@@ -88,7 +94,7 @@
## Articles
-[MFT Arxiv paper](https://arxiv.org/abs/2311.02303)
+[MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning (KDD2024)](https://arxiv.org/abs/2311.02303)
## Introduction
@@ -125,13 +131,13 @@ The main components of this project include:
## Requirements
-To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 11.7) along with the necessary drivers. Additionally, make sure you have installed torch (version 2.0.1).
+To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 12.1) along with the necessary drivers. Additionally, make sure you have installed torch (version >= 2.1.0).
Next, we have provided an init_env.sh script to simplify the installation of required packages. Execute the following command to run the script:
```bash
sh init_env.sh
```
-We highly recommend training with flash attention(version >= 2.1.0, preferably 2.3.6), please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention
+We highly recommend training with flash attention(version >= 2.3.0), please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention
## Training
@@ -152,16 +158,16 @@ If you want to explore some new framework like atorch, you could check:
We are excited to release the following two CodeLLMs trained by MFTCoder, now available on both HuggingFace and ModelScope:
-| Model | HuggingFace Links | ModelScope Links | Base Model | Num of examples trained | Batch Size | Seq Length |
-|--------------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|------|------------|------------|
-| 🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 60万 | 80 | 4096 |
-| 🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 60万 | 80 | 4096 |
-| 🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 |
-| 🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 |
-| 🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 60万 | 80 | 4096 |
-| 🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 |
-| 🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 |
-
+| Model | HuggingFace Links | ModelScope Links | Base Model | Num of examples trained | Batch Size | Seq Length |
+|----------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|-------------------------|------------|------------|
+| 🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 600K | 80 | 4096 |
+| 🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 600K | 80 | 4096 |
+| 🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 600K | 80 | 4096 |
+| 🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 |
+| 🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 600K | 80 | 4096 |
+| 🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 1.1 Million | 256 | 4096 |
+| 🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 1.1 Million | 256 | 4096 |
+| 🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 700K | 128 | 4096 |
## Datasets
We are also pleased to release two code-related instruction datasets, meticulously selected from a range of datasets to facilitate multitask training. Moving forward, we are committed to releasing additional instruction datasets covering various code-related tasks.
diff --git a/README_cn.md b/README_cn.md
index fff697e..143db35 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -45,11 +45,17 @@
## 新闻
-🔥🔥🔥 [2024/01/17] **MFTCoder-v0.3.0**发布。新增对Mixtral(MoE), DeepSeek等模型的支持;新增支持FSDP(Fully Sharded Data Parallel);新增Self-paced Loss, 支持多任务收敛均衡。 感兴趣详见微信公众号CodeFuse的文章[MFTCoder 重磅升级v0.3.0发布](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg)
+🔥🔥🔥 [2024/05/20] **MFTCoder-v0.4.2**发布。新增支持**QLoRA+ DeepSpeed Zero3**, **QLoRA + FSDP**训练模式,可以更好的支持微调更大的模型,比如Qwen1.5-70B等。新增对Qwen2, Qwen2-MoE, Starcoder2, Gemma等模型的支持。
-🔥🔥🔥 [2024/01/17] 开源了[CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B)模型,在HumanEval pass@1(greedy decoding)上可以达到78.7%。该模型在Big Code榜单的结果近期发布,请关注公众号获取最新信息。
+🔥🔥🔥 [2024/05/20] 我们的论文 [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) 已被 KDD 2024 接收.
-🔥🔥🔥 [2024/01/17] 开源了[CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B)模型,在HumanEval pass@1(greedy decoding)上可以达到56.1%。感兴趣详见微信公众号CodeFuse的文章[MFTCoder提升Mixtral-8x7B混合专家模型的代码能力实践](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg)
+🔥🔥🔥 开源了[CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B)模型,在HumanEval上可以达到73.2%,多代码语言能力均衡.
+
+🔥🔥 [2024/01/17] **MFTCoder-v0.3.0**发布。新增对Mixtral(MoE), DeepSeek等模型的支持;新增支持FSDP(Fully Sharded Data Parallel);新增Self-paced Loss, 支持多任务收敛均衡。 感兴趣详见微信公众号CodeFuse的文章[MFTCoder 重磅升级v0.3.0发布](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg)
+
+🔥🔥 [2024/01/17] 开源了[CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B)模型,在HumanEval pass@1(greedy decoding)上可以达到78.7%。该模型在Big Code榜单的结果近期发布,请关注公众号获取最新信息。
+
+🔥🔥 [2024/01/17] 开源了[CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B)模型,在HumanEval pass@1(greedy decoding)上可以达到56.1%。感兴趣详见微信公众号CodeFuse的文章[MFTCoder提升Mixtral-8x7B混合专家模型的代码能力实践](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg)
🔥🔥 [2023/11/07] [MFTCoder论文](https://arxiv.org/abs/2311.02303)在Arxiv公布,介绍了多任务微调的技术细节。
@@ -69,6 +75,7 @@
| **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 |
| **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 |
| **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 |
+| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 |
| WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 |
| GPT-4(zero-shot) | 67.0% | 2023/03 |
| PanGu-Coder2 15B | 61.6% | 2023/08 |
@@ -118,12 +125,12 @@
## 环境
-首先, 你需要将CUDA(>=11.4, 推荐11.7)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.0.0)
+首先, 你需要将CUDA(>=11.4, 推荐12.1)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.1.0)
在requirements.txt下固定了几个主要的python包的版本,执行如下脚本即可:
```bash
sh init_env.sh
```
-我们强烈建议您安装flash attention(>=2.1.0, 推荐2.3.6), 安装请参考 https://github.com/Dao-AILab/flash-attention
+我们强烈建议您安装flash attention(>=2.3.0), 安装请参考 https://github.com/Dao-AILab/flash-attention
## 训练
如果你熟悉大模型训练的各种主流开源资源,例如 ```transformers```, ```DeepSpeed```, ```FSDP```等, 为了用开源项目快速上手高性能微调,我们建议您尝试:
@@ -145,11 +152,11 @@ sh init_env.sh
| 🔥🔥🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 60万 | 80 | 4096 |
| 🔥🔥🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 60万 | 80 | 4096 |
| 🔥🔥🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 |
-| 🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 |
+| 🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 |
| 🔥🔥🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 60万 | 80 | 4096 |
-| 🔥🔥🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 |
-| 🔥🔥🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 |
-
+| 🔥🔥🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 |
+| 🔥🔥🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 |
+| 🔥🔥🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 70万 | 128 | 4096 |
diff --git a/mftcoder_accelerate/README.md b/mftcoder_accelerate/README.md
index 649811d..f65a21e 100644
--- a/mftcoder_accelerate/README.md
+++ b/mftcoder_accelerate/README.md
@@ -7,8 +7,9 @@
[[中文]](README_cn.md) [**English**]
## 1. Updates
+🔥 MFTCoder-accelerate now support these modes: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, Full-parameter + DeepSpeed ZeRO3, QLoRA + FSDP, Full-parameter + FSDP.
-🔥 MFTCoder-accelerate supports Full-parameters/LoRA using accelerate + FSDP Framework;
+🔥 MFTCoder-accelerate supports QLoRA + DeepSpeed ZeRO3 and QLoRA + FSDP, which both work for larger models;
🔥 MFTCoder-accelerate supports MFT/SFT on more new mainstream open-source base models: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3;
@@ -175,10 +176,14 @@ DeepSpeed config in accelerate_ds_config.yaml.
accelerate launch --config_file accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "DeepSpeed"
```
or
-DeepSpeed config in command line arguments
+DeepSpeed Zero2 config in command line arguments
```bash
sh ds_single_launch.sh
```
+DeepSpeed Zero3 config in command line arguments
+```bash
+sh ds_zero3_single_launch.sh
+```
#### Launch via FSDP
FSDP config in accelerate_fsdp_config.yaml.
@@ -188,7 +193,13 @@ accelerate launch --config_file accelerate_fsdp_config.yaml pefts/mft_accelerate
or
FSDP config in command line arguments
```bash
-sh ds_single_launch.sh
+sh fsdp_single_launch.sh
+```
+
+#### MultiNode Launch
+Refer to the deepspeed multi-node launch script below.
+```bash
+sh ds_multinode_launch.sh
```
#### Traing Arguments
@@ -328,6 +339,8 @@ beam_num: Set a smaller value such as 1 or 3. ```beam_num=1``` represents greedy
If OOM happened,you can reduce parameters such as per_device_train_batch_size and seq_length. Since you are dealing with large models (6B, 13B, 34B, 70B, etc.), you are already using gradient checkpointing technology by default, which significantly reduces GPU memory consumption.
However, this may slightly slow down the training speed.
+QLoRA + DeepSpeed Zero3 is recommended for larger models to avoid OOM.
+
#### Q2:install packages
Please refer to init_env.sh and requirements.txt
We highly recommend you install Flash Attention 2 (flash_attn>=2.1.0, 2.3.6 used by us) first to get memory-efficient and fast training.
@@ -339,7 +352,8 @@ CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file accelerate_ds_config.ya
```
#### Q4:Whats is a recommended Distributed Training?
-For LoRA/QLoRA, we recommend DeepSpeed(ZeRO2) as the underlying framework, because it is easy and stable to use, moreover it is more compatable for different settings.
-And FSDP does not support Quantization(integer type in training).
+For LoRA, we recommend DeepSpeed ZeRO2 as the underlying framework, because it is easy and stable to use, moreover it is more compatable for different settings.
+
+For QLoRA, DeepSpeed ZeRO2 and DeepSpeed ZeRO3 are both good, moreover DeepSpeed ZeRO3 is a good choice for very large models.
-For Full-parameter finetuning, FSDP is usually faster, and may help you with very large models by sharding parameters and gradients.
\ No newline at end of file
+For Full-parameter finetuning, DeepSpeed ZeRO3 and FSDP are faster, and may help you with very large models by sharding parameters and gradients.
\ No newline at end of file
diff --git a/mftcoder_accelerate/README_cn.md b/mftcoder_accelerate/README_cn.md
index c208a4a..0acb8d8 100644
--- a/mftcoder_accelerate/README_cn.md
+++ b/mftcoder_accelerate/README_cn.md
@@ -7,6 +7,10 @@
[**中文**] [[English]](README.md)
## 1. 更新
+🔥 MFTCoder-accelerate 最新支持的训练模式包括: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, 全量 + DeepSpeed ZeRO3, QLoRA + FSDP, 全量 + FSDP。
+
+🔥 MFTCoder-accelerate 新增支持QLoRA + DeepSpeed ZeRO3, 支持QLoRA + FSDP, 可以训练更大的模型;
+
🔥 MFTCoder-accelerate 新增支持accelerate + FSDP框架, 支持全量微调和LoRA;
🔥 MFTCoder-accelerate 支持最新更多主流开源模型: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3;
@@ -142,19 +146,24 @@ QLoRA通过4-bit的nf4量化,且加入更多adapter,在大幅减少显存消
QLoRA论文指出,该方法可以在一张V100上对33B的模型进行微调,并且性能逼近全量参数微调。
执行如下命令即可进行 Lora/QLora/全量 微调:
-#### Launch via Deepspeed
+#### Deepspeed 单机启动
DeepSpeed配置在accelerate_ds_config.yaml中。
```bash
accelerate launch --config_file accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "DeepSpeed"
```
或者
-DeepSpeed配置在脚本中通过命令行输入。
+DeepSpeed Zero2 配置在脚本中通过命令行输入。
```bash
sh ds_single_launch.sh
```
-#### Launch via FSDP
+DeepSpeed Zero3 配置在脚本中通过命令行输入
+```bash
+sh ds_zero3_single_launch.sh
+```
+
+#### FSDP 单机启动
FSDP配置在accelerate_fsdp_config.yaml中。
```bash
accelerate launch --config_file accelerate_fsdp_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "FSDP"
@@ -166,6 +175,12 @@ FSDP配置在脚本中通过命令行输入。
sh fsdp_single_launch.sh
```
+#### 多机启动
+多机启动请参考如下deepspeed多机启动脚本
+```bash
+sh ds_multinode_launch.sh
+```
+
#### 训练参数
_**训练需要的参数配置在```configs/*_train_config```中,主要参数说明如下:**_
@@ -256,7 +271,7 @@ print(gen_text)
## 5. FAQ
#### 问题1:OOM如何解决?
如果发生OOM,可以缩小per_device_train_batch_size、seq_length等参数来缓解。由于面对的模型普遍较大(6b, 13b, 34b, 70b等)我们已经默认使用gradient_checkpointing技术,可以大幅降低显存占用,但训练速度会稍慢一些。
-
+如果是模型太大,可以使用QLoRA + DeepSpeed ZeRO3(配置 zero stage = 3),这个方案可以在卡数足够的情况下,微调更大的模型。
#### 问题2:安装包错误
参考init_env.sh和requirements.txt
@@ -276,14 +291,14 @@ CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file pefts/accelerate_ds_con
如果你可以自行安装环境并使用torch>=2.1.1,可以尝试设置参数"attn_implementation"为 "sdpa"。这样会尝试使用transformers兼容的torch.nn.functional.scaled_dot_product_attention。支持的模型还不全面。
#### 问题5:推荐的分布式框架是怎样的?
-对于LoRA/QLoRA, 我们推荐使用DeepSpeed作为底层分布式框架,它具有易用性和兼容性好的特点,并且速度很快。
-FSDP 不支持QLoRA, 因为bitsandbytes暂不支持FSDP。
+对于LoRA, 我们推荐使用DeepSpeed Zero2作为底层分布式框架,它具有易用性和兼容性好的特点,并且速度很快, 模型加载模式上类似DDP。
+对于QLoRA, DeepSpeed Zero2 适合中小模型, DeepSpeed Zero3 适合很大的模型。
-对于全量微调,我们推荐使用FSDP, 因为它在全量训练时可以发挥fully sharding的优势,达到更快的训练速度。
+对于全量微调,可以使用DeepSpeed Zero3, 或者FSDP。二者都是Fully Sharding模式,即模型加载平分在每张卡。
#### 问题6:当前支持的模型中,有什么区别
国产大模型比如chatglm2, chatglm3, baichuan2, qwen, aquila2等,使用的是和模型共同发布的modeling_xxx.py.
-其它被transformers官方支持的大模型,由于已经升级支持flash attention等,所以全面切换到官方的modeling支持训练,之前的自定义modeling会被deprecated
+其它被transformers官方支持的大模型,比如llama, qwen2, starcoder2, mistral等,全面切换到官方的modeling支持训练,之前的自定义modeling会被deprecated。
diff --git a/mftcoder_accelerate/src/data/helpers.cpython-38-x86_64-linux-gnu.so b/mftcoder_accelerate/src/data/helpers.cpython-38-x86_64-linux-gnu.so
deleted file mode 100755
index 6fbc1b7..0000000
Binary files a/mftcoder_accelerate/src/data/helpers.cpython-38-x86_64-linux-gnu.so and /dev/null differ
diff --git a/mftcoder_accelerate/src/data/multi_task_dataset.py b/mftcoder_accelerate/src/data/multi_task_dataset.py
index 1a8612a..fde298b 100644
--- a/mftcoder_accelerate/src/data/multi_task_dataset.py
+++ b/mftcoder_accelerate/src/data/multi_task_dataset.py
@@ -27,12 +27,12 @@ def __init__(
self.name = name
self.input_dataset = input_dataset
- self.num_samples = len(self.input_dataset['input_ids'])
+ self.num_samples = len(self.input_dataset["input_ids"])
self.seq_length = seq_length
self.weighted_loss_mode = weighted_loss_mode
self.ds_weight = ds_weight
- self.task_name = data_prefix.split('/')[-1]
+ self.task_name = data_prefix.split("/")[-1]
self.task_id = TASK2ID[self.task_name]
# Checks
@@ -47,8 +47,7 @@ def __getitem__(self, idx):
try:
# Get the shuffled index.
idx = idx % self.num_samples
- idx_data = {key: self.input_dataset[key][idx]
- for key in self.input_dataset}
+ idx_data = {key: self.input_dataset[key][idx] for key in self.input_dataset}
if self.weighted_loss_mode:
idx_data["weight"] = np.array([self.ds_weight], dtype=np.float32)
@@ -115,9 +114,7 @@ def __init__(self, datasets, weights, global_num_samples, local_num_samples):
print(
"> RANK {} elapsed time for building blendable dataset indices: "
- "{:.2f} (sec)".format(
- torch.distributed.get_rank(), time.time() - start_time
- )
+ "{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time)
)
def calc_weights(self):
@@ -166,7 +163,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
encoder = UniformEncoder(args, args.tokenize_mode)
encoder.initializer()
- data_prefixes = list(args.data_paths[1:-1].split(','))
+ data_prefixes = list(args.data_paths[1:-1].split(","))
splits = []
splits_string = args.data_split
@@ -179,7 +176,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
- print(f'data splits: {splits}')
+ print(f"data splits: {splits}")
all_train_datasets = []
all_valid_datasets = []
@@ -200,40 +197,40 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
cur_dataset_loss_mask = []
# support multiple jsonl files under task dir
for file in files:
- file_name = data_prefixes[dataset_index] + '/' + file
+ file_name = data_prefixes[dataset_index] + "/" + file
if os.path.isdir(file_name):
continue
- fin = open(file_name, 'r')
- print(f'[Global Rank {global_rank}] open file {file_name}')
+ fin = open(file_name, "r")
+ print(f"[Global Rank {global_rank}] open file {file_name}")
- if args.padding_mode == 'padding' or args.padding_mode == 'pack':
+ if args.padding_mode == "padding" or args.padding_mode == "pack":
for i, line in enumerate(fin):
# pre-sharding
if shard_data and i % world_size != global_rank:
continue
- data = json.loads(line.rstrip('\n\r'))
+ data = json.loads(line.rstrip("\n\r"))
features, length = encoder.encode(data, verbose=(i < 1))
# features, length = encoder.encode(data)
# may have more samples
- for idx in range(len(features['input_ids'])):
- cur_dataset_input_ids.append(features['input_ids'][idx])
- cur_dataset_loss_mask.append(features['loss_mask'][idx])
+ for idx in range(len(features["input_ids"])):
+ cur_dataset_input_ids.append(features["input_ids"][idx])
+ cur_dataset_loss_mask.append(features["loss_mask"][idx])
fin.close()
else:
i = 0
for line in fin:
- data = json.loads(line.rstrip('\n\r'))
+ data = json.loads(line.rstrip("\n\r"))
features, length = encoder.encode(data)
# 一个document可能编码不出sample,可能编码出多个sample
- for idx in range(len(features['input_ids'])):
+ for idx in range(len(features["input_ids"])):
# post-sharding
if shard_data and i % world_size != global_rank:
i += 1
continue
i += 1
- cur_dataset_input_ids.append(features['input_ids'][idx])
- cur_dataset_loss_mask.append(features['loss_mask'][idx])
+ cur_dataset_input_ids.append(features["input_ids"][idx])
+ cur_dataset_loss_mask.append(features["loss_mask"][idx])
fin.close()
cur_dataset_input_ids = np.array(cur_dataset_input_ids, dtype=np.float32)
@@ -249,46 +246,40 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
train_ratio = splits[0] / 100.0
train_num = int(math.ceil(train_ratio * cur_dataset_sample_num))
# split train/valid
- cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[: train_num], cur_dataset_input_ids[train_num:]
- cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[: train_num], cur_dataset_loss_mask[train_num:]
+ cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[:train_num], cur_dataset_input_ids[train_num:]
+ cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[:train_num], cur_dataset_loss_mask[train_num:]
local_train_num += train_num
- local_valid_num += (cur_dataset_sample_num - train_num)
-
- cur_train_dataset = {
- 'input_ids': cur_train_input_ids,
- 'loss_mask': cur_train_loss_mask
- }
- cur_valid_dataset = {
- 'input_ids': cur_valid_input_ids,
- 'loss_mask': cur_valid_loss_mask
- }
+ local_valid_num += cur_dataset_sample_num - train_num
+
+ cur_train_dataset = {"input_ids": cur_train_input_ids, "loss_mask": cur_train_loss_mask}
+ cur_valid_dataset = {"input_ids": cur_valid_input_ids, "loss_mask": cur_valid_loss_mask}
print(f"[Global Rank {global_rank}]shape of cur train dataset: {cur_train_dataset['input_ids'].shape}")
print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}")
cur_train_ds = GPT2FromRawDataset(
- 'train',
+ "train",
data_prefixes[dataset_index],
cur_train_dataset,
args.seq_length,
weighted_loss_mode=args.weighted_loss_mode,
- ds_weight=splits[0]
+ ds_weight=splits[0],
)
cur_valid_ds = GPT2FromRawDataset(
- 'valid',
+ "valid",
data_prefixes[dataset_index],
cur_valid_dataset,
args.seq_length,
weighted_loss_mode=args.weighted_loss_mode,
- ds_weight=splits[1]
+ ds_weight=splits[1],
)
-
+
all_train_datasets.append(cur_train_ds)
all_valid_datasets.append(cur_valid_ds)
all_train_datasets_length.append(len(cur_train_ds))
all_valid_datasets_length.append(len(cur_valid_ds))
-
- print(f'[Global Rank {global_rank}]num tokens: {num_tokens}')
- print(f'[Global Rank {global_rank}]effective token rate: {effective_token_rate}')
+
+ print(f"[Global Rank {global_rank}]num tokens: {num_tokens}")
+ print(f"[Global Rank {global_rank}]effective token rate: {effective_token_rate}")
num_tokens = []
ds_fn = partial(ds_weights_by_num_docs_sft)
@@ -296,7 +287,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
ds_fn(all_train_datasets_length),
ds_fn(all_valid_datasets_length),
)
-
+
print(f"> train loss weights in rank {global_rank}: {train_loss_weights}")
print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}")
@@ -306,51 +297,63 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length)
factor /= sum([1.0 / w for w in train_loss_weights]) / len(train_loss_weights)
print(f"> common denomination factor for CE loss in rank {global_rank}: {factor}")
-
+
train_sample_weights = [x / sum(all_train_datasets_length) for x in all_train_datasets_length]
valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length]
print(f"> train sample weights in rank {global_rank}: {train_sample_weights}")
print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}")
# recompute global_train_num and global_valid_num
-
+
torch.distributed.barrier()
device = f"cuda:{local_rank}"
-
+
global_train_num_samples_tensor = torch.tensor(local_train_num, dtype=torch.int32)
global_train_num_samples_tensor = global_train_num_samples_tensor.to(device)
torch.distributed.all_reduce(global_train_num_samples_tensor, op=torch.distributed.ReduceOp.SUM)
global_train_num = global_train_num_samples_tensor.item()
-
+
global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32)
global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device)
torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM)
global_valid_num = global_valid_num_samples_tensor.item()
print(f"> global train num in rank {global_rank}: {global_train_num}")
print(f"> global valid num in rank {global_rank}: {global_valid_num}")
-
+
torch.distributed.barrier()
for i in range(len(all_train_datasets)):
- print(f'loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
+ print(
+ f"loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}"
+ )
blending_train_dataset = None
if all_train_datasets:
args.do_train = True
for i in range(len(all_train_datasets)):
all_train_datasets[i].update_ds_weight(train_loss_weights[i] / factor)
- print(f'loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
- blending_train_dataset = GPT2BlendableDataset(all_train_datasets, train_sample_weights, global_train_num, local_train_num)
+ print(
+ f"loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}"
+ )
+ blending_train_dataset = GPT2BlendableDataset(
+ all_train_datasets, train_sample_weights, global_train_num, local_train_num
+ )
- for i in range(len(all_train_datasets)):
- print(f'loss weight of valid dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
+ for i in range(len(all_valid_datasets)):
+ print(
+ f"loss weight of valid dataset {i} before update in rank {global_rank}: {all_valid_datasets[i].ds_weight}"
+ )
blending_valid_dataset = None
if all_valid_datasets:
args.do_valid = True
for i in range(len(all_valid_datasets)):
all_valid_datasets[i].update_ds_weight(valid_loss_weights[i] / factor)
- print(f'loss weight of valid dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
- blending_valid_dataset = GPT2BlendableDataset(all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num)
-
+ print(
+ f"loss weight of valid dataset {i} after update in rank {global_rank}: {all_valid_datasets[i].ds_weight}"
+ )
+ blending_valid_dataset = GPT2BlendableDataset(
+ all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num
+ )
+
return blending_train_dataset, blending_valid_dataset
@@ -359,11 +362,13 @@ def compile_helper():
is invoked on a single process."""
import os
import subprocess
+
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(["make", "-C", path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
+
sys.exit(1)
else:
print("Making C++ dataset helpers module successfully.")
diff --git a/mftcoder_accelerate/src/data/preprocess_data.py b/mftcoder_accelerate/src/data/preprocess_data.py
index 5f6d286..3c912e6 100644
--- a/mftcoder_accelerate/src/data/preprocess_data.py
+++ b/mftcoder_accelerate/src/data/preprocess_data.py
@@ -8,33 +8,32 @@
import sys
import ftfy
import glob
-print("In preprocess_data.py, sys path:", sys.path)
+
+# print("In preprocess_data.py, sys path:", sys.path)
from tokenizer import build_tokenizer
-CHAT_COL = 'chat_rounds'
-ROLE_COL = 'role'
-CONTENT_COL = 'content'
+CHAT_COL = "chat_rounds"
+ROLE_COL = "role"
+CONTENT_COL = "content"
-SYSTEM_COL = 'system'
-PROMPT_COL = 'prompt'
-ANSWER_COL = 'answer'
+SYSTEM_COL = "system"
+PROMPT_COL = "prompt"
+ANSWER_COL = "answer"
-TEXT_COL = 'text'
+TEXT_COL = "text"
-table = {ord(f): ord(t) for f, t in zip(
- u',。!?:【】()%#@&1234567890',
- u',.!?:[]()%#@&1234567890')}
+table = {ord(f): ord(t) for f, t in zip(",。!?:【】()%#@&1234567890", ",.!?:[]()%#@&1234567890")}
def content_format(content: str):
# Replace non-breaking space with space
- content = content.replace('\u202f', ' ').replace('\xa0', ' ')
+ content = content.replace("\u202f", " ").replace("\xa0", " ")
# change chinese punctuation to english ones
# text = text.translate(table)
- content += '\n'
+ content += "\n"
return content
@@ -112,7 +111,7 @@ def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
# self.tokenizer = build_tokenizer(self.args)
-
+
def pure_encode(self, content):
return Encoder.tokenizer.encode(content, add_special_tokens=False)
@@ -132,8 +131,7 @@ def encode(self, text):
class UniformEncoder(Encoder):
-
- def __init__(self, args, mode='sft'):
+ def __init__(self, args, mode="sft"):
super().__init__(args)
self.verbose = False
self.mode = mode
@@ -149,47 +147,45 @@ def __init__(self, args, mode='sft'):
def encode(self, data, verbose=False):
self.verbose = verbose
- encode_res = {
- "input_ids": [],
- "loss_mask": []
- }
+ encode_res = {"input_ids": [], "loss_mask": []}
if is_prompt_answer_format(data):
- data_type = 'prompt_answer'
+ data_type = "prompt_answer"
elif is_prompt_response_format(data):
- data_type = 'prompt_response'
+ data_type = "prompt_response"
elif is_input_output_format(data):
- data_type = 'input_output'
+ data_type = "input_output"
elif is_instruction_output_format(data):
- data_type = 'instruction_output'
+ data_type = "instruction_output"
elif is_instruction_response_format(data):
- data_type = 'instruction_response'
+ data_type = "instruction_response"
elif is_question_response_format(data):
- data_type = 'question_response'
+ data_type = "question_response"
elif is_question_answer_format(data):
- data_type = 'question_answer'
+ data_type = "question_answer"
elif is_chatml_format(data):
- data_type = 'chatML'
+ data_type = "chatML"
elif is_text_format(data):
- data_type = 'text'
+ data_type = "text"
else:
raise ValueError(
f"data_type does not support"
f"please use chatML or prompt/answer, prompt/response, question/response, "
- f"instruction/output, input/output, instruction/output or text(only for pretrain)")
-
+ f"instruction/output, input/output, instruction/output or text(only for pretrain)"
+ )
+
length = 0
- if data_type == 'chatML':
- for chat in data['chat_rounds']:
- length += len(chat['content'])
- elif data_type == 'text':
- length += len(data['text'])
+ if data_type == "chatML":
+ for chat in data["chat_rounds"]:
+ length += len(chat["content"])
+ elif data_type == "text":
+ length += len(data["text"])
else:
- # update key
+ # update key
global PROMPT_COL, ANSWER_COL
- PROMPT_COL, ANSWER_COL = tuple(data_type.split('_'))
+ PROMPT_COL, ANSWER_COL = tuple(data_type.split("_"))
length = len(data[PROMPT_COL]) + len(data[ANSWER_COL])
-
+
for token_res in self._tokenize_fields(data, data_type=data_type):
for k, v in token_res.items():
encode_res[k].append(v)
@@ -197,19 +193,19 @@ def encode(self, data, verbose=False):
return encode_res, length
def _tokenize_fields(self, data, data_type):
- if self.mode == 'sft':
+ if self.mode == "sft":
if self.args.role_markers:
system_marker = self.args.role_markers["system"]
user_marker = self.args.role_markers["user"]
assistant_marker = self.args.role_markers["assistant"]
else:
- system_marker = 'system\n'
- user_marker = 'human\n'
- assistant_marker = 'bot\n'
- elif self.mode == 'pretrain':
- system_marker = ''
- user_marker = ''
- assistant_marker = ''
+ system_marker = "system\n"
+ user_marker = "human\n"
+ assistant_marker = "bot\n"
+ elif self.mode == "pretrain":
+ system_marker = ""
+ user_marker = ""
+ assistant_marker = ""
else:
raise ValueError(f"tokenize_mode does not support {self.mode}, please use sft or pretrain")
@@ -218,9 +214,9 @@ def _tokenize_fields(self, data, data_type):
input_ids = []
loss_mask = []
- if data_type == 'chatML':
+ if data_type == "chatML":
chat = data[CHAT_COL]
- if chat[0][ROLE_COL] == 'system':
+ if chat[0][ROLE_COL] == "system":
sys_content_ids = self.pure_encode(system_marker + content_format(chat[0][CONTENT_COL]))
chat = chat[1:]
input_ids += sys_content_ids
@@ -230,15 +226,17 @@ def _tokenize_fields(self, data, data_type):
role = r[ROLE_COL]
content = r[CONTENT_COL]
content = content_format(content)
- if (role == 'human' or role == 'user') != (i % 2 == 0):
- raise ValueError("Conversation roles must alternate user/assistant/user/assistant/... or human/bot/human/bot/...')")
-
+ if (role == "human" or role == "user") != (i % 2 == 0):
+ raise ValueError(
+ "Conversation roles must alternate user/assistant/user/assistant/... or human/bot/human/bot/...')"
+ )
+
# compute loss only for assistant's content and eos token afterward
- if role == 'human' or role == 'user':
+ if role == "human" or role == "user":
content_ids = self.pure_encode(user_marker + content + assistant_marker)
input_ids += content_ids
loss_mask += [0] * len(content_ids)
- elif role == 'bot' or role == 'assistant':
+ elif role == "bot" or role == "assistant":
content_ids = self.pure_encode(content) + sft_end_marker_ids
input_ids += content_ids
loss_mask += [1] * len(content_ids)
@@ -255,7 +253,7 @@ def _tokenize_fields(self, data, data_type):
input_ids += text_ids
loss_mask += [1] * len(text_ids)
else:
- system = data.get(SYSTEM_COL, '')
+ system = data.get(SYSTEM_COL, "")
prompt = data[PROMPT_COL]
answer = data[ANSWER_COL]
@@ -270,28 +268,28 @@ def _tokenize_fields(self, data, data_type):
loss_mask += [0] * len(prompt_ids) + [1] * len(answer_ids)
# print(self.mode)
- if self.mode == 'pretrain':
+ if self.mode == "pretrain":
# change loss mask to all 1s
input_ids = input_ids
loss_mask = [1] * len(loss_mask)
- elif self.mode == 'sft':
+ elif self.mode == "sft":
# do nothing
input_ids = input_ids
loss_mask = loss_mask
-
+
if self.verbose:
print(f"original data:\n{data}")
print(f"decoding back:\n{Encoder.tokenizer.decode(input_ids)}")
assert len(input_ids) == len(loss_mask)
- if self.args.padding_mode == 'padding':
+ if self.args.padding_mode == "padding":
if len(input_ids) <= self.seq_length:
yield self.padding(input_ids, loss_mask)
# drop if too long
else:
yield {}
- elif self.args.padding_mode == 'concat':
+ elif self.args.padding_mode == "concat":
input_ids = self.remain_input_ids + input_ids
loss_mask = self.remain_loss_mask + loss_mask
if len(input_ids) < self.seq_length:
@@ -303,15 +301,15 @@ def _tokenize_fields(self, data, data_type):
cursor = 0
while cursor + self.seq_length <= len(input_ids):
yield {
- "input_ids": input_ids[cursor: cursor + self.seq_length],
- "loss_mask": loss_mask[cursor: cursor + self.seq_length]
+ "input_ids": input_ids[cursor : cursor + self.seq_length],
+ "loss_mask": loss_mask[cursor : cursor + self.seq_length],
}
cursor = cursor + self.stride
self.remain_input_ids = input_ids[cursor:]
self.remain_loss_mask = loss_mask[cursor:]
assert len(self.remain_input_ids) == len(self.remain_loss_mask)
yield {}
- elif self.args.padding_mode == 'pack':
+ elif self.args.padding_mode == "pack":
if len(input_ids) > self.seq_length:
yield {}
elif len(self.remain_input_ids) + len(input_ids) > self.seq_length:
@@ -330,10 +328,7 @@ def padding(self, input_ids, loss_mask):
assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}"
input_ids += [pad_id] * (self.seq_length - len(input_ids))
loss_mask += [0] * (self.seq_length - len(loss_mask))
- return {
- "input_ids": input_ids,
- "loss_mask": loss_mask
- }
+ return {"input_ids": input_ids, "loss_mask": loss_mask}
def find_jsonl_fnames(inputs):
diff --git a/mftcoder_accelerate/src/ds_multinode_launch.sh b/mftcoder_accelerate/src/ds_multinode_launch.sh
new file mode 100755
index 0000000..dca0670
--- /dev/null
+++ b/mftcoder_accelerate/src/ds_multinode_launch.sh
@@ -0,0 +1,44 @@
+#!/bin/sh
+# Author: Chaoyu Chen
+# Last Modified: 2024/5/20
+# Description: # Launch script on Multiple Nodes
+
+# Run this script on all Nodes.
+
+# You need to export your number of nodes and number of GPUs per node first.
+N_NODE=4
+N_GPU_PER_NODE=8
+
+# You need to export $MACHINE_RANK, $MASTER_ADDR, $MASTER_PORT automatically for each Node.
+
+# config path
+CONFIG="configs/xxx_train_config.json"
+
+# envs used inside training
+export OMP_NUM_THREADS=4
+export TOKENIZERS_PARALLELISM=False
+
+TODAY=$(date +%Y-%m%d-%H%M)
+
+# accelerate launch --config_file accelerate_ds_config.yaml \
+accelerate launch \
+ --num_machines $N_NODE \
+ --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \
+ --use_deepspeed \
+ --deepspeed_multinode_launcher 'standard' \
+ --zero_stage 2 \
+ --offload_optimizer_device 'cpu' \
+ --offload_param_device 'none' \
+ --gradient_accumulation_steps 1 \
+ --gradient_clipping 1.0 \
+ --zero3_init_flag false \
+ --zero3_save_16bit_model false \
+ --main_training_function 'main' \
+ --mixed_precision 'bf16' \
+ --dynamo_backend 'no' \
+ --same_network \
+ --machine_rank $MACHINE_RANK \
+ --main_process_ip $MASTER_ADDR \
+ --main_process_port $MASTER_PORT \
+ --rdzv_backend 'static' \
+ pefts/mft_accelerate.py --train_config "$CONFIG" --distributed_type "deepspeed"
\ No newline at end of file
diff --git a/mftcoder_accelerate/src/ds_single_launch.sh b/mftcoder_accelerate/src/ds_single_launch.sh
index 54f1528..d6c84bb 100755
--- a/mftcoder_accelerate/src/ds_single_launch.sh
+++ b/mftcoder_accelerate/src/ds_single_launch.sh
@@ -1,11 +1,14 @@
#!/bin/sh
# Author: Chaoyu Chen
-# Last Modified: 2024/12/11
+# Last Modified: 2023/12/11
# Description: An alternative(Command line) way to launch DeepSpeed training
# Launch script on single node
N_GPU_PER_NODE=8
+# config path
+CONFIG="configs/xxx_train_config.json"
+
# envs used inside training
export OMP_NUM_THREADS=4
export TOKENIZERS_PARALLELISM=False
@@ -30,6 +33,6 @@ accelerate launch \
--same_network \
--machine_rank 0 \
--rdzv_backend 'static' \
- pefts/mft_accelerate.py --train_config configs/"xxx_train_config.json" \
+ pefts/mft_accelerate.py --train_config "$CONFIG" \
--distributed_type "deepspeed" \
> MFTCoder-training-"$TODAY".log 2>&1 &
diff --git a/mftcoder_accelerate/src/ds_zero3_single_launch.sh b/mftcoder_accelerate/src/ds_zero3_single_launch.sh
new file mode 100755
index 0000000..5f581c9
--- /dev/null
+++ b/mftcoder_accelerate/src/ds_zero3_single_launch.sh
@@ -0,0 +1,38 @@
+#!/bin/sh
+# Author: Chaoyu Chen
+# Last Modified: 2024/5/20
+# Description: An alternative(Command line) way to launch DeepSpeed training
+
+# Launch script on single node
+N_GPU_PER_NODE=8
+
+# config path
+CONFIG="configs/xxx_train_config.json"
+
+# envs used inside training
+export OMP_NUM_THREADS=4
+export TOKENIZERS_PARALLELISM=False
+
+TODAY=$(date +%Y-%m%d-%H%M)
+
+# accelerate launch --config_file accelerate_ds_config.yaml \
+accelerate launch \
+ --num_machines 1 \
+ --num_processes $N_GPU_PER_NODE \
+ --use_deepspeed \
+ --zero_stage 3 \
+ --offload_optimizer_device 'cpu' \
+ --offload_param_device 'cpu' \
+ --gradient_accumulation_steps 1 \
+ --gradient_clipping 1.0 \
+ --zero3_init_flag true \
+ --zero3_save_16bit_model true \
+ --main_training_function 'main' \
+ --mixed_precision 'bf16' \
+ --dynamo_backend 'no' \
+ --same_network \
+ --machine_rank 0 \
+ --rdzv_backend 'static' \
+ pefts/mft_accelerate.py --train_config "$CONFIG" \
+ --distributed_type "deepspeed" \
+ > MFTCoder-training-"$TODAY".log 2>&1 &
diff --git a/mftcoder_accelerate/src/fsdp_single_launch.sh b/mftcoder_accelerate/src/fsdp_single_launch.sh
index 1959274..2dc8f89 100755
--- a/mftcoder_accelerate/src/fsdp_single_launch.sh
+++ b/mftcoder_accelerate/src/fsdp_single_launch.sh
@@ -1,11 +1,19 @@
#!/bin/sh
# Author: Chaoyu Chen
-# Last Modified: 2024/12/11
+# Last Modified: 2023/12/11
# Description: An alternative(command line) way to launch FSDP training
# Launch script on single node
N_GPU_PER_NODE=8
+# config path
+CONFIG="configs/xxx_train_config.json"
+
+# fsdp_transformer_layer_cls_to_wrap, choose the DecoderLayer
+WRAP_MODULE="LlamaDecoderLayer"
+
+
+
# envs used inside training
export OMP_NUM_THREADS=4
export TOKENIZERS_PARALLELISM=False
@@ -21,7 +29,7 @@ accelerate launch \
--fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
--fsdp_state_dict_type=FULL_STATE_DICT \
--fsdp_backward_prefetch_policy=BACKWARD_PRE \
- --fsdp_transformer_layer_cls_to_wrap=LlamaDecoderLayer \
+ --fsdp_transformer_layer_cls_to_wrap=$WRAP_MODULE \
--fsdp_offload_params=false \
--main_training_function=main \
--mixed_precision=bf16 \
@@ -29,7 +37,7 @@ accelerate launch \
--same_network \
--machine_rank=0 \
--rdzv_backend=static \
- pefts/mft_accelerate.py --train_config configs/"xxx_train_config.json" \
+ pefts/mft_accelerate.py --train_config "$CONFIG" \
--distributed_type "fsdp" \
> MFTCoder-training-"$TODAY".log 2>&1 &
diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp
new file mode 100644
index 0000000..8458a9b
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp
@@ -0,0 +1,198 @@
+#include
+#include
+#include
+
+// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_256.cpp
+void vecquant8matmul_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros,
+ torch::Tensor g_idx
+);
+
+void vecquant8matmul(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros,
+ torch::Tensor g_idx
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant8matmul_batched_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_column_compression_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant4matmul_batched_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant4matmul_batched_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant4matmul_batched_column_compression_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched_column_compression(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant4matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_old_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_old(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+void vecquant4matmul_batched_old_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched_old(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant4matmul_batched_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_column_compression_old_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression_old(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant4matmul_batched_column_compression_old_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant4matmul_batched_column_compression_old(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant4matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+
+void vecquant8matmul_batched_faster_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_faster(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_faster_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+void vecquant8matmul_batched_faster_old_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_faster_old(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_faster_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+void vecquant8matmul_batched_column_compression_faster_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression_faster(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_column_compression_faster_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+void vecquant8matmul_batched_column_compression_faster_old_cuda(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+);
+
+void vecquant8matmul_batched_column_compression_faster_old(
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+ torch::Tensor scales, torch::Tensor zeros
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+ vecquant8matmul_batched_column_compression_faster_old_cuda(vec, mat, mul, scales, zeros);
+}
+
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+ m.def("vecquant8matmul_batched", &vecquant8matmul_batched, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+ m.def("vecquant8matmul_batched_old", &vecquant8matmul_batched_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+ m.def("vecquant8matmul_batched_faster", &vecquant8matmul_batched_faster, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+ m.def("vecquant8matmul_batched_faster_old", &vecquant8matmul_batched_faster_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+ m.def("vecquant4matmul_batched_old", &vecquant4matmul_batched_old, "Vector 4-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+ m.def("vecquant8matmul_batched_column_compression", &vecquant8matmul_batched_column_compression, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+ m.def("vecquant8matmul_batched_column_compression_old", &vecquant8matmul_batched_column_compression_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+ m.def("vecquant8matmul_batched_column_compression_faster", &vecquant8matmul_batched_column_compression_faster, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+ m.def("vecquant8matmul_batched_column_compression_faster_old", &vecquant8matmul_batched_column_compression_faster_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+ m.def("vecquant4matmul_batched_column_compression_old", &vecquant4matmul_batched_column_compression_old, "Vector old 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+ m.def("vecquant4matmul_batched", &vecquant4matmul_batched, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)");
+ m.def("vecquant4matmul_batched_column_compression", &vecquant4matmul_batched_column_compression, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)");
+}
diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu
new file mode 100644
index 0000000..b7932cd
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu
@@ -0,0 +1,1708 @@
+#define _CRT_SECURE_NO_WARNINGS
+#include
+#include
+#include
+#include
+#include
+#include
+
+#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
+// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu
+__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
+ unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2));
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+
+ do {
+ assumed = old;
+ unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff);
+ hsum += val;
+ old = reinterpret_cast(address) & 2
+ ? (old & 0xffff) | (hsum << 16)
+ : (old & 0xffff0000) | hsum;
+ old = atomicCAS(address_as_ui, assumed, old);
+
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+ } while (assumed != old);
+}
+__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) {
+ unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+
+ do {
+ assumed = old;
+ __half_raw hsum;
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+ half tmpres = __hadd(hsum, val);
+ hsum = __half_raw(tmpres);
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+ old = atomicCAS(address_as_ui, assumed, old);
+ } while (assumed != old);
+}
+#endif
+
+template
+__global__ void VecQuant8MatMulKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ const int* __restrict__ g_idx,
+ int batch,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+);
+
+template
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+);
+
+template
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+);
+
+template
+__global__ void VecQuant8BatchMatMulKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+);
+
+template
+__global__ void VecQuant4BatchMatMulKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+);
+
+
+
+template
+__global__ void VecQuant8BatchMatMulKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+);
+
+__global__ void VecQuant8BatchMatMulKernel_faster(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+);
+
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster_old(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width
+);
+
+
+template
+__global__ void VecQuant4BatchMatMulKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+);
+
+
+template
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+);
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+);
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+);
+
+
+template
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+);
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width
+);
+
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+);
+
+const int BLOCKWIDTH = 128;
+const int BLOCKHEIGHT8 = 32;
+const int BLOCKHEIGHT4 = 16;
+const int BLOCKHEIGHT_OLD4 = 128;
+//const int BLOCKHEIGHT_OLD8 = 128;
+
+__device__ inline unsigned int as_unsigned(int i) {
+ return *reinterpret_cast(&i);
+}
+
+__device__ inline int as_int(int i) {
+ return *reinterpret_cast(&i);
+}
+
+void vecquant8matmul_batched_column_compression_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int height = vec.size(3);
+ int width = mat.size(3) * 4;
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant8matmul_batched_cuda", ([&] {
+ VecQuant8BatchMatMulColumnCompressionKernel<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, height, width
+ );
+ })
+ );
+
+}
+
+template
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+) {
+ int weight_total = batch * heads * height * width / 4;
+ int input_total = batch * heads * vec_row * height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKWIDTH
+ int h = BLOCKWIDTH * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ int k;
+ scalar_t w_tmp;
+
+ float weight[BLOCKWIDTH];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
+ int i_w = (w / 4);
+ int w_bit = (w % 4) * 8;
+
+ int w_index = (batch_shift * height + h + k) * width / 4 + i_w;
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * height + h + k];
+ scalar_t zero = zeros[batch_shift * height + h + k];
+ w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xFF);
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+void vecquant8matmul_batched_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int vec_height = vec.size(3);
+ int height = mat.size(2);
+ int width = mat.size(3);
+ int zero_width = zeros.size(2);
+
+ dim3 blocks(
+ (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant8matmul_batched_cuda", ([&] {
+ VecQuant8BatchMatMulKernel<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, vec_height, height, width, zero_width
+ );
+ })
+ );
+
+}
+
+template
+__global__ void VecQuant8BatchMatMulKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * vec_height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKHEIGHT8
+ int h = BLOCKHEIGHT8 * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= vec_height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ // i is index of mat of block first row
+ int i = width * h + w;
+ // if (i >= width * height) {
+ // return;
+ // }
+ int k;
+ scalar_t w_tmp;
+
+ int z_w = w / 4;
+ int z_mod = (w % 4) * 8;
+
+ float weight[BLOCKWIDTH];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){
+ int k_w = (k / 4);
+ int k_bit = (k % 4) * 8;
+
+ int w_index = batch_shift * height * width + i + (k_w * width);
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * width + w];
+ scalar_t zero;
+ if (zero_width == width) {
+ zero = zeros[batch_shift * width + w];
+ } else {
+ zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
+ }
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xFF);
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+void vecquant8matmul_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros,
+ torch::Tensor g_idx
+) {
+ int batch = vec.size(0);
+ int vec_height = vec.size(1);
+ int height = mat.size(0);
+ int width = mat.size(1);
+ int zero_width = zeros.size(1);
+
+ dim3 blocks(
+ (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant8matmul_cuda", ([&] {
+ VecQuant8MatMulKernel<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(), g_idx.data(),
+ batch, vec_height, height, width, zero_width
+ );
+ })
+ );
+}
+
+template
+__global__ void VecQuant8MatMulKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ const int* __restrict__ g_idx,
+ int batch,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+) {
+ int h = BLOCKHEIGHT8 * blockIdx.x;
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ int i = width * h + w;
+ int g_h = h * 4;
+ int k;
+ unsigned int g;
+ scalar_t w_tmp;
+
+ int z_w = w / 4;
+ int z_mod = (w % 4) * 8;
+
+ float weight[BLOCKWIDTH];
+
+ for (k = 0; k < BLOCKWIDTH; ++k){
+ int k_w = (k / 4);
+ int k_bit = (k % 4) * 8;
+
+ g = as_int(g_idx[g_h + k]);
+ scalar_t scale = scales[g * width + w];
+ scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
+
+ w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
+
+ weight[k] = scale * (w_tmp - zero);
+ }
+
+
+ scalar_t res;
+ for (int b = 0; b < batch; ++b){
+ res = 0;
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH; ++k){
+ res += weight[k] * blockvec[k];
+ }
+ atomicAdd(&mul[b * width + w], res);
+ __syncthreads();
+ }
+}
+
+
+
+void vecquant4matmul_batched_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int vec_height = vec.size(3);
+ int height = mat.size(2);
+ int width = mat.size(3);
+ int zero_width = zeros.size(2);
+
+ dim3 blocks(
+ (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant4matmul_batched_cuda", ([&] {
+ VecQuant4BatchMatMulKernel<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, vec_height, height, width, zero_width
+ );
+ })
+ );
+
+}
+
+template
+__global__ void VecQuant4BatchMatMulKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * vec_height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKHEIGHT4
+ int h = BLOCKHEIGHT4 * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= vec_height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ // i is index of mat of block first row
+ int i = width * h + w;
+ int k;
+ scalar_t w_tmp;
+
+ int z_w = w / 8;
+ int z_mod = (w % 8) * 4;
+
+ float weight[BLOCKWIDTH];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){
+ int k_w = (k / 8);
+ int k_bit = (k % 8) * 4;
+
+ int w_index = batch_shift * height * width + i + (k_w * width);
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * width + w];
+ scalar_t zero;
+ if (zero_width == width) {
+ zero = zeros[batch_shift * width + w];
+ } else {
+ zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xF));
+ }
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+
+void vecquant4matmul_batched_column_compression_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int height = vec.size(3);
+ int width = mat.size(3) * 8;
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant4matmul_batched_cuda", ([&] {
+ VecQuant4BatchMatMulColumnCompressionKernel<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, height, width
+ );
+ })
+ );
+
+}
+
+template
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel(
+ const scalar_t* __restrict__ vec,
+ const int* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const int* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+) {
+ int weight_total = batch * heads * height * width / 8;
+ int input_total = batch * heads * vec_row * height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKWIDTH
+ int h = BLOCKWIDTH * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ int k;
+ scalar_t w_tmp;
+
+ float weight[BLOCKWIDTH];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
+ int i_w = (w / 8);
+ int w_bit = (w % 8) * 4;
+
+ int w_index = (batch_shift * height + h + k) * width / 8 + i_w;
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * height + h + k];
+ scalar_t zero = zeros[batch_shift * height + h + k];
+ w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xF);
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+void vecquant8matmul_batched_old_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int vec_height = vec.size(3);
+ int height = mat.size(2);
+ int width = mat.size(3);
+ int zero_width = zeros.size(2);
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant8matmul_batched_old_cuda", ([&] {
+ VecQuant8BatchMatMulKernel_old<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, vec_height, height, width, zero_width
+ );
+ })
+ );
+}
+
+
+template
+__global__ void VecQuant8BatchMatMulKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * vec_height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKHEIGHT8
+ int h = BLOCKWIDTH * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= vec_height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ // i is index of mat of block first row
+ int i = width * h + w;
+ int k;
+ scalar_t w_tmp;
+
+ float weight[BLOCKWIDTH];
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
+ int k_w = k;
+ int w_index = batch_shift * height * width + i + (k_w * width);
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * width + w];
+ scalar_t zero = zeros[batch_shift * width + w];
+ w_tmp = as_unsigned(mat[w_index]);
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+
+void vecquant8matmul_batched_faster_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int vec_height = vec.size(3);
+ int height = mat.size(2);
+ int width = mat.size(3);
+ int zero_width = zeros.size(2);
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ VecQuant8BatchMatMulKernel_faster<<>>(
+ (half*) vec.data_ptr(),
+ (uint8_t*) mat.data_ptr(),
+ (half*) mul.data_ptr(),
+ (half*) scales.data_ptr(),
+ (half*) zeros.data_ptr(),
+ batch, heads, vec_row, vec_height, height, width, zero_width
+ );
+}
+
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+) {
+ //int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * vec_height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ int h = BLOCKWIDTH * blockIdx.x;
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= height) {
+ return;
+ }
+
+ __shared__ float blockvec[BLOCKWIDTH];
+ int i = width * h + w;
+ int k;
+ float w_tmp;
+
+ float weight[BLOCKWIDTH];
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
+ int k_w = k;
+ int w_index = batch_shift * height * width + i + (k_w * width);
+ float scale = __half2float(scales[batch_shift * width + w]);
+ float zero = __half2float(zeros[batch_shift * width + w]);
+ w_tmp = as_unsigned(mat[w_index]);
+ weight[k] = scale *(w_tmp-zero);
+ }
+
+ float res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = __half2float(vec[vec_index]);
+ } else {
+ blockvec[tid] = 0;
+ }
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
+ float temp_res = weight[k]*blockvec[k];
+ res += temp_res;
+ }
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], __float2half(res));
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+
+
+void vecquant8matmul_batched_column_compression_faster_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int height = vec.size(3);
+ int width = mat.size(3);
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ VecQuant8BatchMatMulColumnCompressionKernel_faster<<>>(
+ (half*) vec.data_ptr(),
+ (uint8_t*) mat.data_ptr(),
+ (half*) mul.data_ptr(),
+ (half*) scales.data_ptr(),
+ (half*) zeros.data_ptr(),
+ batch, heads, vec_row, height, width
+ );
+
+}
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+) {
+ //int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ int h = BLOCKWIDTH * blockIdx.x;
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= height) {
+ return;
+ }
+
+ __shared__ float blockvec[BLOCKWIDTH];
+ int k;
+ float w_tmp;
+ float weight[BLOCKWIDTH];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH; ++k){
+ int w_index = (batch_shift * height + h + k) * width + w;
+ float scale = __half2float(scales[batch_shift * height + h + k]);
+ float zero = __half2float(zeros[batch_shift * height + h + k]);
+ w_tmp = mat[w_index];
+ weight[k] = scale * (w_tmp-zero);
+ }
+
+ float res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = __half2float(vec[vec_index]);
+ } else {
+ blockvec[tid] = 0;
+ }
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH; ++k){
+ res += weight[k]*blockvec[k];
+ }
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], __float2half(res));
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+
+void vecquant8matmul_batched_column_compression_old_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int height = vec.size(3);
+ int width = mat.size(3);
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant8matmul_batched_column_compression_old_cuda", ([&] {
+ VecQuant8BatchMatMulColumnCompressionKernel_old<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, height, width
+ );
+ })
+ );
+
+}
+
+template
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKWIDTH
+ int h = BLOCKWIDTH * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ int k;
+ scalar_t w_tmp;
+
+ float weight[BLOCKWIDTH];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
+ int w_index = (batch_shift * height + h + k) * width + w;
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * height + h + k];
+ scalar_t zero = zeros[batch_shift * height + h + k];
+ w_tmp = mat[w_index];
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+void vecquant4matmul_batched_old_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int vec_height = vec.size(3);
+ int height = mat.size(2);
+ int width = mat.size(3);
+ int zero_width = zeros.size(2);
+
+ dim3 blocks(
+ (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant4matmul_batched_old_cuda", ([&] {
+ VecQuant4BatchMatMulKernel_old<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, vec_height, height, width, zero_width
+ );
+ })
+ );
+
+}
+
+template
+__global__ void VecQuant4BatchMatMulKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width,
+ int zero_width
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * vec_height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKHEIGHT_OLD4
+ int h = BLOCKHEIGHT_OLD4 * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= vec_height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ // i is index of mat of block first row
+ int i = width * h + w;
+ int k;
+ scalar_t w_tmp;
+
+ float weight[BLOCKWIDTH];
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){
+ int k_w = (k / 2);
+ int k_bit = (k % 2) * 4;
+ int w_index = batch_shift * height * width + i + (k_w * width);
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * width + w];
+ scalar_t zero = zeros[batch_shift * width + w];
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+
+
+
+void vecquant4matmul_batched_column_compression_old_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int height = vec.size(3);
+ int width = mat.size(3);
+
+ dim3 blocks(
+ (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ AT_DISPATCH_FLOATING_TYPES(
+ vec.type(), "vecquant4matmul_batched_column_compression_old_cuda", ([&] {
+ VecQuant4BatchMatMulColumnCompressionKernel_old<<>>(
+ vec.data(), mat.data(), mul.data(),
+ scales.data(), zeros.data(),
+ batch, heads, vec_row, height, width
+ );
+ })
+ );
+
+}
+
+template
+__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old(
+ const scalar_t* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ scalar_t* __restrict__ mul,
+ const scalar_t* __restrict__ scales,
+ const scalar_t* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int height,
+ int width
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ // h is index of height with step being BLOCKWIDTH
+ int h = BLOCKHEIGHT_OLD4 * blockIdx.x;
+ // w is index of width with step being 1
+ int w = BLOCKWIDTH * blockIdx.y + tid;
+ if (w >= width && tid >= height) {
+ return;
+ }
+
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
+ int k;
+ scalar_t w_tmp;
+
+ float weight[BLOCKWIDTH];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){
+ int k_w = (k / 2);
+ int k_bit = (k % 2) * 4;
+ int w_index = (batch_shift * height + h + k) * width + k_w;
+ if (w_index >= weight_total || w >= width) {
+ weight[k] = 0;
+ } else {
+ scalar_t scale = scales[batch_shift * height + h + k];
+ scalar_t zero = zeros[batch_shift * height + h + k];
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
+ weight[k] = scale * (w_tmp - zero);
+ }
+ }
+
+ scalar_t res;
+ for (int vr = 0; vr < vec_row; ++vr){
+ res = 0;
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+ if (vec_index < input_total) {
+ blockvec[tid] = vec[vec_index];
+ } else {
+ blockvec[tid] = 0;
+ }
+
+ __syncthreads();
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){
+ // res is the dot product of BLOCKWIDTH elements (part of width)
+ res += weight[k] * blockvec[k];
+ }
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (out_index < out_total) {
+ atomicAdd(&mul[out_index], res);
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+
+
+
+void vecquant8matmul_batched_faster_old_cuda(
+ torch::Tensor vec,
+ torch::Tensor mat,
+ torch::Tensor mul,
+ torch::Tensor scales,
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2);
+ int vec_height = vec.size(3);
+ int height = mat.size(2);
+ int width = mat.size(3);
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ VecQuant8BatchMatMulKernel_faster_old<<>>(
+ (half*) vec.data_ptr(),
+ (uint8_t*) mat.data_ptr(),
+ (half*) mul.data_ptr(),
+ (half*) scales.data_ptr(),
+ (half*) zeros.data_ptr(),
+ batch, heads, vec_row, vec_height, height, width
+ );
+}
+
+
+__global__ void VecQuant8BatchMatMulKernel_faster_old(
+ const half* __restrict__ vec,
+ const uint8_t* __restrict__ mat,
+ half* __restrict__ mul,
+ const half* __restrict__ scales,
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row,
+ int vec_height,
+ int height,
+ int width
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * vec_height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ const int BLOCKWIDTH_half = BLOCKWIDTH/2;
+
+ int h = BLOCKWIDTH * blockIdx.x; //head_dim, dim=-1
+ int w = BLOCKWIDTH * blockIdx.y + tid; //seq-len, +0-256 ,dim=-2
+ /*
+ if (w >= width && tid >= vec_height) {
+ return;
+ }
+ */
+ __shared__ half blockvec[BLOCKWIDTH]; //256
+ int i = width * h + w;
+ int k;
+
+ half w_tmp1 = __float2half(0);
+ half w_tmp2 = __float2half(0);
+
+ half2 weight[BLOCKWIDTH_half];
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ //int zero_index = batch_shift;
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
+ int w_index1 = batch_shift * height * width + i + (2 * k * width); // [batch,head,h+k, w]
+ int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width);
+ int zero_index = batch_shift * width + w; // [batch,head, w]
+ if (w_index1 >= weight_total || w >= width || (2 * k + h) >= height) {
+ weight[k] = __float2half2_rn(0);
+ } else {
+ float zero_f=__half2float(zeros[zero_index]);
+ float scale_f= __half2float(scales[zero_index]);
+ if (w_index2 >= weight_total){
+ w_tmp1 = __float2half((as_unsigned(mat[w_index1]) -zero_f)*scale_f);
+ w_tmp2 = __float2half(0);
+ weight[k] = __halves2half2(w_tmp1,w_tmp2);
+ //printf("zero_index is %d w is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,w,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k]));
+ }else{
+ w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1]));
+ w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2]));
+
+ //weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero,zero)),__halves2half2(scale,scale));
+ weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f)));
+ //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k]));
+ }
+ }
+ }
+
+
+ for (int vr = 0; vr < vec_row; ++vr){
+ float res=0;
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+ if (vec_index < input_total) {
+ //blockvec[tid] = __half2float(vec[vec_index]);// [batch, head, vr, tid(seq_len dim+)]
+ blockvec[tid] = vec[vec_index];
+ //printf("width is %d height is %d h is %d w is %d vec_index is %d out_index is %d vec_row is %d vec_height is %d,vr is %d tid is %d blockvec is %f\n",width,height, h,w,vec_index,out_index,vec_row,vec_height,vr,tid,blockvec[tid]);
+ } else {
+ blockvec[tid] = __float2half(0);
+ }
+ __syncthreads();
+ if (out_index < out_total) {
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
+ half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1]));
+ res += __low2float(res2) + __high2float(res2);
+ }
+ atomicAdd(&mul[out_index], __float2half(res));
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
+
+
+void vecquant8matmul_batched_column_compression_faster_old_cuda(
+ torch::Tensor vec, // [batch,heads, seq_q, seq_v]
+ torch::Tensor mat, // [batch,heads, seq_v, head_dim]
+ torch::Tensor mul, // [batch,heads, seq_q,head_dim]
+ torch::Tensor scales, // [batch,heads, head_dim]
+ torch::Tensor zeros
+) {
+ int batch = vec.size(0);
+ int heads = vec.size(1);
+ int vec_row = vec.size(2); //ql
+ int height = mat.size(2); //vl
+ int width = mat.size(3); //head_dim
+
+ dim3 blocks(
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+ );
+ dim3 threads(BLOCKWIDTH);
+
+ VecQuant8BatchMatMulColumnCompressionKernel_faster_old<<>>(
+ (half*) vec.data_ptr(),
+ (uint8_t*) mat.data_ptr(),
+ (half*) mul.data_ptr(),
+ (half*) scales.data_ptr(),
+ (half*) zeros.data_ptr(),
+ batch, heads, vec_row, height, width
+ );
+
+}
+
+
+__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old(
+ const half* __restrict__ vec, // [batch,heads, seq_q, seq_v]
+ const uint8_t* __restrict__ mat, // [batch,heads, seq_v, head_dim]
+ half* __restrict__ mul, // [batch,heads, seq_q,head_dim]
+ const half* __restrict__ scales, // [batch,heads, seq_v]
+ const half* __restrict__ zeros,
+ int batch,
+ int heads,
+ int vec_row, //seq_q
+ int height, //seq_v
+ int width //head_dim
+) {
+ int weight_total = batch * heads * height * width;
+ int input_total = batch * heads * vec_row * height;
+ int out_total = batch * heads * vec_row * width;
+ int tid = threadIdx.x;
+ int h = BLOCKWIDTH * blockIdx.x; // vl
+ int w = BLOCKWIDTH * blockIdx.y + tid; //head_dim + block
+ if (w >= width && tid >= height) {
+ return;
+ }
+ __shared__ half blockvec[BLOCKWIDTH];
+ int k;
+ half w_tmp1 = __float2half(0);
+ half w_tmp2 = __float2half(0);
+ int i = width * h + w;
+ const int BLOCKWIDTH_half = BLOCKWIDTH/2;
+ half2 weight[BLOCKWIDTH_half];
+
+ for (int b = 0; b < batch; ++b){
+ for (int head = 0; head < heads; ++head){
+ int batch_shift = b * heads + head;
+ //int zero_index = batch_shift;
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
+ int w_index1 = batch_shift * height * width + i + (2 * k) * width; // [batch,head, h+k, w]
+ int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width);
+ int zero_index1 = batch_shift * height + h + 2*k; // [batch,head, w]
+ int zero_index2 = batch_shift * height + h + 2*k+1; // [batch,head, w]
+
+ if (w_index1 >= weight_total || (2 * k + h)>=height) {
+ weight[k]=__float2half2_rn(0);
+ } else{
+ //int zero_index = batch_shift + h; // [batch,head, w]
+ //float scale_f1 = __half2float(scales[zero_index1]);
+ //float zero_f1 = __half2float(zeros[zero_index1]);
+ if (w_index2>=weight_total){
+ w_tmp1 = __float2half((as_unsigned(mat[w_index1]) - __half2float(zeros[zero_index1]))* __half2float(scales[zero_index1]));
+ w_tmp2 = __float2half(0);
+ weight[k] = __halves2half2(w_tmp1,w_tmp2);
+ //printf("zero_index is %d k is %d w is %d head is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,k,w,head,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k]));
+ }else{
+ w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1]));
+ w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2]));
+ half zero1=zeros[zero_index1];
+ half zero2=zeros[zero_index2];
+ half scale1=scales[zero_index1];
+ half scale2=scales[zero_index2];
+ weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero1,zero2)),__halves2half2(scale1,scale2));
+ //weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f)));
+ //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k]));
+ }
+ }
+ }
+
+
+ for (int vr = 0; vr < vec_row; ++vr){
+ float res=0;
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
+ int out_index = (batch_shift * vec_row + vr) * width + w;
+
+ if (vec_index < input_total) {
+ //blockvec[tid] = __half2float(vec[vec_index]);
+ blockvec[tid] = vec[vec_index];
+ //printf("vec_index is %d out_index is %d vec_row is %d ,vr is %d tid is %d blockvec is %f\n",vec_index,out_index,vec_row,vr,tid,blockvec[tid]);
+ } else {
+ blockvec[tid] = __float2half(0);
+ //blockvec[tid] = 0;
+ }
+ __syncthreads();
+ if (out_index < out_total) {
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
+ half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1]));
+ res += __low2float(res2) + __high2float(res2);
+ }
+ atomicAdd(&mul[out_index], __float2half(res));
+ }
+ __syncthreads();
+ }
+ }
+ }
+}
diff --git a/mftcoder_accelerate/src/model/qwen/configuration_qwen.py b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py
index 2ccfc92..f8fe2cb 100644
--- a/mftcoder_accelerate/src/model/qwen/configuration_qwen.py
+++ b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py
@@ -35,6 +35,9 @@ def __init__(
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
+ use_cache_quantization=False,
+ use_cache_kernel=False,
+ softmax_in_fp32=False,
**kwargs,
):
self.vocab_size = vocab_size
@@ -59,6 +62,9 @@ def __init__(
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
+ self.use_cache_quantization = use_cache_quantization
+ self.use_cache_kernel = use_cache_kernel
+ self.softmax_in_fp32 = softmax_in_fp32
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
diff --git a/mftcoder_accelerate/src/model/qwen/cpp_kernels.py b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py
new file mode 100644
index 0000000..d9cee70
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py
@@ -0,0 +1,55 @@
+from torch.utils import cpp_extension
+import pathlib
+import os
+import subprocess
+
+def _get_cuda_bare_metal_version(cuda_dir):
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
+ universal_newlines=True)
+ output = raw_output.split()
+ release_idx = output.index("release") + 1
+ release = output[release_idx].split(".")
+ bare_metal_major = release[0]
+ bare_metal_minor = release[1][0]
+
+ return raw_output, bare_metal_major, bare_metal_minor
+
+def _create_build_dir(buildpath):
+ try:
+ os.mkdir(buildpath)
+ except OSError:
+ if not os.path.isdir(buildpath):
+ print(f"Creation of the build directory {buildpath} failed")
+
+# Check if cuda 11 is installed for compute capability 8.0
+cc_flag = []
+_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
+if int(bare_metal_major) >= 11:
+ cc_flag.append('-gencode')
+ cc_flag.append('arch=compute_80,code=sm_80')
+ if int(bare_metal_minor) >= 7:
+ cc_flag.append('-gencode')
+ cc_flag.append('arch=compute_90,code=sm_90')
+
+# Build path
+srcpath = pathlib.Path(__file__).parent.absolute()
+buildpath = srcpath / 'build'
+_create_build_dir(buildpath)
+
+def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
+ return cpp_extension.load(
+ name=name,
+ sources=sources,
+ build_directory=buildpath,
+ extra_cflags=['-O3', ],
+ extra_cuda_cflags=['-O3',
+ '-gencode', 'arch=compute_70,code=sm_70',
+ '--use_fast_math'] + extra_cuda_flags + cc_flag,
+ verbose=1
+ )
+
+extra_flags = []
+
+cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp",
+ "./cache_autogptq_cuda_kernel_256.cu"]
+cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
diff --git a/mftcoder_accelerate/src/model/qwen/modeling_qwen.py b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py
index 05264b4..45c0d16 100644
--- a/mftcoder_accelerate/src/model/qwen/modeling_qwen.py
+++ b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py
@@ -3,14 +3,16 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+import copy
import importlib
import math
+import pathlib
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
-from torch.cuda.amp import autocast
+import warnings
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
@@ -35,6 +37,8 @@
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
+SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
+
from .configuration_qwen import QWenConfig
from .qwen_generation_utils import (
@@ -74,10 +78,10 @@
apply_rotary_emb_func = None
rms_norm = None
flash_attn_unpadded_func = None
-
+flash_attn_func = None
def _import_flash_attn():
- global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func
try:
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
apply_rotary_emb_func = __apply_rotary_emb_func
@@ -98,20 +102,49 @@ def _import_flash_attn():
try:
import flash_attn
+ _flash_attn_func = None
if not hasattr(flash_attn, '__version__'):
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
else:
if int(flash_attn.__version__.split(".")[0]) >= 2:
+ if int(flash_attn.__version__.split(".")[1]) >= 1:
+ from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
flash_attn_unpadded_func = __flash_attn_unpadded_func
+ flash_attn_func = _flash_attn_func
except ImportError:
logger.warn(
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention"
)
+def quantize_cache_v(fdata, bits, qmax, qmin):
+ # b, s, head, h-dim->b, head, s, h-dim
+ qtype = torch.uint8
+ device = fdata.device
+ shape = fdata.shape
+
+ fdata_cal = torch.flatten(fdata, 2)
+ fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
+ fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
+ # Compute params
+ if qmax.device != fmax.device:
+ qmax = qmax.to(device)
+ qmin = qmin.to(device)
+ scale = (fmax - fmin) / (qmax - qmin)
+ zero = qmin - fmin / scale
+ scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
+ zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
+ # Quantize
+ res_data = fdata / scale + zero
+ qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
+ return qdata.contiguous(), scale, zero
+
+def dequantize_cache_torch(qdata, scale, zero):
+ data = scale * (qdata - zero)
+ return data
class FlashSelfAttention(torch.nn.Module):
def __init__(
@@ -151,6 +184,12 @@ def forward(self, q, k, v, attention_mask=None):
assert all((i.is_cuda for i in (q, k, v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
+ seqlen_out = seqlen_q
+
+ if flash_attn_func is not None and batch_size == 1:
+ dropout_p = self.dropout_p if self.training else 0
+ output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
+ return output
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
cu_seqlens_q = torch.arange(
@@ -161,12 +200,13 @@ def forward(self, q, k, v, attention_mask=None):
device=q.device,
)
- if attention_mask is not None:
+ if batch_size > 1 and attention_mask is not None:
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
- v = v[indices_k]
- if seqlen_q == seqlen_k:
+ if q.size(0) == v.size(0):
q = q[indices_k]
cu_seqlens_q = cu_seqlens_k
+ seqlen_q = seqlen_k
+ v = v[indices_k]
else:
cu_seqlens_k = torch.arange(
0,
@@ -196,8 +236,8 @@ def forward(self, q, k, v, attention_mask=None):
softmax_scale=self.softmax_scale,
causal=is_causal,
)
- if attention_mask is not None and seqlen_q == seqlen_k:
- output = self.pad_input(output, indices_k, batch_size, seqlen_q)
+ if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
+ output = self.pad_input(output, indices_k, batch_size, seqlen_out)
else:
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
output = output.view(new_shape)
@@ -254,99 +294,100 @@ def __init__(self, config):
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
-
- def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
+ self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
+ self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
+ self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
+ cache_dtype = torch.float
+ if self.bf16:
+ cache_dtype=torch.bfloat16
+ elif config.fp16:
+ cache_dtype = torch.float16
+ self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
+ self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
+
+ if config.use_cache_quantization and config.use_cache_kernel:
+ # pre check if the support files existing
+ module_root = pathlib.Path(__file__).parent
+ src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
+ if any(not (module_root/src).is_file() for src in src_files):
+ warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
+ self.cache_kernels = None
+ else:
+ try:
+ from .cpp_kernels import cache_autogptq_cuda_256
+ self.cache_kernels = cache_autogptq_cuda_256
+ except ImportError:
+ warnings.warn("Failed to import KV cache kernels.")
+ self.cache_kernels = None
+
+ def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
+ device = query.device
+ if self.use_cache_quantization:
+ qk, qk_scale, qk_zero = key
+ if self.use_cache_kernel and self.cache_kernels is not None:
+ shape = query.shape[:-1] + (qk.shape[-2],)
+ attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
+ self.cache_kernels.vecquant8matmul_batched_faster_old(
+ query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
+ qk.transpose(-1, -2).contiguous(),
+ attn_weights,
+ qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
+ qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
+ # attn_weights = attn_weights.to(query.dtype).contiguous()
+ else:
+ key = dequantize_cache_torch(qk, qk_scale, qk_zero)
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+ else:
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
- attn_weights = attn_weights / torch.full(
- [],
- value.size(-1) ** 0.5,
- dtype=attn_weights.dtype,
- device=attn_weights.device,
- )
+ if self.use_cache_quantization:
+ size_temp = value[0].size(-1)
+ else:
+ size_temp = value.size(-1)
+ attn_weights = attn_weights / (size_temp ** 0.5)
- query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = registered_causal_mask[
- :, :, key_length - query_length : key_length, :key_length
- ]
mask_value = torch.finfo(attn_weights.dtype).min
- mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
- attn_weights.device
- )
- attn_weights = torch.where(
- causal_mask, attn_weights.to(attn_weights.dtype), mask_value
- )
-
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
-
- attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
-
- attn_weights = attn_weights.type(value.dtype)
- attn_weights = self.attn_dropout(attn_weights)
-
- if head_mask is not None:
- attn_weights = attn_weights * head_mask
-
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2)
-
- return attn_output, attn_weights
-
- def _upcast_and_reordered_attn(
- self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
- ):
- bsz, num_heads, q_seq_len, dk = query.size()
- _, _, k_seq_len, _ = key.size()
-
- attn_weights = torch.empty(
- bsz * num_heads,
- q_seq_len,
- k_seq_len,
- dtype=torch.float32,
- device=query.device,
- )
-
- scale_factor = 1.0
- if self.scale_attn_weights:
- scale_factor /= float(value.size(-1)) ** 0.5
-
- with autocast(enabled=False):
- q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
- -1, dk, k_seq_len
- )
- attn_weights = torch.baddbmm(
- attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
+ if causal_mask is not None:
+ attn_weights = torch.where(
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
)
- attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
-
- query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = registered_causal_mask[
- :, :, key_length - query_length : key_length, :key_length
- ]
- mask_value = torch.finfo(attn_weights.dtype).min
- mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
- attn_weights.device
- )
- attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ if self.softmax_in_fp32:
+ attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
+ else:
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- if attn_weights.dtype != torch.float32:
- raise RuntimeError(
- "Error with upcasting, attn_weights does not have dtype torch.float32"
- )
- attn_weights = attn_weights.type(value.dtype)
+ attn_weights = attn_weights.type(query.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
- attn_output = torch.matmul(attn_weights, value)
+ if self.use_cache_quantization:
+ qv, qv_scale, qv_zero = value
+ if self.use_cache_kernel and self.cache_kernels is not None:
+ shape = attn_weights.shape[:-1] + (query.shape[-1],)
+ attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
+ self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
+ attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
+ qv.contiguous(), # dtype: int32
+ attn_output,
+ qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
+ qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
+ if attn_output.dtype != query.dtype:
+ attn_output = attn_output.to(query.dtype)
+ attn_weights = attn_weights.to(query.dtype)
+ else:
+ value = dequantize_cache_torch(qv, qv_scale, qv_zero)
+ attn_output = torch.matmul(attn_weights, value)
+ else:
+ attn_output = torch.matmul(attn_weights, value)
+
+ attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
@@ -363,8 +404,7 @@ def _merge_heads(self, tensor, num_heads, attn_head_size):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
- rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
- registered_causal_mask: Optional[torch.Tensor] = None,
+ rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
@@ -373,7 +413,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
-
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
@@ -405,20 +444,49 @@ def forward(
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
+ if self.use_cache_quantization:
+ key = quantize_cache_v(key.permute(0, 2, 1, 3),
+ bits=8,
+ qmin=self.cache_qmin,
+ qmax=self.cache_qmax)
+ value = quantize_cache_v(value.permute(0, 2, 1, 3),
+ bits=8,
+ qmin=self.cache_qmin,
+ qmax=self.cache_qmax)
+
+
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
- key = torch.cat((past_key, key), dim=1)
- value = torch.cat((past_value, value), dim=1)
+ if self.use_cache_quantization:
+ # use_cache_quantization:
+ # present=((q_key,key_scale,key_zero_point),
+ # (q_value,value_scale,value_zero_point))
+ key = (torch.cat((past_key[0], key[0]), dim=2),
+ torch.cat((past_key[1], key[1]), dim=2),
+ torch.cat((past_key[2], key[2]), dim=2))
+ value = (torch.cat((past_value[0], value[0]), dim=2),
+ torch.cat((past_value[1], value[1]), dim=2),
+ torch.cat((past_value[2], value[2]), dim=2))
+ else:
+ # not use_cache_quantization:
+ # present=(key,value)
+ key = torch.cat((past_key, key), dim=1)
+ value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value)
else:
present = None
- if self.use_logn_attn and not self.training:
- seq_start = key.size(1) - query.size(1)
- seq_end = key.size(1)
- logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
+ key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
+ if key_size > self.seq_length and self.use_logn_attn and not self.training:
+ if self.use_cache_quantization:
+ seq_start = key[0].size(2) - query.size(1)
+ seq_end = key[0].size(2)
+ else:
+ seq_start = key.size(1) - query.size(1)
+ seq_end = key.size(1)
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query)
if (
@@ -428,29 +496,46 @@ def forward(
and query.is_cuda
):
q, k, v = query, key, value
- context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
-
- # b s h d -> b s (h d)
- context_layer = context_layer.flatten(2,3).contiguous()
-
+ attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
else:
+ key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
+ if query.size(1) == key_size:
+ causal_mask = torch.tril(
+ torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
+ ).view(1, 1, key_size, key_size)
+ else:
+ causal_mask = None
query = query.permute(0, 2, 1, 3)
- key = key.permute(0, 2, 1, 3)
- value = value.permute(0, 2, 1, 3)
+ if not self.use_cache_quantization:
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
if (
- registered_causal_mask is None
+ causal_mask is None
and self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
and not query.is_cuda
):
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
- attn_output, attn_weight = self._attn(
- query, key, value, registered_causal_mask, attention_mask, head_mask
- )
- context_layer = self._merge_heads(
- attn_output, self.num_heads, self.head_dim
- )
+
+ if not self.use_cache_quantization and SUPPORT_TORCH2:
+ if attention_mask is not None:
+ attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
+ if causal_mask is not None:
+ attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
+ else:
+ attention_mask = causal_mask
+ attn_output = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask
+ ).transpose(1, 2)
+ attn_weight = None
+ else:
+ attn_output, attn_weight = self._attn(
+ query, key, value, causal_mask, attention_mask, head_mask
+ )
+ context_layer = self._merge_heads(
+ attn_output, self.num_heads, self.head_dim
+ )
attn_output = self.c_proj(context_layer)
@@ -462,6 +547,8 @@ def forward(
and not self.is_fp32
):
raise ValueError("Cannot output attentions while using flash-attn")
+ elif not self.use_cache_quantization and SUPPORT_TORCH2:
+ raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
else:
outputs += (attn_weight,)
@@ -487,6 +574,7 @@ def forward(self, hidden_states):
output = self.c_proj(intermediate_parallel)
return output
+
class QWenBlock(nn.Module):
def __init__(self, config):
super().__init__()
@@ -508,8 +596,7 @@ def __init__(self, config):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
- rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
- registered_causal_mask: Optional[torch.Tensor] = None,
+ rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
@@ -523,7 +610,6 @@ def forward(
attn_outputs = self.attn(
layernorm_output,
rotary_pos_emb_list,
- registered_causal_mask=registered_causal_mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
@@ -557,6 +643,7 @@ class QWenPreTrainedModel(PreTrainedModel):
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = ["QWenBlock"]
+ _skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -597,6 +684,7 @@ def __init__(self, config):
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
+ self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False
self.gradient_checkpointing = False
self.use_dynamic_ntk = config.use_dynamic_ntk
@@ -622,21 +710,6 @@ def __init__(self, config):
self.use_flash_attn = config.use_flash_attn
self.is_fp32 = not (config.bf16 or config.fp16)
- if (
- self.use_flash_attn
- and flash_attn_unpadded_func is not None
- and not self.is_fp32
- ):
- self.registered_causal_mask = None
- else:
- max_positions = config.max_position_embeddings
- self.register_buffer(
- "registered_causal_mask",
- torch.tril(
- torch.ones((max_positions, max_positions), dtype=torch.bool)
- ).view(1, 1, max_positions, max_positions),
- persistent=False,
- )
self.h = nn.ModuleList(
[
@@ -721,8 +794,10 @@ def forward(
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
- past_length = past_key_values[0][0].size(-2)
-
+ if self.use_cache_quantization:
+ past_length = past_key_values[0][0][0].size(2)
+ else:
+ past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
@@ -750,7 +825,10 @@ def forward(
kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim
- kv_seq_len += past_key_values[0][0].shape[1]
+ if self.use_cache_quantization:
+ kv_seq_len += past_key_values[0][0][0].shape[2]
+ else:
+ kv_seq_len += past_key_values[0][0].shape[1]
if self.training or not self.use_dynamic_ntk:
ntk_alpha_list = [1.0]
@@ -768,11 +846,9 @@ def forward(
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
-
- rotary_pos_emb_list = []
- for ntk_alpha in ntk_alpha_list:
- rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
- rotary_pos_emb_list.append(rotary_pos_emb)
+ rotary_pos_emb_list = [
+ self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
+ ]
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
@@ -805,7 +881,6 @@ def custom_forward(*inputs):
create_custom_forward(block),
hidden_states,
rotary_pos_emb_list,
- self.registered_causal_mask,
None,
attention_mask,
head_mask[i],
@@ -817,7 +892,6 @@ def custom_forward(*inputs):
hidden_states,
layer_past=layer_past,
rotary_pos_emb_list=rotary_pos_emb_list,
- registered_causal_mask=self.registered_causal_mask,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
@@ -861,11 +935,6 @@ def __init__(self, config):
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
- logger.warn(
- "Warning: please make sure that you are using the latest codes and checkpoints, "
- "especially if you used Qwen-7B before 09.25.2023."
- "请使用最新模型和代码,尤其如果你在9月25日前已经开始使用Qwen-7B,千万注意不要使用错误代码和模型。"
- )
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
@@ -927,22 +996,13 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
- token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
- if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
- attention_mask = kwargs.get("attention_mask", None)
- position_ids = kwargs.get("position_ids", None)
-
- if attention_mask is not None and position_ids is None:
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ if input_ids.size(0) == 1:
+ attention_mask = None
else:
- position_ids = None
+ attention_mask = kwargs.get("attention_mask", None)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
@@ -953,9 +1013,7 @@ def prepare_inputs_for_generation(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
- "position_ids": position_ids,
"attention_mask": attention_mask,
- "token_type_ids": token_type_ids,
}
)
return model_inputs
@@ -1042,7 +1100,6 @@ def chat(
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
- append_history: bool = True,
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
@@ -1054,6 +1111,10 @@ def chat(
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
+ else:
+ # make a copy of the user's input such that is is left untouched
+ history = copy.deepcopy(history)
+
if stop_words_ids is None:
stop_words_ids = []
@@ -1091,8 +1152,11 @@ def chat(
errors='replace'
)
- if append_history:
- history.append((query, response))
+ # as history is a copy of the user inputs,
+ # we can always return the new turn to the user.
+ # separating input history and output history also enables the user
+ # to implement more complex history management
+ history.append((query, response))
return response, history
@@ -1220,8 +1284,7 @@ def __init__(self, dim, base=10000):
self._ntk_alpha_cached = 1.0
self._ntk_alpha_cached_list = [1.0]
- def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
- seqlen = max_seq_len + offset
+ def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0):
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
@@ -1244,10 +1307,10 @@ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
cos, sin = emb.cos(), emb.sin()
self._rotary_pos_emb_cache = [cos, sin]
- def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
- self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
+ def forward(self, max_seq_len, ntk_alpha=1.0):
+ self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha)
cos, sin = self._rotary_pos_emb_cache
- return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
+ return [cos[:, :max_seq_len], sin[:, :max_seq_len]]
def _rotate_half(x):
@@ -1259,21 +1322,28 @@ def _rotate_half(x):
def apply_rotary_pos_emb(t, freqs):
+ """ Apply rotary embedding to the first rotary_dim of the iput
+
+ Arguments:
+ t (tensor(batch_size, seq_len, n_head, head_dim)):
+ the input embedding/hidden states
+ freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
+ the cached cos/sin position embeddings
+ """
+ rot_dim = freqs[0].shape[-1]
cos, sin = freqs
+ t_float = t.float()
if apply_rotary_emb_func is not None and t.is_cuda:
- t_ = t.float()
- cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
- sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
- output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
- return output
+ # apply_rotary_emb in flash_attn requires cos/sin to be of
+ # shape (seqlen, rotary_dim / 2) and apply rotary embedding
+ # to the first rotary_dim of the input
+ cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
+ sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
+ return apply_rotary_emb_func(t_float, cos, sin).type_as(t)
else:
- rot_dim = freqs[0].shape[-1]
- cos, sin = freqs
- t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
- t_ = t_.float()
- t_pass_ = t_pass_.float()
- t_ = (t_ * cos) + (_rotate_half(t_) * sin)
- return torch.cat((t_, t_pass_), dim=-1).type_as(t)
+ t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
+ t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
+ return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
class RMSNorm(torch.nn.Module):
diff --git a/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py
index 4a47c7a..2a526d6 100644
--- a/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py
+++ b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py
@@ -27,11 +27,22 @@
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
-SPECIAL_TOKENS = (
- ENDOFTEXT,
- IMSTART,
- IMEND,
-) + EXTRAS
+# changed to use actual index to avoid misconfiguration with vocabulary expansion
+SPECIAL_START_ID = 151643
+SPECIAL_TOKENS = tuple(
+ enumerate(
+ (
+ (
+ ENDOFTEXT,
+ IMSTART,
+ IMEND,
+ )
+ + EXTRAS
+ ),
+ start=SPECIAL_START_ID,
+ )
+)
+SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
@@ -42,6 +53,7 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
for token, rank in (line.split() for line in contents.splitlines() if line)
}
+
class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""
@@ -51,20 +63,35 @@ def __init__(
self,
vocab_file,
errors="replace",
+ extra_vocab_file=None,
**kwargs,
):
super().__init__(**kwargs)
- self.errors = errors # how to handle errors in decoding
+ # how to handle errors in decoding UTF-8 byte sequences
+ # use ignore if you are in streaming inference
+ self.errors = errors
- self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
self.special_tokens = {
token: index
- for index, token in enumerate(
- SPECIAL_TOKENS, start=len(self.mergeable_ranks)
- )
+ for index, token in SPECIAL_TOKENS
}
+ # try load extra vocab from file
+ if extra_vocab_file is not None:
+ used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
+ extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
+ for token, index in extra_mergeable_ranks.items():
+ if token in self.mergeable_ranks:
+ logger.info(f"extra token {token} exists, skipping")
+ continue
+ if index in used_ids:
+ logger.info(f'the index {index} for extra token {token} exists, skipping')
+ continue
+ self.mergeable_ranks[token] = index
+ # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
+
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
@@ -89,7 +116,7 @@ def __init__(
def __getstate__(self):
# for pickle lovers
state = self.__dict__.copy()
- del state['tokenizer']
+ del state["tokenizer"]
return state
def __setstate__(self, state):
@@ -103,7 +130,6 @@ def __setstate__(self, state):
)
self.tokenizer = enc
-
def __len__(self) -> int:
return self.tokenizer.n_vocab
@@ -126,13 +152,17 @@ def convert_tokens_to_ids(
ids.append(self.mergeable_ranks.get(token))
return ids
- def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
+ def _add_tokens(
+ self,
+ new_tokens: Union[List[str], List[AddedToken]],
+ special_tokens: bool = False,
+ ) -> int:
if not special_tokens and new_tokens:
- raise ValueError('Adding regular tokens is not supported')
+ raise ValueError("Adding regular tokens is not supported")
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
- if surface_form not in SPECIAL_TOKENS:
- raise ValueError('Adding unknown special tokens is not supported')
+ if surface_form not in SPECIAL_TOKENS_SET:
+ raise ValueError("Adding unknown special tokens is not supported")
return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
diff --git a/mftcoder_accelerate/src/model/qwen/tokenizer_config.json b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json
new file mode 100644
index 0000000..9c37cac
--- /dev/null
+++ b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json
@@ -0,0 +1,10 @@
+{
+ "model_max_length": 8192,
+ "tokenizer_class": "QWenTokenizer",
+ "auto_map": {
+ "AutoTokenizer": [
+ "tokenization_qwen.QWenTokenizer",
+ null
+ ]
+ }
+}
diff --git a/mftcoder_accelerate/src/pefts/arguments.py b/mftcoder_accelerate/src/pefts/arguments.py
index 5317e5b..1403c4f 100644
--- a/mftcoder_accelerate/src/pefts/arguments.py
+++ b/mftcoder_accelerate/src/pefts/arguments.py
@@ -51,10 +51,10 @@ class TrainArgs:
weighted_loss_mode: str = "case3"
# lora or qlora or None(for full-parameter training)
- peft_type: str = "qlora"
+ peft_type: Union[None, str] = "qlora"
- # if qlora, 4bit or 8bit, 4bit is suggested
- quantization: str = "4bit"
+ # if qlora, 4bit will be set, else None
+ quantization: Union[None, str] = "4bit"
# lora rank, the bigger, the more trainalbe parameters
lora_rank: int = 96
@@ -66,7 +66,7 @@ class TrainArgs:
lora_dropout: float = 0.05
# lora targeting modules
- target_modules: Union[None, List[str]] = None
+ target_modules: Union[None, str, List[str]] = None
# mircro train batch size
per_device_train_batch_size: int = 8
@@ -84,7 +84,7 @@ class TrainArgs:
min_lr: float = 5e-6
# weight decay
- weight_decay: float = 0.1
+ weight_decay: float = 0.01
# gradient_accumulation_steps
gradient_accumulation_steps: int = 1
@@ -107,6 +107,9 @@ class TrainArgs:
# path of adaptor which is resumed from, None for not resuming training
resume_from_checkpoint: Union[None, str] = None
+ # auto resume from latest ckpt if job restarted
+ auto_resume: bool = True
+
# num of steps for logging training loss
log_interval: int = 10
@@ -128,14 +131,14 @@ class TrainArgs:
# DDP random sampler
use_random_sampler: bool = True
- # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point
+ # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point
early_stopping: bool = True
early_stopping_stall_num: int = 5
# limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota.
saving_limit: Union[None, int] = None
- # if dynamic padding
+ # if dynamic padding
use_dynamic_padding: bool = True
# interval of update per task train weight in selfpaced
@@ -154,7 +157,7 @@ class TrainArgs:
# role_markers: {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"}
role_markers: Union[None, dict] = None
- distributed_type: Union[None, str] = "deepspeed"
+ distributed_type: Union[None, str] = None
# legacy, leave them
use_xformers: bool = True
trust_remote_code: bool = True
diff --git a/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py
index bf8434f..1f2ff59 100644
--- a/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py
+++ b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py
@@ -16,13 +16,14 @@
from peft import LoraConfig, get_peft_model
from peft import PeftModel
-# insert src as import path
+# insert src as import path
current_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(os.path.dirname(current_path))
sys.path.insert(0, parent_dir)
print("In merge_base_and_lora_to_hf.py, sys path:", sys.path)
from pefts.model_mapping import MODEL_SPECIAL_TOKENS
+from tokenizer.chat_template import MFTCoder_template
def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str):
@@ -60,6 +61,7 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str):
t0 = time.time()
config = {"model_type": model_type}
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ tokenizer.chat_template = MFTCoder_template
base_model = AutoModelForCausalLM.from_pretrained(
model_path,
diff --git a/mftcoder_accelerate/src/pefts/mft_accelerate.py b/mftcoder_accelerate/src/pefts/mft_accelerate.py
index 8d3eb18..acfb422 100644
--- a/mftcoder_accelerate/src/pefts/mft_accelerate.py
+++ b/mftcoder_accelerate/src/pefts/mft_accelerate.py
@@ -1,15 +1,15 @@
"""
# @author Chaoyu Chen
-# @date 2023/12/11
+# @date 2024/5/20
# @module mft_accelerate.py
-Accelerate + DeepSpeed/FSDP
-QLoRA/LoRA/Full + MFT/MPT, accurate and efficient training
+Accelerate + DeepSpeed zero2/zero3/FSDP + Data Parallelism
+QLoRA/LoRA/Full + MFT/MPT, resource and parameters efficient training
Entry
"""
-import gc
+
import os
import sys
import argparse
@@ -17,6 +17,7 @@
import logging
import json
import time
+from tqdm.auto import tqdm
import transformers
import numpy as np
import torch
@@ -26,7 +27,7 @@
import datasets
from torch.utils.data import DataLoader
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
-from tqdm.auto import tqdm
+
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -40,10 +41,10 @@
LoraConfig,
TaskType,
get_peft_model,
- prepare_model_for_kbit_training,
+ # prepare_model_for_kbit_training,
PeftModel,
)
-from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin
+from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration
from accelerate.logging import get_logger
# insert src as import path
@@ -56,13 +57,28 @@
from data.multi_task_dataset import load_dataset_from_jsonl, compile_helper
from data.data_utils import load_dataset_from_bin
from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK
-from pefts.train_utils import accelerate_train
+
+from pefts.trainer import MftTrainer
from pefts.arguments import TrainArgs
-from pefts.model_mapping import MODEL_TYPES, FULL_LORA_TARGETING_MODULES, MODEL_SPECIAL_TOKENS
+from pefts.model_mapping import MODEL_TYPES, FULL_LORA_TARGETING_MODULES, MODEL_SPECIAL_TOKENS, CUSTOMIZE
+
logger = get_logger(__name__)
-SUPPORT_FA2_IN_TRANSFORMERS = ["code_llama", "llama", "deepseek", "mistral", "mixtral", "gpt_neox", "phi", "starcoder"]
+SUPPORT_FA2_IN_TRANSFORMERS = [
+ "code_llama",
+ "llama",
+ "deepseek",
+ "mistral",
+ "mixtral",
+ "gpt_neox",
+ "phi",
+ "starcoder",
+ "qwen2",
+ "qwen2_moe",
+ "gemma",
+ "starcoder2"
+]
def get_task_mask(args, task_id):
@@ -93,58 +109,112 @@ class DataCollatorForMFTDataset(object):
args: None
def __call__(self, instances):
- input_ids, loss_mask, weights, task_id = tuple(
- [instance[key] if key in instance else None for instance in instances] for key in
- ("input_ids", "loss_mask", "weight", "task_id"))
+ (input_ids, loss_mask, weights, task_id) = tuple(
+ [instance.get(key, None) for instance in instances]
+ for key in ("input_ids", "loss_mask", "weight", "task_id")
+ )
result_batch = {}
- '''
+ """
outputs = model(
- input_ids=batch['input_ids'],
- attention_mask=batch['attention_mask'],
- # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']),
- # labels=(batch['labels'], batch['loss_mask']),
- position_ids=batch['position_ids'],
- )
- '''
+ input_ids=batch['input_ids'],
+ attention_mask=batch['attention_mask'],
+ # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']),
+ # labels=(batch['labels'], batch['loss_mask']),
+ position_ids=batch['position_ids'])
+ """
# if loss_mask is not None:
loss_mask = torch.tensor(np.array(loss_mask)).long()
+ last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1)
if self.args.use_dynamic_padding:
- last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1)
# get last non-padding position
max_pos = last_one_pos.max().item() + 1
else:
max_pos = loss_mask.shape[-1]
- result_batch['loss_mask'] = loss_mask.float()[:, 1:max_pos].contiguous()
- input_ids = torch.tensor(np.array(input_ids)).long()
- # print(f"shape of input_ids: {input_ids.shape}")
- result_batch['input_ids'] = input_ids[:, :max_pos - 1].contiguous()
- result_batch['labels'] = input_ids[:, 1:max_pos].contiguous()
+ if self.args.tokenize_mode == "sst" and self.args.padding_mode == "pack":
+ # sst + pack tokenization, remove last dirty data
+ result_batch["loss_mask"] = loss_mask.float()[:, 1 : max_pos - 1].contiguous()
+ input_ids = torch.tensor(np.array(input_ids)).long()
+ result_batch["input_ids"] = input_ids[:, : max_pos - 2].contiguous()
+ result_batch["labels"] = input_ids[:, 1 : max_pos - 1].contiguous()
+ else:
+ result_batch["loss_mask"] = loss_mask.float()[:, 1:max_pos].contiguous()
+ input_ids = torch.tensor(np.array(input_ids)).long()
+ # print(f"shape of input_ids: {input_ids.shape}")
+ result_batch["input_ids"] = input_ids[:, : max_pos - 1].contiguous()
+ result_batch["labels"] = input_ids[:, 1:max_pos].contiguous()
# Get the masks and position ids.
- # For decoder-only models, attention_mask and position_ids should be None and transformers will create them.
- result_batch['attention_mask'], result_batch['position_ids'] = None, None
-
- # if you want to be compatible with non-gpt(non-causal)models, something you can do here
- # result_batch['attention_mask'], result_batch['position_ids'] = get_attention_mask_and_position_ids(data=result_batch['input_ids'])
+ if self.args.model_type in ["mixtral", "qwen2_moe"]:
+ batch_size, seq_length = result_batch["input_ids"].shape
+ # bsz * seq_length
+ range_tensor = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1)
+ # attention_mask for padding tokens
+ attention_mask = (range_tensor <= last_one_pos.reshape(batch_size, 1)).long()
+ result_batch["attention_mask"], result_batch["position_ids"] = attention_mask, None
+ else:
+ # For decoder-only models, transformers will create them.
+ result_batch["attention_mask"], result_batch["position_ids"] = None, None
if task_id is not None:
task_id = torch.tensor(np.array(task_id))
- result_batch['task_mask'] = get_task_mask(self.args, task_id) # bsz * task_num
- result_batch['task_id'] = task_id
+ result_batch["task_mask"] = get_task_mask(self.args, task_id) # bsz * task_num
+ result_batch["task_id"] = task_id
return result_batch
+def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
+ """
+ This method wraps the entire protocol for preparing a model before running a training.
+ This includes:
+ 1- Cast the layernorm in fp32
+ 2- making output embedding layer require grads
+ 3- Add the upcasting of the lm head to fp32
+
+ Args:
+ model, (`transformers.PreTrainedModel`):
+ The loaded model from `transformers`
+ """
+ loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
+
+ is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
+ for name, param in model.named_parameters():
+ # freeze base model's layers
+ param.requires_grad = False
+
+ if not is_gptq_quantized:
+ # cast all non INT8 parameters to fp32
+ for param in model.parameters():
+ if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
+ param.data = param.data.to(torch.float32)
+
+ if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing:
+ # For backward compatibility
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ # enable gradient checkpointing for memory efficiency
+ model.gradient_checkpointing_enable()
+
+ return model
+
+
def pprint_args(args, accelerator):
# 计算所有键的最大字符串长度
max_key_length = max(len(str(key)) for key in vars(args).keys())
message = ""
message += "====" * 60 + "\n"
- message += '\n'.join([f'{k:<{max_key_length}} : {v}' for k, v in vars(args).items()]) + "\n"
+ message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n"
message += "====" * 60 + "\n"
accelerator.print(message)
accelerator.print("GPU: {}".format(torch.cuda.current_device()))
@@ -164,7 +234,7 @@ def prepare_args():
parsed = parser.parse_args()
# get json configs
- with open(parsed.train_config, 'r') as f:
+ with open(parsed.train_config, "r") as f:
train_config = json.load(f)
# parse args from cofig.json
@@ -190,26 +260,27 @@ def prepare_args():
args.distributed_type = parsed.distributed_type
# refactor args
- args.eos_token = MODEL_SPECIAL_TOKENS[args.model_type]['eos_token']
- args.pad_token = MODEL_SPECIAL_TOKENS[args.model_type]['pad_token']
+ args.eos_token = MODEL_SPECIAL_TOKENS[args.model_type]["eos_token"]
+ args.pad_token = MODEL_SPECIAL_TOKENS[args.model_type]["pad_token"]
- if args.peft_type == 'qlora' and args.quantization != '4bit' and args.quantization != '8bit':
- print(f"[WARNING]peft_type is qlora but quantization is not 4bit or 8bit, setting it to 4bit")
- args.quantization = '4bit'
+ if args.peft_type == "qlora":
+ print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'")
+ args.quantization = "4bit"
+ else:
+ args.quantization = None
args.vocab_file = args.pretrained_model_path
- args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(','))) + "]"
+ args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(","))) + "]"
# generate TASK2ID, ID2TASK
generate_task_id(args.data_paths)
- if args.weighted_loss_mode == 'selfpaced':
+ if args.weighted_loss_mode == "selfpaced":
args.task_weights = [1.0] * len(ID2TASK)
elif args.task_weights is not None:
args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")]
- assert len(args.task_weights) == len(
- ID2TASK), f"length of task_weights, is not equal to the length of data_paths"
+ assert len(args.task_weights) == len(ID2TASK), f"length of task_weights must equal to length of data_paths"
else:
args.task_weights = [1.0] * len(ID2TASK)
@@ -219,11 +290,14 @@ def prepare_args():
def main():
t0 = time.time()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
- print(f"transformers.__version__: {transformers.__version__}")
-
+ os.environ["HF_HUB_OFFLINE"] = "false"
# get input args, set TASK2ID, ID2TASK, refactor args
args = prepare_args()
+ # fix randomness
+ if args.seed is not None:
+ set_seed(args.seed)
+
# define accelerator
if args.distributed_type and args.distributed_type.lower() == "fsdp":
fsdp_plugin = FullyShardedDataParallelPlugin(
@@ -231,19 +305,26 @@ def main():
# optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
limit_all_gathers=True,
sync_module_states=True,
- cpu_offload=False
+ use_orig_params=True,
+ cpu_offload=False,
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps, fsdp_plugin=fsdp_plugin,
+ dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True),
)
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, fsdp_plugin=fsdp_plugin)
else:
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True),
+ )
+
+ # print key infos
+ accelerator.print("In mft_accelerate.py, sys path:", sys.path)
+ accelerator.print(f"transformers.__version__: {transformers.__version__}")
# get world_size
args.world_size = accelerator.num_processes
- # fix randomness
- if args.seed is not None:
- set_seed(args.seed)
-
# backup args
pprint_args(args, accelerator)
if accelerator.is_main_process:
@@ -252,9 +333,31 @@ def main():
with open(os.path.join(args.output_dir, "args.json"), "w") as f:
json.dump(args.dict(), f, indent=2)
+ # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest
+ latest = None
+ if os.path.exists(os.path.join(args.output_dir, "latest")):
+ with open(os.path.join(args.output_dir, "latest"), "r") as fl:
+ latest = json.load(fl)
+ accelerator.print(f"[INFO] Existing latest: {latest}")
+
+ if args.auto_resume and args.resume_from_checkpoint is None and latest:
+ if args.peft_type:
+ args.resume_from_checkpoint = latest["latest_ckpt"]
+ else:
+ args.resume_from_checkpoint = latest["latest_ckpt"]
+ args.pretrained_model_path = args.resume_from_checkpoint
+ args.learning_rate = latest["lr"]
+ elif args.resume_from_checkpoint and (not args.peft_type):
+ args.pretrained_model_path = args.resume_from_checkpoint
+
+ # if latest:
+ # scheduler_last_ep = latest["scheduler_last_ep"]
+ # else:
+ # scheduler_last_ep = -1
+
# logger
logging.basicConfig(
- format="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
+ format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
@@ -272,35 +375,37 @@ def main():
# get global_rank and local rank for current process
global_rank = accelerator.process_index
local_rank = accelerator.local_process_index
- print(f'world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}')
+ print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}")
# TASK2ID, ID2TASK
# generate_task_id(args.data_paths)
# multi task blendable dataset(sharded)
if args.load_raw_dataset:
- print_rank_0('> load raw jsonl dataset')
+ print_rank_0("> load raw jsonl dataset")
train_dataset, valid_dataset = load_dataset_from_jsonl(
- args=args,
- shard_data=True,
- world_size=args.world_size,
- global_rank=global_rank,
- local_rank=local_rank
+ args=args, shard_data=True, world_size=args.world_size, global_rank=global_rank, local_rank=local_rank
)
else:
- print_rank_0('> load tokenized bin dataset, refer to gpt_neox indexed dataset')
+ print_rank_0("> load tokenized bin dataset, refer to gpt_neox indexed dataset")
train_dataset, valid_dataset, _ = load_dataset_from_bin(args=args)
t1 = time.time()
logger.info(f"dataset loading time: {t1 - t0:.4f}")
# cuda memory
- free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3)
+ free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
max_memory = f"{free_in_GB - 2}GB"
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
accelerator.print("max memory: ", max_memory, n_gpus)
+ # target_modules
+ if args.target_modules:
+ target_modules = args.target_modules
+ else:
+ target_modules = FULL_LORA_TARGETING_MODULES[args.model_type]
+
# peft config
if args.peft_type:
peft_config = LoraConfig(
@@ -309,80 +414,82 @@ def main():
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
- target_modules=args.target_modules if args.target_modules else FULL_LORA_TARGETING_MODULES[args.model_type]
+ target_modules=target_modules,
+ bias="lora_only",
)
- # # 是否要加入新的special tokens
+ # new special tokens
# num_added_toks = tokenizer.tokenizer.add_special_tokens(["", ""])
# accelerator.print("We have added", num_added_toks, "tokens")
# accelerator.print(f"role marker tokens {tokenizer.convert_tokens_to_ids('')} {tokenizer.convert_tokens_to_ids('')}, resized tokenizer_size: {len(tokenizer)}")
# creating base model
ModelClass = MODEL_TYPES[args.model_type]
- if args.model_type in SUPPORT_FA2_IN_TRANSFORMERS:
- accelerator.print(f"[INFO] Model Type {args.model_type} is supported FA2 by Transformers and we use it")
+ if args.model_type in SUPPORT_FA2_IN_TRANSFORMERS and not CUSTOMIZE:
+ accelerator.print(f"[INFO] Model Type {args.model_type} " f"is supported FA2 by Transformers and we use it")
model = ModelClass.from_pretrained(
args.pretrained_model_path,
attn_implementation=args.attn_implementation,
- # trust_remote_code=True,
- load_in_8bit=(args.quantization == '8bit'),
- load_in_4bit=(args.quantization == '4bit'),
torch_dtype=torch.bfloat16,
- # low_cpu_mem_usage=args.low_cpu_mem_usage, # not for zero3
- # use_safetensors=False,
quantization_config=BitsAndBytesConfig(
- load_in_4bit=(args.quantization == '4bit'),
+ load_in_4bit=(args.quantization == "4bit"),
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
- ) if args.quantization == '4bit' else None,
+ bnb_4bit_quant_storage=torch.bfloat16,
+ )
+ if args.quantization == "4bit"
+ else None,
)
else:
- accelerator.print(f"[INFO] Model Type {args.model_type} is NOT supported officially by Transformers "
- f"and we use published modeling_xxx.py(may be modified by us)")
+ accelerator.print(
+ f"[INFO] Model Type {args.model_type} "
+ f"is NOT supported officially by Transformers "
+ f"and we use published modeling_xxx.py(may be modified by us)"
+ )
model = ModelClass.from_pretrained(
args.pretrained_model_path,
- load_in_8bit=(args.quantization == '8bit'),
- load_in_4bit=(args.quantization == '4bit'),
torch_dtype=torch.bfloat16,
quantization_config=BitsAndBytesConfig(
- load_in_4bit=(args.quantization == '4bit'),
+ load_in_4bit=(args.quantization == "4bit"),
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
- ) if args.quantization == '4bit' else None,
+ bnb_4bit_quant_storage=torch.bfloat16,
+ )
+ if args.quantization == "4bit"
+ else None,
)
# build a tokenizer for possible resizing or saving
tokenizer = build_tokenizer(args)
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary,
# i.e. the length of the tokenizer.
- # 如果新增special tokens, 需要resize input embedding 和output embedding
# model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
- accelerator.print("load in 8bit: ", args.quantization == '8bit')
- accelerator.print("load in 4bit: ", args.quantization == '4bit')
- if args.peft_type:
- if args.peft_type == 'lora':
- model.gradient_checkpointing_enable()
- # args.saving_limit = None
-
- elif args.peft_type == 'qlora':
- # prepare base model for 8bit or 4bit model(cast non-8bit or non-4bit layers to fp32)
- model = prepare_model_for_kbit_training(model)
- logging.info(f"device map: {model.hf_device_map}")
- # args.saving_limit = None
+ accelerator.print("Model load_in_4bit: ", args.quantization == "4bit")
+
+ if args.peft_type == "lora":
+ model.gradient_checkpointing_enable()
+ elif args.peft_type == "qlora":
+ # prepare base model for 4bit model(cast non-4bit layers to fp32)
+ model = prepare_model_for_kbit_training(model)
+ # logging.info(f"device map: {model.hf_device_map}")
else:
model.gradient_checkpointing_enable()
- assert (args.saving_limit is not None and isinstance(args.saving_limit, int)), "saving_limit must be a integer in Full Training"
+ if args.saving_limit is None or not isinstance(args.saving_limit, int) or args.saving_limit < 1:
+ # saving_limit is set automatically if needed
+ args.saving_limit = 2
+ accelerator.print(
+ "[WARNING]saving_limit must be a integer greater than 1 in Full-Parameters Training, we set it to 2"
+ )
- # Potentially load in the lora from a previous save
+ # Load PeftModel from a previous save or create a new PeftModel
if args.peft_type:
if not args.resume_from_checkpoint:
model = get_peft_model(model, peft_config)
else:
-
- accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
+ accelerator.print(f"[INFO] Resumed from checkpoint: {args.resume_from_checkpoint}")
# accelerator.load_state(args.resume_from_checkpoint)
model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True)
@@ -394,38 +501,52 @@ def main():
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
model.config.use_logn_attn = False # special for qwen model
+ # load balance for moe training
+ if hasattr(model.config, "output_router_logits"):
+ model.config.output_router_logits = True
+ model_config = model.config
accelerator.print(model.config)
# dataloader
train_dataloader = DataLoader(
- train_dataset, shuffle=True, collate_fn=DataCollatorForMFTDataset(args),
- batch_size=args.per_device_train_batch_size, pin_memory=True, drop_last=True
+ train_dataset,
+ shuffle=True,
+ collate_fn=DataCollatorForMFTDataset(args),
+ batch_size=args.per_device_train_batch_size,
+ pin_memory=True,
+ drop_last=True,
)
valid_dataloader = DataLoader(
- valid_dataset, collate_fn=DataCollatorForMFTDataset(args), batch_size=args.per_device_eval_batch_size,
- pin_memory=True, drop_last=True
+ valid_dataset,
+ collate_fn=DataCollatorForMFTDataset(args),
+ batch_size=args.per_device_eval_batch_size,
+ pin_memory=True,
+ drop_last=True,
)
if accelerator.distributed_type == DistributedType.DEEPSPEED:
accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED")
- from deepspeed.ops.adam import FusedAdam as Adam
- adam_optimizer = Adam
+ # from deepspeed.ops.adam import FusedAdam as Adam
+ # adam_optimizer = Adam
+ adam_optimizer = torch.optim.AdamW
elif accelerator.distributed_type == DistributedType.FSDP:
accelerator.print("DISTRIBUTED TRAINING USING FSDP")
if args.peft_type and getattr(accelerator.state, "fsdp_plugin", None) is not None:
from peft.utils.other import fsdp_auto_wrap_policy
+
accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
model = accelerator.prepare(model)
adam_optimizer = torch.optim.AdamW
else:
- accelerator.print(f"DISTRIBUTED TRAINING USING {accelerator.distributed_type}")
- adam_optimizer = torch.optim.AdamW
+ raise ValueError("Only support DeepSpeed and FSDP")
optimizer = adam_optimizer(
model.parameters(),
weight_decay=args.weight_decay,
lr=args.learning_rate,
- betas=(0.9, 0.95),
+ betas=(0.9, 0.999),
)
+ # for group in optimizer.param_groups:
+ # group.setdefault("initial_lr", group["lr"])
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -437,27 +558,25 @@ def main():
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ # scheduler_specific_kwargs={"last_epoch": scheduler_last_ep}
)
+ # prepare all
if accelerator.distributed_type == DistributedType.DEEPSPEED:
- model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare(
+ (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare(
model, train_dataloader, valid_dataloader, optimizer, lr_scheduler
)
+ # prepare all except model, which is prepared before
elif accelerator.distributed_type == DistributedType.FSDP:
- optimizer, train_dataloader, valid_dataloader, lr_scheduler = accelerator.prepare(
+ (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare(
optimizer, train_dataloader, valid_dataloader, lr_scheduler
)
- else:
- # may be not suitable for all DistributedType, expected to be ok with simple multi-gpu
- model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare(
- model, train_dataloader, valid_dataloader, optimizer, lr_scheduler
- )
print(model.device)
accelerator.print(model)
# accelerator.print(model.config)
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ # Recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
@@ -471,18 +590,21 @@ def main():
accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}")
elif getattr(accelerator.state, "fsdp_plugin", None):
accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}")
-
- # Train!
- accelerate_train(accelerator,
- model,
- train_dataloader,
- valid_dataloader,
- optimizer,
- lr_scheduler,
- tokenizer,
- num_update_steps_per_epoch,
- len(train_dataset),
- args)
+
+ trainer = MftTrainer(
+ accelerator=accelerator,
+ model=model,
+ model_config=model_config,
+ train_dataloader=train_dataloader,
+ valid_dataloader=valid_dataloader,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ tokenizer=tokenizer,
+ num_update_steps_per_epoch=num_update_steps_per_epoch,
+ total_train_dataset_size=len(train_dataset),
+ args=args,
+ )
+ trainer.accelerate_train()
if __name__ == "__main__":
diff --git a/mftcoder_accelerate/src/pefts/model_mapping.py b/mftcoder_accelerate/src/pefts/model_mapping.py
index 0474a6d..7824e8f 100644
--- a/mftcoder_accelerate/src/pefts/model_mapping.py
+++ b/mftcoder_accelerate/src/pefts/model_mapping.py
@@ -1,47 +1,63 @@
"""
# @author Chaoyu Chen
- # @date 2023/12/11
+ # @date 2024/5/20
Manage supported models and their special token used in training.
Default targeting modules for LoRA/QLora
- 4.36 is stable now
+ 4.40 is stable now
"""
-# Models that Transformers support FA2
+
+# Models that have both cutomized modeling and Transformers modeling
+
+CUSTOMIZE = False
+if CUSTOMIZE:
+ from model.code_llama.modeling_llama import LlamaForCausalLM
+ from model.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
+ from model.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
+else:
+ from transformers import (
+ GPTNeoXForCausalLM,
+ GPTBigCodeForCausalLM,
+ LlamaForCausalLM,
+ )
+
+# Models that Transformers support Code and FA2 when flash_attn>=2.1.0
from transformers import (
- AutoConfig,
- AutoTokenizer,
- AutoModelForCausalLM,
- GPTNeoXForCausalLM,
- GPTBigCodeForCausalLM,
- LlamaForCausalLM,
MistralForCausalLM,
MixtralForCausalLM,
PhiForCausalLM,
+ GemmaForCausalLM,
+ Qwen2ForCausalLM,
+ Qwen2MoeForCausalLM,
+ Starcoder2ForCausalLM,
)
-
-# Models that Transformers not support FA2, supported by publisher or ourself
+# Models that Code from "remote_code"
from model.aquila2.modeling_aquila import AquilaForCausalLM
from model.baichuan2.modeling_baichuan import BaichuanForCausalLM
from model.qwen.modeling_qwen import QWenLMHeadModel
from model.chatglm2.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration2
from model.chatglm3.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration3
-
# from model.phi.modeling_mixformer_sequential import MixFormerSequentialForCausalLM
+
MODEL_TYPES = {
"aquila2": AquilaForCausalLM,
"baichuan": BaichuanForCausalLM,
- 'chatglm2': ChatGLMForConditionalGeneration2,
- 'chatglm3': ChatGLMForConditionalGeneration3,
+ "chatglm2": ChatGLMForConditionalGeneration2,
+ "chatglm3": ChatGLMForConditionalGeneration3,
"code_llama": LlamaForCausalLM,
"deepseek": LlamaForCausalLM,
"gpt_neox": GPTNeoXForCausalLM,
"llama": LlamaForCausalLM,
"mistral": MistralForCausalLM,
"mixtral": MixtralForCausalLM,
- 'phi': PhiForCausalLM,
- 'qwen': QWenLMHeadModel,
+ "phi": PhiForCausalLM,
+ "qwen": QWenLMHeadModel,
"starcoder": GPTBigCodeForCausalLM,
+ "qwen2": Qwen2ForCausalLM,
+ "gemma": GemmaForCausalLM,
+ "qwen2_moe": Qwen2MoeForCausalLM,
+ "starcoder2": Starcoder2ForCausalLM,
}
FULL_LORA_TARGETING_MODULES = {
@@ -51,91 +67,87 @@
"chatglm3": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
"deepseek": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
"code_llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
- "gpt_neox": ["query_key_value", 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
+ "gpt_neox": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
"llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
"mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
- "mixtral": ["q_proj", "k_proj", "v_proj", "o_proj"],
- "phi": ["query_key_value", 'dense', 'fc1', 'fc2'],
+ "mixtral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate", "w1", "w2", "w3"],
+ "phi": ["query_key_value", "dense", "fc1", "fc2"],
"qwen": ["c_proj", "c_attn", "w1", "w2"],
"starcoder": ["c_proj", "c_attn", "q_attn", "c_fc"],
+ "gemma": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
+ "qwen2": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
+ "qwen2_moe": "all-linear",
+ "starcoder2": "all-linear",
}
+
MODEL_SPECIAL_TOKENS = {
"gpt_neox": {
-
"eos_token": "<|endoftext|>",
"pad_token": "<|pad|>",
-
},
"llama": {
-
"eos_token": "",
"pad_token": "",
-
},
"code_llama": {
-
"eos_token": "",
"pad_token": "",
-
},
"baichuan": {
-
"eos_token": "",
"pad_token": "",
-
},
"starcoder": {
-
"eos_token": "<|endoftext|>",
"pad_token": "",
-
},
"qwen": {
-
"eos_token": "<|endoftext|>",
"pad_token": "<|extra_1|>",
-
},
"chatglm2": {
-
"eos_token": "",
"pad_token": "",
-
},
"chatglm3": {
-
"eos_token": "",
"pad_token": "",
-
},
"phi": {
-
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"aquila": {
-
"eos_token": "",
"pad_token": "<|endoftext|>",
-
},
"deepseek": {
-
"eos_token": "<|end▁of▁sentence|>",
"pad_token": "<|end▁of▁sentence|>",
-
},
"mixtral": {
-
"eos_token": "",
"pad_token": "",
-
},
"mistral": {
-
"eos_token": "",
"pad_token": "",
-
+ },
+ "qwen2": {
+ "eos_token": "<|endoftext|>",
+ "pad_token": "<|endoftext|>",
+ },
+ "gemma": {
+ "eos_token": "",
+ "pad_token": "",
+ },
+ "qwen2_moe": {
+ "eos_token": "<|endoftext|>",
+ "pad_token": "<|endoftext|>",
+ },
+ "starcoder2": {
+ "eos_token": "<|endoftext|>",
+ "pad_token": "<|endoftext|>",
},
}
diff --git a/mftcoder_accelerate/src/pefts/train_utils.py b/mftcoder_accelerate/src/pefts/train_utils.py
deleted file mode 100644
index fb0ca1c..0000000
--- a/mftcoder_accelerate/src/pefts/train_utils.py
+++ /dev/null
@@ -1,416 +0,0 @@
-"""
-# @author Chaoyu Chen
-# @date 2023/10/19
-# @module train_utils.py
-
-Accelerate + DeepSpeed zero stage2 + DistributedDataParallel
-QLoRA/LoRA/Full + MFT/MPT, resource and parameters efficient training
-
-training functions
-"""
-
-import gc
-import os
-import sys
-import threading
-import argparse
-import math
-import logging
-import json
-import time
-import transformers
-import numpy as np
-import psutil
-import shutil
-import torch
-from torch import nn
-from tqdm.auto import tqdm
-
-sys.path.append("..")
-from utils.common_utils import generate_task_id, TASK2ID, ID2TASK
-from utils.auto_accelerate_utils import loss_func_mft, SelfpacedStatus
-from torch.utils.tensorboard import SummaryWriter
-from accelerate.logging import get_logger
-
-logger = get_logger(__name__)
-
-
-def check_existing_ckpts(output_dir):
- prefix = "step_"
-
- if not os.path.exists(output_dir):
- return []
- # 列出目录中的所有文件和文件夹
- contents = os.listdir(output_dir)
-
- # 使用列表推导式筛选以"step_"开头的文件夹
- matching_folders = [folder for folder in contents if
- os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix)]
-
- return matching_folders
-
-
-def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps):
- """
- extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training
- """
- # Extract `epoch_{i}` or `step_{i}`
- training_difference = os.path.splitext(path)[0]
-
- if "epoch" in training_difference:
- starting_epoch = int(training_difference.replace("epoch_", "")) + 1
- resume_step = None
- completed_steps = starting_epoch * num_update_steps_per_epoch
- print(f"resume from epoch {starting_epoch} and completed_steps {completed_steps}")
- else:
- # need to multiply `gradient_accumulation_steps` to reflect real steps
- completed_steps = int(training_difference.replace("step_", ""))
- starting_epoch = completed_steps // num_update_steps_per_epoch
- resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps
- print(f"resume from epoch {starting_epoch} resusme step {resume_step} and completed_steps {completed_steps}")
-
- return starting_epoch, completed_steps, resume_step
-
-
-def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
- for key, value in log_dict.items():
- summary_writer.add_scalar(f'{key}', value, completed_steps)
-
-
-def accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir: str, completed_steps: int, args):
- """
- Saving lora adaptor or full checkpoint using accelerator
- """
- accelerator.wait_for_everyone()
-
- logger.info(
- f"[CHECKPOINT] Saving checkpoint",
- main_process_only=True
- )
- unwrapped_model = accelerator.unwrap_model(model)
- unwrapped_model.save_pretrained(
- output_dir,
- is_main_process=accelerator.is_main_process,
- save_function=accelerator.save,
- state_dict=accelerator.get_state_dict(model)
- )
- # for full-parameter training, save whole ckpt and tokenizer together because it does not need a merge.
- if not args.peft_type and accelerator.is_main_process:
- tokenizer.save_pretrained(output_dir)
-
- logger.info(
- f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved",
- main_process_only=True
- )
- accelerator.wait_for_everyone()
-
-
-def accelerate_monitor(accelerator, model, reduce_loss, reduce_task_loss, reduce_task_exist, args, completed_steps,
- lr_scheduler, optimizer, summary_writer, selfpaced_status=None):
- """
- gather reduce_loss and reduce_task_loss from all N devices.
- train logging and tensorboarding.
- """
- # gather reduce_loss and reduce_task_loss from all N devices
- reduce_losses = accelerator.gather(reduce_loss).detach().float()
- reduce_task_losses = accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK))
- reduce_task_exists = accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK))
- # get train loss and per-task train loss
- train_loss = torch.mean(reduce_losses) / (args.log_interval * args.gradient_accumulation_steps)
- # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (args.log_interval * args.gradient_accumulation_steps)
- train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0)
-
- # logging and tensorboard
- logger.info(
- f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}][train_task_loss={train_task_loss}]"
- f"[gather shape={reduce_losses.shape}][lr={lr_scheduler.get_lr()[0]:.4e}, {optimizer.param_groups[0]['lr']:.4e}]",
- main_process_only=True)
- if selfpaced_status is not None:
- if completed_steps > selfpaced_status.selfpaced_history_length:
- selfpaced_status.log_per_task_weight = selfpaced_status.log_per_task_weight / torch.sum(selfpaced_status.log_per_task_weight)
- else:
- selfpaced_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK)
- logger.info(f"[TRAIN][per_task_train_weight={selfpaced_status.log_per_task_weight}]", main_process_only=True)
- train_log_dict = {"training_loss": train_loss}
- for i in range(len(ID2TASK)):
- train_log_dict[f"{ID2TASK[i]}_train_loss"] = train_task_loss[i]
- if selfpaced_status is not None:
- train_log_dict[f"{ID2TASK[i]}_train_selfpaced_weight"] = selfpaced_status.log_per_task_weight[i].item()
-
- if accelerator.is_main_process:
- write_tensorboard(summary_writer, train_log_dict, completed_steps)
-
- if selfpaced_status is not None:
- selfpaced_status.log_per_task_weight = torch.zeros(len(ID2TASK))
-
-
-def accelerate_evaluate(accelerator, model, valid_dataloader, args, completed_steps, step, min_eval_loss, stall_num,
- best_step, summary_writer):
- """
- evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices.
- eval logging and tensorboarding.
- """
- losses = []
- accumulated_task_loss = torch.zeros(len(ID2TASK)).to(model.device)
- accumulated_task_exist = torch.zeros(len(ID2TASK)).to(model.device)
- for valid_step, valid_batch in enumerate(valid_dataloader):
- with torch.no_grad():
- outputs = model(
- input_ids=valid_batch['input_ids'],
- attention_mask=valid_batch['attention_mask'],
- position_ids=valid_batch['position_ids'],
- return_dict=True,
- )
-
- loss, task_loss, _ = loss_func_mft(
- outputs=outputs,
- labels=valid_batch['labels'],
- task_mask=valid_batch['task_mask'],
- task_id=valid_batch['task_id'],
- weighted_loss_mode=args.weighted_loss_mode,
- loss_mask=valid_batch['loss_mask'],
- task_weights=args.task_weights
- )
-
- losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
- accumulated_task_loss += task_loss.detach().float()
- accumulated_task_exist += (task_loss != 0.0).detach().float()
-
- accelerator.wait_for_everyone()
- valid_batch_num = len(losses)
- gathered_size = losses[0].shape
- losses = torch.cat(losses)
- # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK))
- task_losses = accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK))
- task_exists = accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK))
-
- try:
- eval_loss = torch.mean(losses)
- # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num
- eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0)
- if eval_loss <= min_eval_loss:
- min_eval_loss = eval_loss
- stall_num = 0
- best_step = completed_steps
- else:
- stall_num += 1
- perplexity = math.exp(eval_loss)
- except OverflowError:
- perplexity = float("inf")
-
- logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]"
- f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]"
- f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}]"
- f"[eval_task_loss={eval_task_loss}]",
- main_process_only=True)
- eval_log_dict = {"valid_loss": eval_loss, "perplexity": perplexity}
- for i in range(len(ID2TASK)):
- eval_log_dict[f"{ID2TASK[i]}_valid_loss"] = eval_task_loss[i]
-
- if accelerator.is_main_process:
- write_tensorboard(summary_writer, eval_log_dict, completed_steps)
-
- return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step
-
-
-def delete_ckpts_over_limits(output_dir, saving_limit, best_step):
- """delete ckpts more than saving_limits except for the best_step ckpt"""
- existing_ckpts = check_existing_ckpts(output_dir)
- logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}")
- # sorted only step num ascendingly
- ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts])
- # delete the oldest steps except for the best step at present
- if len(ckpt_steps) > saving_limit:
- deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step]
- # print(deletable_steps[:len(ckpt_steps) - saving_limit])
- for del_step in deletable_steps[:len(ckpt_steps) - saving_limit]:
- shutil.rmtree(os.path.join(output_dir, f"step_{del_step}"))
- logger.info(f"Removed ckpt step_{del_step}")
-
-
-def touch_print(accelerator, batch, num_tokens=10):
- """touch first and last tokens and labels for debugging usage"""
- accelerator.print(f"step 1 batch shape: {batch['input_ids'].shape},\n"
- f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}"
- f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}")
- accelerator.print(f"first {num_tokens} input_ids and loss_mask")
- for pt in range(1):
- accelerator.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
- accelerator.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
-
-
-def accelerate_train(accelerator, model, train_dataloader, valid_dataloader, optimizer, lr_scheduler, tokenizer,
- num_update_steps_per_epoch, total_train_dataset_size, args):
- # tensorboard writer
- summary_writer = SummaryWriter(log_dir=args.tb_dir)
- # Train!
- total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
- logger.info("**************************************** Running training ****************************************")
- logger.info(f" Num examples = {total_train_dataset_size}")
- logger.info(f" Num Epochs = {args.num_train_epochs}")
- logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
- logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization(update/completed) steps = {args.max_train_steps}")
- logger.info(f" Complete/Optimization steps per Epoch = {args.max_train_steps // args.num_train_epochs}")
- logger.info("***************************************************************************************************")
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
-
- # set starting_epoch, completed_steps and resume_step of train_dataloader
- completed_steps = 0
- starting_epoch = 0
- resume_step = None
-
- if args.resume_from_checkpoint:
- path = os.path.basename(args.resume_from_checkpoint)
- starting_epoch, completed_steps, resume_step = extract_epochs_and_steps(
- path, num_update_steps_per_epoch, args.gradient_accumulation_steps
- )
-
- # update the progress_bar if load from checkpoint
- progress_bar.update(completed_steps)
-
- # monitor minimum eval_loss, stalling num, and best_step
- min_eval_loss = float('inf')
- stall_num = 0
- best_step = None
-
- # monitor train loss
- reduce_loss = 0
- reduce_task_loss = torch.zeros(len(ID2TASK)).to(model.device)
- reduce_task_exist = torch.zeros(len(ID2TASK)).to(model.device)
- per_task_weight = args.task_weights
-
- if args.weighted_loss_mode == "selfpaced":
- selfpaced_status = SelfpacedStatus(args.selfpaced_scale_factor, args.selfpaced_interval, args.selfpaced_history_length, args.selfpaced_sample_valid_num, valid_dataloader)
- selfpaced_status.sample_valid_batch(model, completed_steps)
- selfpaced_status.valid_iterator = iter(selfpaced_status.valid_dataloader)
- else:
- selfpaced_status = None
-
- # Training Loop!
- for epoch in range(starting_epoch, args.num_train_epochs):
- if args.early_stopping and stall_num == args.early_stopping_stall_num:
- break
-
- if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
- # We skip the first `n` batches in the dataloader when resuming from a checkpoint
- active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
- else:
- active_dataloader = train_dataloader
- tail_num = len(active_dataloader) - len(active_dataloader) % args.gradient_accumulation_steps
- print(f"length of dataloader: {len(active_dataloader)}")
-
- model.train()
- # Inner Loop!
- for step, batch in enumerate(active_dataloader):
- if step == tail_num:
- break
- with accelerator.accumulate(model):
- if step == 0:
- touch_print(accelerator, batch, num_tokens=10)
- # forward
- outputs = model(
- input_ids=batch['input_ids'],
- attention_mask=batch['attention_mask'],
- position_ids=batch['position_ids'],
- return_dict=True
- )
-
- if args.weighted_loss_mode == 'selfpaced' and step % args.gradient_accumulation_steps == 0 and completed_steps % args.selfpaced_interval == 0 and completed_steps >= args.selfpaced_history_length:
- per_task_weight = selfpaced_status.compute_per_task_weight(completed_steps=completed_steps)
- selfpaced_status.log_per_task_weight += per_task_weight
-
- # loss
- loss, task_loss, _ = loss_func_mft(
- outputs=outputs,
- labels=batch['labels'],
- task_mask=batch['task_mask'],
- task_id=batch['task_id'],
- weighted_loss_mode=args.weighted_loss_mode,
- loss_mask=batch['loss_mask'],
- task_weights=per_task_weight
- )
-
- # backward
- accelerator.backward(loss)
-
- # update(sync_gradients)
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
- # support args.min_lr
- if optimizer.param_groups[0]['lr'] <= args.min_lr:
- optimizer.param_groups[0]['lr'] = args.min_lr
-
- # accumulate resuce_loss and reduce_task_loss in a log_interval
- if not torch.isnan(loss):
- reduce_loss += loss.detach().float()
- # accelerator.print("task loss devices: ", reduce_task_loss.device, task_loss.device)
- reduce_task_loss += task_loss.detach().float()
- reduce_task_exist += (task_loss != 0).detach().float()
-
- # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done.
- if accelerator.sync_gradients:
- if args.weighted_loss_mode == 'selfpaced' and completed_steps % args.selfpaced_interval == 0 and completed_steps >= 1:
- selfpaced_status.sample_valid_batch(model, completed_steps)
-
- # progress_bar.update(1)
- completed_steps += 1
- # monitoring training process and logging and tensorboarding
- if completed_steps % args.log_interval == 0:
- progress_bar.update(args.log_interval)
- accelerate_monitor(
- accelerator, model, reduce_loss, reduce_task_loss, reduce_task_exist, args, completed_steps,
- lr_scheduler, optimizer, summary_writer, selfpaced_status
- )
- # reset reduce_loss
- reduce_loss = 0
- reduce_task_loss = torch.zeros(len(ID2TASK)).to(model.device)
- reduce_task_exist = torch.zeros(len(ID2TASK)).to(model.device)
-
- # steps checkpointing
- if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0:
- output_dir = f"step_{completed_steps}"
- if args.output_dir is not None:
- output_dir = os.path.join(args.output_dir, output_dir)
- accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args)
-
- # steps evaluation
- if completed_steps % args.evaluation_steps == 0:
- model.eval()
- eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate(
- accelerator, model, valid_dataloader, args, completed_steps, step,
- min_eval_loss, stall_num, best_step, summary_writer
- )
- model.train()
-
- # delete ckpts over args.saving_limit
- if accelerator.is_main_process and args.saving_limit:
- delete_ckpts_over_limits(args.output_dir, args.saving_limit, best_step)
-
- # early stoppin when stalling more than args.early_stopping_stall_num
- if args.early_stopping and stall_num == args.early_stopping_stall_num:
- accelerator.print(f"[WARNING] Early stopping at {completed_steps}")
- break
-
- if completed_steps >= args.max_train_steps:
- break
- accelerator.wait_for_everyone()
-
- # epoch checkpointing
- if args.epoch_checkpointing:
- output_dir = f"epoch_{epoch}"
- if args.output_dir is not None:
- output_dir = os.path.join(args.output_dir, output_dir)
- accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args)
-
- summary_writer.close()
-
- # final save
- output_dir = f"final_step_{completed_steps}"
- if args.output_dir is not None:
- output_dir = os.path.join(args.output_dir, output_dir)
- accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args)
diff --git a/mftcoder_accelerate/src/pefts/trainer.py b/mftcoder_accelerate/src/pefts/trainer.py
new file mode 100644
index 0000000..7e004d9
--- /dev/null
+++ b/mftcoder_accelerate/src/pefts/trainer.py
@@ -0,0 +1,598 @@
+"""
+# @author Chaoyu Chen
+# @date 2024/4/12
+# @module trainer.py
+
+Accelerate + DeepSpeed/FSDP
+QLoRA/LoRA/Full + SFT/MFT/MPT
+
+Trainer
+"""
+
+import gc
+import os
+import sys
+import threading
+import argparse
+import math
+import logging
+import json
+import time
+import transformers
+import numpy as np
+import psutil
+import shutil
+import torch
+from torch import nn
+from torch.utils.tensorboard import SummaryWriter
+from typing import List, Optional, Tuple, Union
+from tqdm.auto import tqdm
+from accelerate.logging import get_logger
+from accelerate import Accelerator
+from transformers import set_seed
+
+# sys.path.append("..")
+from utils.common_utils import generate_task_id, TASK2ID, ID2TASK
+from utils.loss_utils import loss_func_mft, SelfpacedStatus, load_balancing_loss_func
+
+logger = get_logger(__name__)
+
+
+def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str):
+ # create path if not exist
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ # copy each file in files_list to save_path
+ for filename in files_list:
+ src_file = os.path.join(mode_path, filename)
+
+ # copy only if src exists
+ if os.path.exists(src_file):
+ dest_file = os.path.join(save_path, filename)
+
+ # copy
+ shutil.copy(src_file, dest_file)
+ print(f"Copied {filename} to {save_path}")
+ else:
+ print(f"File {filename} does not exist in {mode_path}")
+
+
+def check_existing_ckpts(output_dir):
+ prefix = "step_"
+
+ if not os.path.exists(output_dir):
+ return []
+ # list all files and dirs
+ contents = os.listdir(output_dir)
+
+ # find dirs starts with "step_"
+ matching_folders = [
+ folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix)
+ ]
+
+ return matching_folders
+
+
+def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps):
+ """
+ extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training
+ """
+ # Extract `epoch_{i}` or `step_{i}`
+ training_difference = os.path.splitext(path)[0]
+
+ if "epoch" in training_difference:
+ starting_epoch = int(training_difference.replace("epoch_", ""))
+ resume_step = None
+ completed_steps = starting_epoch * num_update_steps_per_epoch
+ logger.info(f"Resume from exact Epoch {starting_epoch}: completed_steps {completed_steps}")
+ else:
+ # need to multiply `gradient_accumulation_steps` to reflect real steps
+ completed_steps = int(training_difference.replace("step_", ""))
+ starting_epoch = completed_steps // num_update_steps_per_epoch
+ resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps
+ logger.info(f"Resume from Epoch {starting_epoch} + step {resume_step}: completed_steps {completed_steps}")
+
+ return starting_epoch, completed_steps, resume_step
+
+
+def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
+ for key, value in log_dict.items():
+ summary_writer.add_scalar(f"{key}", value, completed_steps)
+
+
+def delete_ckpts_over_limits(output_dir, saving_limit, best_step):
+ """delete ckpts more than saving_limits except for the best_step ckpt"""
+ existing_ckpts = check_existing_ckpts(output_dir)
+ logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}")
+ # sorted only step num ascendingly
+ ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts])
+ # delete the oldest steps except for the best step at present
+ if len(ckpt_steps) > saving_limit:
+ deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step]
+ # print(deletable_steps[:len(ckpt_steps) - saving_limit])
+ for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]:
+ shutil.rmtree(os.path.join(output_dir, f"step_{del_step}"))
+ logger.info(f"Removed ckpt step_{del_step}")
+
+
+class MftTrainer:
+ """
+ Multitask FineTuing Trainer, supporting MFT/SFT/ContinueTrain with Lora/Qlora/Full-parameters.
+ """
+
+ def __init__(
+ self,
+ accelerator: Accelerator,
+ model,
+ model_config,
+ train_dataloader,
+ valid_dataloader,
+ optimizer,
+ lr_scheduler,
+ tokenizer,
+ num_update_steps_per_epoch,
+ total_train_dataset_size,
+ args,
+ ):
+ self.accelerator = accelerator
+ self.model = model
+ # hf model config
+ self.model_config = model_config
+ self.train_dataloader = train_dataloader
+ self.valid_dataloader = valid_dataloader
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ self.tokenizer = tokenizer
+ self.num_update_steps_per_epoch = num_update_steps_per_epoch
+ self.total_train_dataset_size = total_train_dataset_size
+ # training arguments
+ self.args = args
+ # tensorboard writer
+ self.summary_writer = SummaryWriter(log_dir=args.tb_dir)
+ self.default_writer = SummaryWriter(log_dir="/home/admin/logs/tfevent")
+
+ def print(self, msg: str):
+ """
+ accelerator print, default on main process
+ Args:
+ msg:
+
+ Returns:
+
+ """
+ self.accelerator.print(msg)
+
+ def touch(self, batch, num_tokens=10):
+ """touch first and last tokens and labels for debugging usage"""
+ self.print(
+ f"step 1 batch shape: {batch['input_ids'].shape},\n"
+ f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}"
+ f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}"
+ )
+ self.print(f"first {num_tokens} input_ids and loss_mask")
+ for pt in range(1):
+ self.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
+ self.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
+
+ @staticmethod
+ def format_tensor(tensor, n):
+ return list(map(lambda x: round(x, n), tensor.tolist()))
+
+ def accelerate_saving_checkpoint(self, output_dir: str, completed_steps: int):
+ """
+ Saving lora adaptor or full checkpoint using accelerator
+ Args:
+ output_dir: exact dir for saving ckpt
+ completed_steps:
+
+ Returns:
+
+ """
+ self.accelerator.wait_for_everyone()
+
+ logger.info(f"[CHECKPOINT] Saving checkpoint", main_process_only=True)
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
+ # self.print(f"unwrapped model type {type(unwrapped_model)}")
+ unwrapped_model.save_pretrained(
+ output_dir,
+ is_main_process=self.accelerator.is_main_process,
+ save_function=self.accelerator.save,
+ state_dict=self.accelerator.get_state_dict(self.model),
+ )
+ self.accelerator.wait_for_everyone()
+ # for full-parameter training, save whole ckpt and tokenizer together because it does not need a merge.
+ if not self.args.peft_type and self.accelerator.is_main_process:
+ if self.args.model_type.lower() == "deepseek":
+ copy_tokenizer_files(
+ self.args.pretrained_model_path, ["tokenizer.json", "tokenizer_config.json"], output_dir
+ )
+ else:
+ self.tokenizer.save_pretrained(output_dir)
+
+ sf = os.path.join(output_dir, "model.safetensors")
+ index_file = os.path.join(output_dir, "model.safetensors.index.json")
+ if os.path.isfile(sf) and os.path.isfile(index_file):
+ self.print(f"Remove bug dummy ckpt {sf}")
+ os.remove(sf)
+
+ if self.accelerator.is_main_process:
+ latest = {
+ "latest_ckpt": output_dir,
+ "lr": self.optimizer.param_groups[0]["lr"],
+ # 1 step back because ckping is after schuduler.step()
+ # "scheduler_last_ep": self.lr_scheduler.state_dict().get("last_epoch", 0) - 1,
+ }
+ with open(os.path.join(self.args.output_dir, "latest"), "w") as f:
+ json.dump(latest, f, indent=2)
+
+ logger.info(
+ f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved, latest: {latest}",
+ main_process_only=True,
+ )
+ self.accelerator.wait_for_everyone()
+
+ def accelerate_monitor(
+ self,
+ reduce_loss,
+ reduce_task_loss,
+ reduce_task_exist,
+ completed_steps,
+ selfpaced_status=None,
+ ):
+ """
+ gather reduce_loss and reduce_task_loss from all N devices.
+ train logging and tensorboarding.
+ """
+ # gather reduce_loss and reduce_task_loss from all N devices
+ reduce_losses = self.accelerator.gather(reduce_loss).detach().float()
+ reduce_task_losses = self.accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK))
+ reduce_task_exists = self.accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK))
+
+ # get train loss and per-task train loss
+ train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps)
+ # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (self.args.log_interval * self.args.gradient_accumulation_steps)
+ train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0)
+
+ # logging and writing tensorboard
+ logger.info(
+ f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]"
+ f"[train_task_loss={self.format_tensor(train_task_loss, 4)}]"
+ f"[gather shape={list(reduce_losses.shape)}]"
+ f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]",
+ main_process_only=True,
+ )
+ if selfpaced_status is not None:
+ if completed_steps > selfpaced_status.selfpaced_history_length:
+ selfpaced_status.log_per_task_weight = selfpaced_status.log_per_task_weight / torch.sum(
+ selfpaced_status.log_per_task_weight
+ )
+ else:
+ selfpaced_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK)
+ logger.info(
+ f"[TRAIN][per_task_train_weight={selfpaced_status.log_per_task_weight}]", main_process_only=True
+ )
+ train_log_dict = {"Loss/train": train_loss}
+ for i in range(len(ID2TASK)):
+ train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i]
+ if selfpaced_status is not None:
+ train_log_dict[f"{ID2TASK[i]}_selfpaced_weight/train"] = selfpaced_status.log_per_task_weight[i].item()
+
+ if self.accelerator.is_main_process:
+ write_tensorboard(self.summary_writer, train_log_dict, completed_steps)
+ write_tensorboard(self.default_writer, train_log_dict, completed_steps)
+
+ if selfpaced_status is not None:
+ selfpaced_status.log_per_task_weight = torch.zeros(len(ID2TASK))
+
+ def accelerate_evaluate(
+ self,
+ completed_steps,
+ step,
+ min_eval_loss,
+ stall_num,
+ best_step,
+ ):
+ """
+ evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices.
+ eval logging and tensorboarding.
+ """
+ losses = []
+ accumulated_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+ accumulated_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+ for valid_step, valid_batch in enumerate(self.valid_dataloader):
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids=valid_batch["input_ids"],
+ attention_mask=valid_batch["attention_mask"],
+ position_ids=valid_batch["position_ids"],
+ return_dict=True,
+ )
+
+ loss, task_loss, _ = loss_func_mft(
+ outputs=outputs,
+ labels=valid_batch["labels"],
+ task_mask=valid_batch["task_mask"],
+ task_id=valid_batch["task_id"],
+ weighted_loss_mode=self.args.weighted_loss_mode,
+ loss_mask=valid_batch["loss_mask"],
+ task_weights=self.args.task_weights,
+ )
+
+ losses.append(self.accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size)))
+ accumulated_task_loss += task_loss.detach().float()
+ accumulated_task_exist += (task_loss != 0.0).detach().float()
+
+ self.accelerator.wait_for_everyone()
+ valid_batch_num = len(losses)
+ gathered_size = losses[0].shape
+ losses = torch.cat(losses)
+ # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK))
+ task_losses = self.accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK))
+ task_exists = self.accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK))
+
+ try:
+ eval_loss = torch.mean(losses)
+ # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num
+ eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0)
+ if eval_loss <= min_eval_loss:
+ min_eval_loss = eval_loss
+ stall_num = 0
+ best_step = completed_steps
+ else:
+ stall_num += 1
+ perplexity = math.exp(eval_loss)
+ except OverflowError:
+ perplexity = float("inf")
+
+ logger.info(
+ f"[EVAL][completed_steps={completed_steps}]"
+ f"[eval_loss={eval_loss:.6f}][eval_task_loss={self.format_tensor(eval_task_loss, 4)}]"
+ f"[perplexity={perplexity:.4f}][valid_batch_num={valid_batch_num}]"
+ f"[gather_size={list(gathered_size)}]",
+ main_process_only=True,
+ )
+ eval_log_dict = {
+ "Loss/valid": eval_loss,
+ "Perplexity/valid": perplexity,
+ "Epochs": round(completed_steps / self.num_update_steps_per_epoch, 2),
+ }
+ for i in range(len(ID2TASK)):
+ eval_log_dict[f"{ID2TASK[i]}_loss/valid"] = eval_task_loss[i]
+
+ if self.accelerator.is_main_process:
+ write_tensorboard(self.summary_writer, eval_log_dict, completed_steps)
+ write_tensorboard(self.default_writer, eval_log_dict, completed_steps)
+
+ return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step
+
+ def accelerate_train(self):
+ # Train!
+ if self.args.seed is not None:
+ set_seed(self.args.seed)
+
+ global_batch_size = (
+ self.args.per_device_train_batch_size
+ * self.accelerator.num_processes
+ * self.args.gradient_accumulation_steps
+ )
+ logger.info("************************************** Running training ****************************************")
+ logger.info(f" Num examples = {self.total_train_dataset_size}")
+ logger.info(f" Num Epochs = {self.args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
+ logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization(update/completed) steps = {self.args.max_train_steps}")
+ logger.info(f" Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_train_epochs}")
+ logger.info("************************************************************************************************")
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process)
+
+ # set starting_epoch, completed_steps and resume_step of train_dataloader
+ completed_steps = 0
+ starting_epoch = 0
+ resume_step = None
+
+ if self.args.resume_from_checkpoint:
+ path = os.path.basename(self.args.resume_from_checkpoint)
+ starting_epoch, completed_steps, resume_step = extract_epochs_and_steps(
+ path, self.num_update_steps_per_epoch, self.args.gradient_accumulation_steps
+ )
+
+ # update the progress_bar if load from checkpoint
+ progress_bar.update(completed_steps)
+
+ # monitor minimum eval_loss, stalling num, and best_step
+ min_eval_loss = float("inf")
+ stall_num = 0
+ best_step = None
+
+ # monitor train loss
+ reduce_loss = torch.tensor(0.0).to(self.model.device)
+ reduce_aux_loss = torch.tensor(0.0).to(self.model.device)
+ reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+ reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+ per_task_weight = self.args.task_weights
+
+ if self.args.weighted_loss_mode == "selfpaced":
+ selfpaced_status = SelfpacedStatus(
+ self.args.selfpaced_scale_factor,
+ self.args.selfpaced_interval,
+ self.args.selfpaced_history_length,
+ self.args.selfpaced_sample_valid_num,
+ self.valid_dataloader,
+ )
+ selfpaced_status.sample_valid_batch(self.model, completed_steps)
+ selfpaced_status.valid_iterator = iter(selfpaced_status.valid_dataloader)
+ else:
+ selfpaced_status = None
+
+ # Training Loop!
+ for epoch in range(starting_epoch, self.args.num_train_epochs):
+ # set_epoch
+ # self.train_dataloader.set_epoch(epoch)
+
+ # if we early stop by some ckpts not converging
+ if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num:
+ break
+
+ if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
+ # We skip the first `n` batches in the dataloader when resuming from a checkpoint
+ active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step)
+ else:
+ active_dataloader = self.train_dataloader
+ tail_num = len(active_dataloader) - len(active_dataloader) % self.args.gradient_accumulation_steps
+ print(f"length of dataloader: {len(active_dataloader)}")
+
+ self.model.train()
+ # Inner Loop!
+ for step, batch in enumerate(active_dataloader):
+ if step == tail_num:
+ break
+ with self.accelerator.accumulate(self.model):
+ if step == 0:
+ self.touch(batch, num_tokens=10)
+ # forward
+ outputs = self.model(
+ input_ids=batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ position_ids=batch["position_ids"],
+ return_dict=True,
+ )
+
+ if (
+ self.args.weighted_loss_mode == "selfpaced"
+ and step % self.args.gradient_accumulation_steps == 0
+ and completed_steps % self.args.selfpaced_interval == 0
+ and completed_steps >= self.args.selfpaced_history_length
+ ):
+ per_task_weight = selfpaced_status.compute_per_task_weight(completed_steps=completed_steps)
+ selfpaced_status.log_per_task_weight += per_task_weight
+
+ # loss
+ loss, task_loss, _ = loss_func_mft(
+ outputs=outputs,
+ labels=batch["labels"],
+ task_mask=batch["task_mask"],
+ task_id=batch["task_id"],
+ weighted_loss_mode=self.args.weighted_loss_mode,
+ loss_mask=batch["loss_mask"],
+ task_weights=per_task_weight,
+ )
+
+ # accelerator.print(len(outputs.router_logits), outputs.router_logits[0], outputs.router_logits[-1])
+ # accelerator.print(batch['attention_mask'].shape, batch['attention_mask'])
+ aux_loss = None
+ if hasattr(self.model_config, "output_router_logits") and self.model_config.output_router_logits:
+ if hasattr(self.model_config, "num_local_experts"):
+ num_experts = self.model_config.num_local_experts
+ elif hasattr(self.model_config, "num_experts"):
+ num_experts = self.model_config.num_experts
+ else:
+ raise ValueError("model has no attribute num_local_experts or num_experts")
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ num_experts,
+ self.model_config.num_experts_per_tok,
+ batch["attention_mask"],
+ )
+ aux_loss = self.model_config.router_aux_loss_coef * aux_loss.to(loss.device)
+ loss += aux_loss # make sure to reside in the same device
+
+ # backward
+ self.accelerator.backward(loss)
+ # print(self.lr_scheduler.state_dict(), self.accelerator.process_index)
+ # update(sync_gradients)
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ self.optimizer.zero_grad()
+ # support args.min_lr
+ if self.optimizer.param_groups[0]["lr"] <= self.args.min_lr:
+ self.optimizer.param_groups[0]["lr"] = self.args.min_lr
+
+ # accumulate resuce_loss and reduce_task_loss in a log_interval
+ if not torch.isnan(loss):
+ reduce_loss += loss.detach().float()
+ if aux_loss and not torch.isnan(aux_loss):
+ reduce_aux_loss += aux_loss.detach().float()
+ # self.print("task loss devices: ", reduce_task_loss.device, task_loss.device)
+ reduce_task_loss += task_loss.detach().float()
+ reduce_task_exist += (task_loss != 0).detach().float()
+
+ # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done.
+ if self.accelerator.sync_gradients:
+ if (
+ self.args.weighted_loss_mode == "selfpaced"
+ and completed_steps % self.args.selfpaced_interval == 0
+ and completed_steps >= 1
+ ):
+ selfpaced_status.sample_valid_batch(self.model, completed_steps)
+
+ # progress_bar.update(1)
+ completed_steps += 1
+ # monitoring training process and logging and tensorboarding
+ if completed_steps % self.args.log_interval == 0:
+ progress_bar.update(self.args.log_interval)
+ if reduce_aux_loss > 0.0:
+ self.print(f"[INFO] aux_loss: {reduce_aux_loss/self.args.log_interval}")
+ self.accelerate_monitor(
+ reduce_loss,
+ reduce_task_loss,
+ reduce_task_exist,
+ completed_steps,
+ selfpaced_status,
+ )
+ # reset reduce_loss
+ reduce_loss = torch.tensor(0.0).to(self.model.device)
+ reduce_aux_loss = torch.tensor(0.0).to(self.model.device)
+ reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device)
+ reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
+
+ # steps checkpointing
+ if self.args.checkpointing_steps and completed_steps % self.args.checkpointing_steps == 0:
+ output_dir = f"step_{completed_steps}"
+ if self.args.output_dir is not None:
+ output_dir = os.path.join(self.args.output_dir, output_dir)
+ self.accelerate_saving_checkpoint(output_dir, completed_steps)
+
+ # steps evaluation
+ if completed_steps % self.args.evaluation_steps == 0:
+ self.model.eval()
+ eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate(
+ completed_steps,
+ step,
+ min_eval_loss,
+ stall_num,
+ best_step,
+ )
+ self.model.train()
+
+ # delete ckpts over args.saving_limit
+ if self.accelerator.is_main_process and self.args.saving_limit:
+ delete_ckpts_over_limits(self.args.output_dir, self.args.saving_limit, best_step)
+
+ # early stoppin when stalling more than args.early_stopping_stall_num
+ if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num:
+ self.print(f"[WARNING] Early stopping at {completed_steps}")
+ break
+
+ if completed_steps >= self.args.max_train_steps:
+ break
+ self.accelerator.wait_for_everyone()
+
+ # epoch checkpointing
+ if self.args.epoch_checkpointing:
+ output_dir = f"epoch_{epoch + 1}"
+ if self.args.output_dir is not None:
+ output_dir = os.path.join(self.args.output_dir, output_dir)
+ self.accelerate_saving_checkpoint(output_dir, completed_steps)
+
+ self.summary_writer.close()
+ self.default_writer.close()
+
+ # final save
+ # output_dir = f"final_step_{completed_steps}"
+ # if self.args.output_dir is not None:
+ # output_dir = os.path.join(self.args.output_dir, output_dir)
+ # self.accelerate_saving_checkpoint(output_dir, completed_steps)
diff --git a/mftcoder_accelerate/src/tokenizer/chat_template.py b/mftcoder_accelerate/src/tokenizer/chat_template.py
index 0f2dff0..3d2ad03 100644
--- a/mftcoder_accelerate/src/tokenizer/chat_template.py
+++ b/mftcoder_accelerate/src/tokenizer/chat_template.py
@@ -4,25 +4,17 @@
# store possible chat_template for tokenizers to prepare input string
# -------------------------------------------------- Import ------------------------------------------------------------
-from transformers import (
- AutoTokenizer
-)
-
-# ----------------------------------------------- func and class -------------------------------------------------------
-instruction_template = (
- "{% for message in messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if message['role'] == 'user' %}"
- "{{ '[INST] ' + message['content'] + ' [/INST]' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ message['content'] + eos_token}}"
- "{% else %}"
- "{{ raise_exception('Only user and assistant roles are supported!') }}"
- "{% endif %}"
- "{% endfor %}"
-)
+"""
+Usage:
+tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
+messages = [
+ {"role": "system", "content": "Be smart"},
+ {"role": "human", "content": "Hello, how are you?"},
+ {"role": "bot", "content": "I'm doing great. How can I help you today?"},
+ {"role": "human", "content": "I'd like to show off how chat templating works!"},
+]
+prompts = tokenizer.apply_chat_template(message, chat_template=MFTCoder_template, tokenize=False, add_generation_prompt=True)
+"""
MFTCoder_template = (
"{% if messages[0]['role'] == 'system' %}"
@@ -54,5 +46,5 @@
"{% endif %}"
)
-if __name__ == '__main__':
+if __name__ == "__main__":
pass
diff --git a/mftcoder_accelerate/src/tokenizer/tokenizer.py b/mftcoder_accelerate/src/tokenizer/tokenizer.py
index 5765ffd..cacd712 100644
--- a/mftcoder_accelerate/src/tokenizer/tokenizer.py
+++ b/mftcoder_accelerate/src/tokenizer/tokenizer.py
@@ -1,8 +1,6 @@
"""
# @author Chaoyu Chen
# @date 2023/6/19
-
-Build tokenizer
"""
@@ -10,26 +8,30 @@
from typing import List, Union
from utils.common_utils import print_rank_0
from transformers import AutoTokenizer
+from tokenizer.chat_template import MFTCoder_template
def build_tokenizer(args):
"""Initialize tokenizer."""
- print_rank_0("> building {} tokenizer ...".format(args.tokenizer_type))
+ print_rank_0(f"> building {args.tokenizer_type} tokenizer ...")
# Select and instantiate the tokenizer.
if args.tokenizer_type.lower() == "AutoTokenizer".lower():
assert args.pretrained_model_path is not None
+ # tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True, use_fast=False, legacy=False)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True)
-
tokenizer.eod_id = tokenizer.convert_tokens_to_ids(args.eos_token)
tokenizer.pad_id = tokenizer.convert_tokens_to_ids(args.pad_token)
-
- print_rank_0(f"Tokenizer: {tokenizer}\nLength of tokenizer: {len(tokenizer)}")
+ try:
+ tokenizer.eos_token = args.eos_token
+ tokenizer.pad_token = args.pad_token
+ except:
+ print(f"[WARNING]Cannot set tokenizer.eos_token")
+ print_rank_0(f"Tokenizer: {type(tokenizer)}")
+ print_rank_0(f"Length of tokenizer: {len(tokenizer)}")
print_rank_0(f"build_tokenizer PAD id: {tokenizer.pad_id}, EOD id: {tokenizer.eod_id}")
print_rank_0(f"build_tokenizer PAD token : {args.pad_token}, EOD token: {args.eos_token}")
else:
- raise NotImplementedError(
- "{} tokenizer is not " "implemented.".format(args.tokenizer_type)
- )
+ raise NotImplementedError(f"{args.tokenizer_type} tokenizer is not implemented.")
# Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
@@ -38,7 +40,7 @@ def build_tokenizer(args):
def _vocab_size_with_padding(orig_vocab_size, args):
- """Pad vocab size so that it is divisible by model parallel size and
+ """Pad vocab size thus it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
diff --git a/mftcoder_accelerate/src/utils/__init__.py b/mftcoder_accelerate/src/utils/__init__.py
index 0cf9434..0bd6cec 100644
--- a/mftcoder_accelerate/src/utils/__init__.py
+++ b/mftcoder_accelerate/src/utils/__init__.py
@@ -1,2 +1,2 @@
from .common_utils import *
-from .auto_accelerate_utils import *
\ No newline at end of file
+from .loss_utils import *
diff --git a/mftcoder_accelerate/src/utils/agd.py b/mftcoder_accelerate/src/utils/agd.py
index bb654a9..11929e3 100644
--- a/mftcoder_accelerate/src/utils/agd.py
+++ b/mftcoder_accelerate/src/utils/agd.py
@@ -83,22 +83,14 @@ def step(self, closure: OptLossClosure = None) -> OptFloat:
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
- state["exp_avg"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
+ state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
- state["exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
+ state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["amsgrad"]:
# Maintains max of all exp. moving avg. of sq. grad. values
- state["max_exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
+ state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["win"]:
- state["z"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
+ state["z"] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = (
state["exp_avg"],
@@ -116,8 +108,7 @@ def step(self, closure: OptLossClosure = None) -> OptFloat:
update = (
exp_avg * (1 / bias_correction1)
if state["step"] == 1
- else exp_avg * (1 / bias_correction1)
- - exp_avg_old * (1 / bias_correction1_old)
+ else exp_avg * (1 / bias_correction1) - exp_avg_old * (1 / bias_correction1_old)
)
exp_avg_sq.mul_(beta2).addcmul_(update, update, value=1 - beta2)
diff --git a/mftcoder_accelerate/src/utils/common_utils.py b/mftcoder_accelerate/src/utils/common_utils.py
index 48d75e1..7b6ea30 100644
--- a/mftcoder_accelerate/src/utils/common_utils.py
+++ b/mftcoder_accelerate/src/utils/common_utils.py
@@ -1,10 +1,29 @@
import os
import math
import torch
+from packaging import version
+import importlib
TASK2ID = {}
ID2TASK = {}
+
+def is_flash_attn_2_available():
+
+ # Let's add an extra check to see if cuda is available
+
+ if not torch.cuda.is_available():
+ return False
+
+ if torch.version.cuda:
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
+ elif torch.version.hip:
+ # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
+ else:
+ return False
+
+
def print_rank_0(*message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
@@ -92,30 +111,24 @@ def get_tflops_new(args, batch_size, seq_len, step_time):
L = args.num_hidden_layers
h = args.hidden_size
V = args.vocab_size
- flops = (96 * batch_size * sl * L * h * h * (1 + sl / (6 * h) + V / (16 * L * h)) / step_time)
+ flops = 96 * batch_size * sl * L * h * h * (1 + sl / (6 * h) + V / (16 * L * h)) / step_time
return human_readable_flops(flops)
-def get_tflops_megatron(total_model_param, hidden_size, num_hidden_layers,
- batch_size_per_device, seq_len, step_time):
+def get_tflops_megatron(total_model_param, hidden_size, num_hidden_layers, batch_size_per_device, seq_len, step_time):
ff = total_model_param * 6
attn = seq_len * hidden_size * num_hidden_layers * 60
- flops = (
- batch_size_per_device
- * seq_len
- * (ff + attn)
- / step_time
- )
+ flops = batch_size_per_device * seq_len * (ff + attn) / step_time
return human_readable_flops(flops)
def generate_task_id(data_paths):
- data_prefixes = list(data_paths[1:-1].split(','))
+ data_prefixes = list(data_paths[1:-1].split(","))
print("data paths: ")
print(data_prefixes)
for i, prefix in enumerate(data_prefixes):
- task_name = prefix.split('/')[-1]
+ task_name = prefix.split("/")[-1]
TASK2ID[task_name] = i
ID2TASK[i] = task_name
diff --git a/mftcoder_accelerate/src/utils/auto_accelerate_utils.py b/mftcoder_accelerate/src/utils/loss_utils.py
similarity index 55%
rename from mftcoder_accelerate/src/utils/auto_accelerate_utils.py
rename to mftcoder_accelerate/src/utils/loss_utils.py
index c35ca9b..deb59d0 100644
--- a/mftcoder_accelerate/src/utils/auto_accelerate_utils.py
+++ b/mftcoder_accelerate/src/utils/loss_utils.py
@@ -5,13 +5,14 @@
import torch.nn.functional as F
from dataclasses import dataclass
import numpy as np
+from typing import List, Optional, Tuple, Union
def get_task_mask(task_id):
task_num = len(TASK2ID)
task_mask = torch.zeros(task_id.shape[0], task_num)
task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1
-
+
return task_mask
@@ -53,15 +54,17 @@ def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_
# loss_mask = None
if loss_mask is None:
ineffective_tokens_per_sample = (labels == -100).sum(dim=1)
- effective_tokens_per_sample = - (ineffective_tokens_per_sample - seq_len)
+ effective_tokens_per_sample = -(ineffective_tokens_per_sample - seq_len)
effective_tokens = bsz * seq_len - ineffective_tokens_per_sample.sum()
- loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100)
+ loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-100)
else:
loss_mask = loss_mask.to(device=lm_logits.device)
- loss_fct = CrossEntropyLoss(reduction='none')
+ loss_fct = CrossEntropyLoss(reduction="none")
losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) # [B * L, 1]
losses = losses.contiguous().view(bsz, -1)
- token_losses = losses.clone().detach().float() if loss_mask is None else losses.clone().detach().float() * loss_mask # [B, L]
+ token_losses = (
+ losses.clone().detach().float() if loss_mask is None else losses.clone().detach().float() * loss_mask
+ ) # [B, L]
task_mask_trans = torch.transpose(task_mask, 0, 1)
unique_id = torch.unique(task_id)
if weighted_loss_mode == "case3" or weighted_loss_mode == "case4" or weighted_loss_mode == "selfpaced":
@@ -73,14 +76,22 @@ def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_
weights_sum += task_weight
if weighted_loss_mode == "case3" or weighted_loss_mode == "selfpaced":
if loss_mask is None:
- loss += torch.sum(losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) * task_weight
+ loss += (
+ torch.sum(losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) * task_weight
+ )
else:
loss += torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) * task_weight
elif weighted_loss_mode == "case4":
if loss_mask is None:
- loss += torch.mean(torch.sum(losses, dim=1)[row_idx] / effective_tokens_per_sample[row_idx]) * task_weight
+ loss += (
+ torch.mean(torch.sum(losses, dim=1)[row_idx] / effective_tokens_per_sample[row_idx])
+ * task_weight
+ )
else:
- loss += torch.mean(torch.sum(losses * loss_mask, dim=1)[row_idx] / torch.sum(loss_mask, dim=1)[row_idx]) * task_weight
+ loss += (
+ torch.mean(torch.sum(losses * loss_mask, dim=1)[row_idx] / torch.sum(loss_mask, dim=1)[row_idx])
+ * task_weight
+ )
# loss /= len(unique_id)
loss /= weights_sum
@@ -114,19 +125,96 @@ def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_
return loss, task_loss, task_num
+def load_balancing_loss_func(
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
+) -> float:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ attention_mask (`torch.Tensor`, None):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+ num_experts (`int`, *optional*):
+ Number of experts
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
class MFTLossStatus:
def __init__(self):
super(MFTLossStatus, self).__init__()
class SelfpacedStatus(MFTLossStatus):
- def __init__(self,
- selfpaced_scale_factor=50,
- selfpaced_interval=1,
- selfpaced_history_length=100,
- selfpaced_sample_valid_num=1,
- valid_dataloader=None
- ):
+ def __init__(
+ self,
+ selfpaced_scale_factor=50,
+ selfpaced_interval=1,
+ selfpaced_history_length=100,
+ selfpaced_sample_valid_num=1,
+ valid_dataloader=None,
+ ):
super(SelfpacedStatus, self).__init__()
self.selfpaced_scale_factor = selfpaced_scale_factor
@@ -144,34 +232,37 @@ def selfpaced_evaluate(self, model, v_batch, per_task_weight=None, selfpaced_sta
model.eval()
with torch.no_grad():
valid_outputs = model(
- input_ids=v_batch['input_ids'],
- attention_mask=v_batch['attention_mask'],
- position_ids=v_batch['position_ids']
+ input_ids=v_batch["input_ids"],
+ attention_mask=v_batch["attention_mask"],
+ position_ids=v_batch["position_ids"],
)
_, valid_task_loss, valid_task_num = loss_func_mft(
outputs=valid_outputs,
- labels=v_batch['labels'],
- task_mask=v_batch['task_mask'],
- task_id=v_batch['task_id'],
- weighted_loss_mode='selfpaced',
- loss_mask=v_batch['loss_mask'],
- task_weights=None
+ labels=v_batch["labels"],
+ task_mask=v_batch["task_mask"],
+ task_id=v_batch["task_id"],
+ weighted_loss_mode="selfpaced",
+ loss_mask=v_batch["loss_mask"],
+ task_weights=None,
)
torch.distributed.all_reduce(valid_task_loss, op=torch.distributed.ReduceOp.SUM)
valid_task_loss /= torch.distributed.get_world_size()
model.train()
return valid_task_loss
-
+
def compute_per_task_weight(self, completed_steps=None):
task_slope_fitting = torch.ones(len(ID2TASK))
- history_steps = torch.arange(completed_steps - self.selfpaced_history_length, completed_steps, 1) # DEBUG: step < 0
+ history_steps = torch.arange(
+ completed_steps - self.selfpaced_history_length, completed_steps, 1
+ ) # DEBUG: step < 0
transpose_history_task_valid_loss = self.history_task_valid_loss.transpose(0, 1)
for i in range(len(ID2TASK)):
per_history_task_valid_loss = transpose_history_task_valid_loss[i]
- task_slope_fitting[i] = self.fit_window_point(history_steps, per_history_task_valid_loss,
- history=self.selfpaced_history_length, type='slope')
+ task_slope_fitting[i] = self.fit_window_point(
+ history_steps, per_history_task_valid_loss, history=self.selfpaced_history_length, method="slope"
+ )
slope_sum_abs = torch.sum(torch.abs(task_slope_fitting))
if slope_sum_abs == 0:
@@ -179,15 +270,15 @@ def compute_per_task_weight(self, completed_steps=None):
else:
# print_rank_0(f"[step={completed_steps}][slope sum abs={slope_sum_abs}]")
normalize_slope = len(ID2TASK) * task_slope_fitting / slope_sum_abs
- print_rank_0(f'normalize_slope: {normalize_slope}')
+ print_rank_0(f"normalize_slope: {normalize_slope}")
score = F.softmax(normalize_slope, dim=-1) * (-1 * normalize_slope)
- print_rank_0(f'score: {score}')
+ print_rank_0(f"score: {score}")
per_task_weight = F.softmax(self.selfpaced_scale_factor * score, dim=-1)
- print_rank_0(f'per_task_weight: {per_task_weight}')
-
+ print_rank_0(f"per_task_weight: {per_task_weight}")
+
return per_task_weight
-
- def fit_window_point(self, x, y, history=10, type='slope'):
+
+ def fit_window_point(self, x, y, history=10, method="slope"):
nonzero_index = torch.squeeze(torch.nonzero(y), dim=1)
y = torch.index_select(y, 0, nonzero_index)
@@ -197,11 +288,11 @@ def fit_window_point(self, x, y, history=10, type='slope'):
ws = ws.float()
if len(y) >= 2:
- if type == 'slope':
+ if method == "slope":
X = torch.stack((x, torch.ones_like(x))).T
X = X.float()
else:
- X = torch.stack((x ** 2, x, torch.ones_like(x))).T
+ X = torch.stack((x**2, x, torch.ones_like(x))).T
w = torch.linalg.solve(X.T @ (ws[:, None] * X), X.T @ (ws * y))
result = w[0]
@@ -209,16 +300,24 @@ def fit_window_point(self, x, y, history=10, type='slope'):
result = 0.0
return result
-
+
def sample_valid_batch(self, model, completed_steps):
self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK))
for i in range(self.selfpaced_sample_valid_num):
- if (self.selfpaced_sample_valid_num * completed_steps // self.selfpaced_interval + i) % self.valid_dataloader_length == 0:
+ if (
+ self.selfpaced_sample_valid_num * completed_steps // self.selfpaced_interval + i
+ ) % self.valid_dataloader_length == 0:
self.valid_iterator = iter(self.valid_dataloader)
+
v_batch = next(self.valid_iterator)
valid_task_loss = self.selfpaced_evaluate(model, v_batch)
self.valid_task_loss_accumulated += valid_task_loss.detach().cpu()
+
self.valid_task_loss_accumulated /= self.selfpaced_sample_valid_num
- self.history_task_valid_loss = torch.cat((self.history_task_valid_loss, torch.unsqueeze(self.valid_task_loss_accumulated, dim=0)))
+ self.history_task_valid_loss = torch.cat(
+ (self.history_task_valid_loss, torch.unsqueeze(self.valid_task_loss_accumulated, dim=0))
+ )
if len(self.history_task_valid_loss) > self.selfpaced_history_length:
- self.history_task_valid_loss = self.history_task_valid_loss[len(self.history_task_valid_loss) - self.selfpaced_history_length:]
+ self.history_task_valid_loss = self.history_task_valid_loss[
+ len(self.history_task_valid_loss) - self.selfpaced_history_length :
+ ]
diff --git a/requirements.txt b/requirements.txt
index ee5577f..c6430fc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,13 @@
numpy==1.23.5
pandas==1.5.3
-torch==2.0.1
+torch==2.1.0
tensorboard==2.11.0
-deepspeed==0.9.3
-transformers==4.36.0
-accelerate==0.23.0
-peft==0.7.0
-BitsAndBytes==0.40.2
-xformers==0.0.21
+deepspeed==0.14.0
+transformers==4.40.2
+accelerate==0.28.0
+peft==0.10.0
+BitsAndBytes==0.43.0
+xformers==0.0.22.post7
packaging
einops
sentencepiece