Skip to content

Commit c5b1dd2

Browse files
committed
add the additional config to the doc
Signed-off-by: ganyi <[email protected]>
1 parent e7d8a01 commit c5b1dd2

File tree

5 files changed

+79
-32
lines changed

5 files changed

+79
-32
lines changed

docs/source/user_guide/configuration/additional_config.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ The following table lists the additional configuration options available in vLLM
2727
| Name | Type | Default | Description |
2828
|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------|
2929
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
30+
| `ascend_compilation_config` | dict | `{}` | The config options for torch.compile |
3031
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
3132
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
3233
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
@@ -49,6 +50,13 @@ The details of each config option are as follows:
4950
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
5051
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
5152

53+
**ascend_compilation_config**
54+
| Name | Type | Default | Description |
55+
| ---- | ---- | ------- | ----------- |
56+
| `enable_graph_rewrite` | bool | `True` | Whether to enable the graph rewriter to rewrite the fx graph generated by torch.compile |
57+
| `enable_quantization_fusion` | bool | `True` | Whether to enable the fusion pass on op + quantize, this should remain open by default to benefit all users for performance boost |
58+
59+
5260
**ascend_scheduler_config**
5361

5462
| Name | Type | Default | Description |
@@ -71,6 +79,10 @@ An example of additional configuration is as follows:
7179
"enable_multistream_moe": False,
7280
"enable_kv_nz": False
7381
},
82+
"ascend_compilation_config": {
83+
"enable_graph_rewriter": True,
84+
"enable_quantization_fusion": True
85+
},
7486
"ascend_scheduler_config": {
7587
"enabled": True,
7688
"enable_chunked_prefill": True,

tests/e2e/singlecard/test_graph_rewriter.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,24 @@
2121
import torch_npu
2222
import random
2323
import copy
24+
from vllm.config import VllmConfig
25+
from vllm_ascend.compilation.quant_fusion_pass import AscendQuantFusionPass
26+
from vllm_ascend.compilation.graph_rewrite_pass_manager import GraphRewritePassManager
2427
from vllm_ascend.quantization.w8a8 import quant_per_tensor
2528

29+
2630
class ModelWithRMSNormQuant(nn.Module):
2731
def __init__(self, hidden_size, eps=1e-6, quant_config=None, prefix=""):
2832
super().__init__()
2933
self.hidden_size = hidden_size
3034
self.eps = eps
3135
self.quant_config = quant_config
3236
self.prefix = prefix
33-
self.former_linear = nn.Linear(hidden_size, hidden_size) # float
34-
self.post_linear = nn.Linear(hidden_size, hidden_size, dtype=torch.int8) # quantized
35-
self.deq_scale = 0.7
37+
self.former_linear = nn.Linear(hidden_size, hidden_size)
3638
self.weight = nn.Parameter(torch.Tensor(hidden_size))
3739
self.bias = nn.Parameter(torch.Tensor(hidden_size))
38-
self.quant_scale = 0.83
39-
self.quant_offset = 3
40+
self.quant_scale = nn.Parameter(torch.Tensor(hidden_size))
41+
self.quant_offset = nn.Parameter(torch.Tensor(hidden_size))
4042

4143
def forward(self, x):
4244
hidden_states = self.former_linear(x)
@@ -45,35 +47,37 @@ def forward(self, x):
4547
return quantized_output, residual
4648

4749

48-
def custom_graph_rewriter_backend(gm: torch.fx.GraphModule, example_inputs):
49-
from torch.fx.subgraph_rewriter import replace_pattern
50-
print("before fusion graph:", gm.graph)
51-
def pattern(npu_quant_matmul, output_parallel, rms_norm_weight, scale, offset):
52-
output = torch.ops.npu_add_rms_norm(npu_quant_matmul, output_parallel, rms_norm_weight, 1e-6)
53-
out0 = output[0]
54-
out1 = output[2]
55-
new_out = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
56-
return new_out, out1
57-
58-
def replace(npu_quant_matmul, output_parallel, rms_norm_weight, scale, offset):
59-
output = torch.ops.npu.npu_add_rms_norm_quantize(npu_quant_matmul, output_parallel, rms_norm_weight, scale, offset, epsilon=1e-6)
60-
return output[0], output[2]
6150

62-
replace_pattern(gm, pattern, replace)
63-
gm.recompile()
64-
print("after fusion graph:", gm.graph)
65-
return gm
51+
class CustomizeCompilationInterface:
52+
def __init__(self, vllm_config):
53+
self.vllm_config = vllm_config
54+
self.graph_rewriter_manager = GraphRewritePassManager()
55+
self.graph_rewriter_manager.configure(vllm_config)
56+
57+
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
58+
gm = self.graph_rewriter_manager(gm)
59+
gm.recompile()
60+
return gm
61+
6662

67-
def test_graph_rewriter():
63+
def test_fusion_pass(
64+
num_tokens: int = 20,
65+
hidden_size: int = 4096,
66+
):
6867
# Create a random input tensor
69-
num_tokens = 20
70-
hidden_size = 4096
7168
input_tensor = torch.randn(num_tokens, hidden_size)
69+
vllm_config = VllmConfig()
70+
# Open the compilation fusion config and enable the graph rewriter on quantization
71+
vllm_config.additional_config.ascend_compilation_config.enable_graph_rewriter = True
72+
vllm_config.additional_config.ascend_compilation_config.enable_quantization_fusion = True
73+
compilation_interface = CustomizeCompilationInterface(vllm_config)
74+
for pass_ in compilation_interface.graph_rewriter_manager.passes:
75+
7276

7377
# Initialize the model with RMSNorm quantization
7478
model = ModelWithRMSNormQuant(hidden_size=hidden_size)
7579
new_model = copy.deepcopy(model)
76-
compiled_model = torch.compile(model, backend=custom_graph_rewriter_backend)
80+
compiled_model = torch.compile(model, backend=CustomizeCompilationInterface(vllm_config))
7781
for i in range(3):
7882
output = compiled_model(input_tensor)
7983
# Check if the output is as expected

tests/ut/test_ascend_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def test_init_ascend_config_without_additional_config(self):
5454
self.assertTrue(torchair_graph_config.enable_view_optimize)
5555
self.assertFalse(torchair_graph_config.enable_kv_nz)
5656

57+
ascend_compilation_config = ascend_config.ascend_compilation_config
58+
self.assertTrue(ascend_compilation_config.enable_graph_rewriter)
59+
self.assertTrue(ascend_compilation_config.enable_quantization_fusion)
60+
5761
ascend_scheduler_config = ascend_config.ascend_scheduler_config
5862
self.assertFalse(ascend_scheduler_config.enabled)
5963

@@ -71,6 +75,10 @@ def test_init_ascend_config_with_additional_config(self):
7175
"enable_view_optimize": True,
7276
"enable_kv_nz": True
7377
},
78+
"ascend_compilation_config": {
79+
"enable_graph_rewriter": False,
80+
"enable_quantization_fusion": False,
81+
},
7482
"ascend_scheduler_config": {
7583
"enabled": True
7684
},
@@ -89,6 +97,10 @@ def test_init_ascend_config_with_additional_config(self):
8997
self.assertTrue(torchair_graph_config.enable_multistream_moe)
9098
self.assertTrue(torchair_graph_config.enable_view_optimize)
9199
self.assertTrue(torchair_graph_config.enable_kv_nz)
100+
ascend_compilation_config = ascend_config.ascend_compilation_config
101+
self.assertFalse(ascend_compilation_config.enable_graph_rewriter)
102+
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
103+
92104

