Skip to content

Commit 11d555f

Browse files
committed
add msprobe support
1 parent 6c130eb commit 11d555f

File tree

5 files changed

+160
-1
lines changed

5 files changed

+160
-1
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# 配置
2+
3+
按需修改ms-swift目录下msprobe_config.json文件中的dump_path、level等配置项
4+
更多配置可参考[配置示例](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_examples.md)[配置文件介绍](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_introduct.md)
5+
6+
7+
# 代码修改
8+
为了支持 msprobe 工具进行精度调试,我们需要修改 `swift/megatron/model/mm_gpt_model.py` 文件中的 `_patch_word_embeddings` 函数。主要改动是调整函数参数和内部实现逻辑,使其能够正确地对嵌入层进行patch
9+
10+
下面是具体的修改内容:
11+
12+
修改前:
13+
```python
14+
def _patch_word_embeddings(self, kwargs):
15+
origin_forward = VocabParallelEmbedding.forward
16+
17+
def forward(_self, input_):
18+
from ..trainers.utils import split_cp_inputs
19+
args = get_args()
20+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
21+
_self.reduce_scatter_embeddings = False
22+
input_ = torch.masked_fill(input_, input_ < 0, 0)
23+
res = origin_forward(_self, input_)
24+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
25+
packed_seq_params = kwargs.get('packed_seq_params')
26+
# ...其他逻辑...
27+
return res
28+
VocabParallelEmbedding.forward = forward
29+
try:
30+
yield
31+
finally:
32+
VocabParallelEmbedding.forward = origin_forward
33+
34+
def forward(
35+
self,
36+
input_ids: torch.Tensor,
37+
position_ids: torch.Tensor,
38+
attention_mask: torch.Tensor = None,
39+
decoder_input: torch.Tensor = None,
40+
labels: torch.Tensor = None,
41+
inference_params: InferenceParams = None,
42+
packed_seq_params: PackedSeqParams = None,
43+
**kwargs,
44+
) -> torch.Tensor:
45+
if decoder_input is not None:
46+
pass
47+
elif self.pre_process:
48+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
49+
with self._patch_word_embeddings(kwargs):
50+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
51+
52+
# ...其他逻辑...
53+
```
54+
55+
修改后:
56+
```python
57+
def _patch_word_embeddings(self, kwargs, emb): # 修改1
58+
origin_forward = emb.word_embeddings.forward # 修改2
59+
60+
def forward(input_): # 修改3
61+
from ..trainers.utils import split_cp_inputs
62+
args = get_args()
63+
_self = emb.word_embeddings # 修改4
64+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
65+
_self.reduce_scatter_embeddings = False
66+
input_ = torch.masked_fill(input_, input_ < 0, 0)
67+
res = origin_forward(input_) # 修改5
68+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
69+
packed_seq_params = kwargs.get('packed_seq_params')
70+
# ...其他逻辑...
71+
return res
72+
73+
emb.word_embeddings.forward = forward # 修改6
74+
try:
75+
yield
76+
finally:
77+
emb.word_embeddings.forward = origin_forward # 修改7
78+
79+
def forward(
80+
self,
81+
input_ids: torch.Tensor,
82+
position_ids: torch.Tensor,
83+
attention_mask: torch.Tensor = None,
84+
decoder_input: torch.Tensor = None,
85+
labels: torch.Tensor = None,
86+
inference_params: InferenceParams = None,
87+
packed_seq_params: PackedSeqParams = None,
88+
**kwargs,
89+
) -> torch.Tensor:
90+
if decoder_input is not None:
91+
pass
92+
elif self.pre_process:
93+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
94+
with self._patch_word_embeddings(kwargs, self.language_model.embedding): # 修改8
95+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
96+
97+
# ...其他逻辑...
98+
```
99+
100+
主要变化包括:
101+
1. `_patch_word_embeddings` 方法增加了 `emb` 参数,用于接收 embedding 模块实例
102+
2. 直接获取 `emb.word_embeddings.forward` 而不是 `VocabParallelEmbedding.forward`
103+
3. 内部 `forward` 函数签名从 `(_self, input_)` 改为 `(input_)`
104+
4. 在函数内部通过 `emb.word_embeddings` 获取 `_self`
105+
5. 调用原始 forward 时直接传入 `input_`
106+
6. 使用 `emb.word_embeddings.forward` 进行替换和恢复操作(修改6、7)
107+
7. 在调用 `_patch_word_embeddings` 时传入 `self.language_model.embedding` 实例
108+
109+
110+
# 使能
111+
在启动脚本添加`--enable_msprobe True`
112+
113+
另外,由于msprobe不支持融合计算,还需要添加`--no_bias_dropout_fusion True``--no_bias_swiglu_fusion True``--cross_entropy_loss_fusion False`
114+
## 示例
115+
```shell
116+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
117+
NPROC_PER_NODE=2 \
118+
CUDA_VISIBLE_DEVICES=0,1 \
119+
megatron sft \
120+
--load Qwen2.5-7B-Instruct-mcore \
121+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
122+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
123+
'swift/self-cognition#500' \
124+
--tensor_model_parallel_size 2 \
125+
...
126+
--no_bias_dropout_fusion True \
127+
--no_bias_swiglu_fusion True \
128+
--cross_entropy_loss_fusion False \
129+
--enable_msprobe True
130+
```

msprobe_config.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"task": "statistics",
3+
"dump_path": "./dump_path",
4+
"rank": [],
5+
"step": [],
6+
"level": "mix",
7+
"async_dump": false,
8+
"statistics": {
9+
"scope": [],
10+
"list": [],
11+
"tensor_list": [],
12+
"data_mode": ["all"],
13+
"summary_mode": "statistics"
14+
}
15+
}

requirements/framework.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ importlib_metadata
1414
jieba
1515
json_repair
1616
matplotlib
17+
mindstudio-probe
1718
modelscope>=1.23
1819
nltk
1920
numpy

swift/megatron/argument/megatron_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
327327
# qwen3_vl, qwen3_omni
328328
mrope_interleaved: Optional[bool] = None
329329

330+
# dump
331+
enable_msprobe: bool = False
332+
msprobe_config: str = './msprobe_config.json'
333+
330334
@staticmethod
331335
def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]:
332336
res = {}

swift/megatron/trainers/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,17 @@ def _all_reduce_metric(self,
503503
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
504504
**kwargs):
505505
new_data_iterator = self._replace_data_iterator(data_iterator, model)
506-
return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
506+
debugger_on = self.args.enable_msprobe
507+
if debugger_on:
508+
from msprobe.pytorch import PrecisionDebugger
509+
debugger = PrecisionDebugger(config=self.args.msprobe_config, model=model)
510+
debugger.start()
511+
origin_train_step_out = self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
507512
config, *args, **kwargs)
513+
if debugger_on:
514+
debugger.stop()
515+
debugger.step()
516+
return origin_train_step_out
508517

509518
# Code borrowed from NVIDIA/Megatron-LM
510519
def evaluate(

0 commit comments

Comments
 (0)