Skip to content

Commit 80afd39

Browse files
committed
add model test for the graph fusion
Signed-off-by: ganyi <[email protected]>
1 parent c5b1dd2 commit 80afd39

File tree

2 files changed

+89
-12
lines changed

2 files changed

+89
-12
lines changed

tests/e2e/singlecard/test_graph_rewriter.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@
1616
# limitations under the License.
1717
#
1818

19+
import pytest
1920
import torch
2021
import torch.nn as nn
2122
import torch_npu
2223
import random
2324
import copy
2425
from vllm.config import VllmConfig
26+
from vllm import LLM, SamplingParams
2527
from vllm_ascend.compilation.quant_fusion_pass import AscendQuantFusionPass
2628
from vllm_ascend.compilation.graph_rewrite_pass_manager import GraphRewritePassManager
2729
from vllm_ascend.quantization.w8a8 import quant_per_tensor
30+
from tests.e2e.model_utils import check_outputs_equal
2831

32+
NUM_TOKENS = [4, 32, 57]
33+
HIDDEN_SIZES = [128, 512, 1024, 2048, 4096]
34+
MODELS = ["Qwen/Qwen3-30B-A3B"]
2935

3036
class ModelWithRMSNormQuant(nn.Module):
3137
def __init__(self, hidden_size, eps=1e-6, quant_config=None, prefix=""):
@@ -49,40 +55,111 @@ def forward(self, x):
4955

5056

5157
class CustomizeCompilationInterface:
52-
def __init__(self, vllm_config):
58+
def __init__(self, vllm_config, checking_fusion_pass: str = "torch.ops"):
5359
self.vllm_config = vllm_config
5460
self.graph_rewriter_manager = GraphRewritePassManager()
5561
self.graph_rewriter_manager.configure(vllm_config)
62+
self.checking_string_for_fusion_pass = checking_fusion_pass
63+
64+
def string_checking_for_op_name(self, gm: torch.fx.GraphModule, op_names: list[str]) -> bool:
65+
for op_name in op_names:
66+
if not any(op_name in node.target for node in gm.graph.nodes):
67+
return False
68+
return True
5669

5770
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
71+
for pass_name, (op_names, _) in self.checking_string_for_fusion_pass.items():
72+
assert self.string_checking_for_op_name(gm, op_names), f"Expected to find {op_names} in the graph, but not found."
5873
gm = self.graph_rewriter_manager(gm)
5974
gm.recompile()
75+
for pass_name, (_, replace_op_names) in self.checking_string_for_fusion_pass.items():
76+
assert self.string_checking_for_op_name(gm, replace_op_names), f"Expected to find {replace_op_names} in the graph after pass {pass_name}, but not found."
6077
return gm
6178

62-
63-
def test_fusion_pass(
64-
num_tokens: int = 20,
65-
hidden_size: int = 4096,
66-
):
79+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
80+
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
81+
def test_quant_fusion_pass(
82+
num_tokens: int,
83+
hidden_size: int,
84+
) -> None:
85+
checking_string_for_fusion_pass = {
86+
"quant_fusion_pass": (["torch.ops.npu.npu_add_rms_norm", "torch.ops.npu_quantize"],["torch.ops.npu.npu_add_rms_norm_quant"])
87+
}
6788
# Create a random input tensor
6889
input_tensor = torch.randn(num_tokens, hidden_size)
6990
vllm_config = VllmConfig()
7091
# Open the compilation fusion config and enable the graph rewriter on quantization
7192
vllm_config.additional_config.ascend_compilation_config.enable_graph_rewriter = True
7293
vllm_config.additional_config.ascend_compilation_config.enable_quantization_fusion = True
73-
compilation_interface = CustomizeCompilationInterface(vllm_config)
94+
95+
# 1. Checking if the pass is added to the pass manager when related config is enabled
96+
compilation_interface = CustomizeCompilationInterface(vllm_config, checking_fusion_pass=checking_string_for_fusion_pass)
97+
quant_fusion_pass_found = False
7498
for pass_ in compilation_interface.graph_rewriter_manager.passes:
75-
99+
if isinstance(pass_, AscendQuantFusionPass):
100+
quant_fusion_pass_found = True
101+
break
102+
assert quant_fusion_pass_found, "AscendQuantFusionPass not found in the pass manager"
76103

104+
# 2, Check if the pass is applied correctly,the checking process happens in the `__call__` method of `CustomizeCompilationInterface`
77105
# Initialize the model with RMSNorm quantization
78106
model = ModelWithRMSNormQuant(hidden_size=hidden_size)
79107
new_model = copy.deepcopy(model)
80-
compiled_model = torch.compile(model, backend=CustomizeCompilationInterface(vllm_config))
108+
compiled_model = torch.compile(model, backend=compilation_interface)
81109
for i in range(3):
82110
output = compiled_model(input_tensor)
83-
# Check if the output is as expected
111+
112+
# 3. Check if the output is as expected, we use the original model to get the reference output
84113
reference_output = model(input_tensor)
85114
compiled_output = compiled_model(input_tensor)
86115
assert torch.allclose(reference_output[0], compiled_output[0]), "Outputs do not match"
87116

88117
print("Test passed successfully!")
118+
119+
@pytest.mark.parametrize("model", MODELS)
120+
@pytest.mark.parametrize("max_tokens", [32])
121+
def test_whole_model_with_quant_fusion_pass(
122+
model: str,
123+
max_tokens: int,
124+
):
125+
prompts = [
126+
"Hello, my name is", "The president of the United States is",
127+
"The capital of France is", "The future of AI is"
128+
]
129+
130+
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
131+
# TODO: change to use vllmrunner when the registry of custom op is solved
132+
# while running pytest
133+
vllm_model = LLM(model,
134+
max_model_len=1024,
135+
additional_config={
136+
'ascend_compilation_config': {
137+
'enable_graph_rewriter': True,
138+
'enable_quantization_fusion': True}
139+
})
140+
141+
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
142+
del vllm_model
143+
torch.npu.empty_cache()
144+
145+
vllm_model = LLM(model, enforce_eager=True, max_model_len=1024)
146+
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
147+
del vllm_model
148+
torch.npu.empty_cache()
149+
150+
vllm_aclgraph_outputs_list = []
151+
for output in vllm_aclgraph_outputs:
152+
vllm_aclgraph_outputs_list.append(
153+
(output.outputs[0].index, output.outputs[0].text))
154+
155+
vllm_eager_outputs_list = []
156+
for output in vllm_eager_outputs:
157+
vllm_eager_outputs_list.append(
158+
(output.outputs[0].index, output.outputs[0].text))
159+
160+
check_outputs_equal(
161+
outputs_0_lst=vllm_eager_outputs_list,
162+
outputs_1_lst=vllm_aclgraph_outputs_list,
163+
name_0="vllm_eager_outputs",
164+
name_1="vllm_aclgraph_outputs",
165+
)

vllm_ascend/compilation/quant_fusion_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def __init__(self, vllm_config):
7272

7373
def __call__(self, graph: torch.fx.Graph):
7474
self.begin()
75-
self.dump_graph(graph, "before_ascend_quant_fusion")
75+
self.dump_graph(graph, "before_ascend_quant_fusion_pass")
7676
for pattern, replace in self.patterns:
7777
replace_pattern(graph, pattern, replace)
78-
self.dump_graph(graph, "after_ascend_quant_fusion")
78+
self.dump_graph(graph, "after_ascend_quant_fusion_pass")
7979
self.end_and_log()

0 commit comments

Comments
 (0)