93105
ascend_scheduler_config = ascend_config.ascend_scheduler_config
94106
self.assertTrue(ascend_scheduler_config.enabled)

vllm_ascend/ascend_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def __init__(self, vllm_config):
3939
{})
4040
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
4141

42+
ascend_compilation_config = additional_config.get("ascend_compilation_config", {})
43+
self.ascend_compilation_config = AscendCompilationConfig(ascend_compilation_config)
44+
4245
ascend_scheduler_config = additional_config.get(
4346
"ascend_scheduler_config", {})
4447
self.ascend_scheduler_config = AscendSchedulerConfig(
@@ -105,6 +108,19 @@ def __init__(self, torchair_graph_config):
105108
"enable_kv_nz is valid only when Torchair graph mode is enabled"
106109
)
107110

111+
class AscendCompilationConfig:
112+
"""
113+
Configuration Object for ascend_compilation_config from additional_config
114+
"""
115+
116+
def __init__(self, ascend_compilation_config: dict):
117+
self.enable_graph_rewriter = ascend_compilation_config.get(
118+
"enable_graph_rewriter", True)
119+
self.enable_quantization_fusion = ascend_compilation_config.get(
120+
"enable_quantization_fusion", True)
121+
# Add more compilation related configs here as needed
122+
123+
108124

109125
class AscendSchedulerConfig:
110126
"""
@@ -175,6 +191,12 @@ def check_ascend_config(vllm_config, enforce_eager):
175191
"it has been disabled automatically.")
176192
# aclgraph case
177193
else:
194+
# This graph fusion can actually works on eager mode.
195+
if ascend_config.ascend_compilation_config.enable_graph_rewriter:
196+
logger.info("Graph rewriter enabled! Automatic kernel fusion is expected.")
197+
198+
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
199+
logger.info("Quantization fusion enabled! op fusion on quantization are expected. ")
178200
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
179201
if vllm_config.model_config:
180202
model_type = vllm_config.model_config.hf_config.model_type

vllm_ascend/compilation/graph_rewrite_pass_manager.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@
1919
from torch import fx as fx
2020

2121
from vllm.config import VllmConfig
22-
from vllm.platforms import current_platform
23-
from vllm.logger import init_logger
2422
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
25-
from vllm.compilation.inductor_pass import get_pass_context, InductorPass
26-
from quant_fusion_pass import AscendQuantFusionPass
23+
from vllm.compilation.inductor_pass import get_pass_context
2724

2825

2926
class GraphRewritePassManager:
@@ -51,8 +48,8 @@ def add(self, pass_: VllmInductorPass):
5148
self.passes.append(pass_)
5249

5350
def configure(self, config: VllmConfig):
54-
self.pass_config = config.additional_config.ascend_pass_config
55-
if self.pass_config.enable_addrms_norm_quant_fusion:
51+
self.ascend_compilation_config = config.additional_config.ascend_compilation_config
52+
if self.ascend_compilation_config.enable_quantization_fusion:
5653
from .quant_fusion_pass import AscendQuantFusionPass
5754
self.passes.append(AscendQuantFusionPass(config))
5855
# Add more passes here as needed

0 commit comments

Comments
 (0)