1616# limitations under the License.
1717#
1818
19+ import pytest
1920import torch
2021import torch .nn as nn
2122import torch_npu
2223import random
2324import copy
2425from vllm .config import VllmConfig
26+ from vllm import LLM , SamplingParams
2527from vllm_ascend .compilation .quant_fusion_pass import AscendQuantFusionPass
2628from vllm_ascend .compilation .graph_rewrite_pass_manager import GraphRewritePassManager
2729from vllm_ascend .quantization .w8a8 import quant_per_tensor
30+ from tests .e2e .model_utils import check_outputs_equal
2831
32+ NUM_TOKENS = [4 , 32 , 57 ]
33+ HIDDEN_SIZES = [128 , 512 , 1024 , 2048 , 4096 ]
34+ MODELS = ["Qwen/Qwen3-30B-A3B" ]
2935
3036class ModelWithRMSNormQuant (nn .Module ):
3137 def __init__ (self , hidden_size , eps = 1e-6 , quant_config = None , prefix = "" ):
@@ -49,40 +55,111 @@ def forward(self, x):
4955
5056
5157class CustomizeCompilationInterface :
52- def __init__ (self , vllm_config ):
58+ def __init__ (self , vllm_config , checking_fusion_pass : str = "torch.ops" ):
5359 self .vllm_config = vllm_config
5460 self .graph_rewriter_manager = GraphRewritePassManager ()
5561 self .graph_rewriter_manager .configure (vllm_config )
62+ self .checking_string_for_fusion_pass = checking_fusion_pass
63+
64+ def string_checking_for_op_name (self , gm : torch .fx .GraphModule , op_names : list [str ]) -> bool :
65+ for op_name in op_names :
66+ if not any (op_name in node .target for node in gm .graph .nodes ):
67+ return False
68+ return True
5669
5770 def __call__ (self , gm : torch .fx .GraphModule , example_inputs ):
71+ for pass_name , (op_names , _ ) in self .checking_string_for_fusion_pass .items ():
72+ assert self .string_checking_for_op_name (gm , op_names ), f"Expected to find { op_names } in the graph, but not found."
5873 gm = self .graph_rewriter_manager (gm )
5974 gm .recompile ()
75+ for pass_name , (_ , replace_op_names ) in self .checking_string_for_fusion_pass .items ():
76+ assert self .string_checking_for_op_name (gm , replace_op_names ), f"Expected to find { replace_op_names } in the graph after pass { pass_name } , but not found."
6077 return gm
6178
62-
63- def test_fusion_pass (
64- num_tokens : int = 20 ,
65- hidden_size : int = 4096 ,
66- ):
79+ @pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
80+ @pytest .mark .parametrize ("hidden_size" , HIDDEN_SIZES )
81+ def test_quant_fusion_pass (
82+ num_tokens : int ,
83+ hidden_size : int ,
84+ ) -> None :
85+ checking_string_for_fusion_pass = {
86+ "quant_fusion_pass" : (["torch.ops.npu.npu_add_rms_norm" , "torch.ops.npu_quantize" ],["torch.ops.npu.npu_add_rms_norm_quant" ])
87+ }
6788 # Create a random input tensor
6889 input_tensor = torch .randn (num_tokens , hidden_size )
6990 vllm_config = VllmConfig ()
7091 # Open the compilation fusion config and enable the graph rewriter on quantization
7192 vllm_config .additional_config .ascend_compilation_config .enable_graph_rewriter = True
7293 vllm_config .additional_config .ascend_compilation_config .enable_quantization_fusion = True
73- compilation_interface = CustomizeCompilationInterface (vllm_config )
94+
95+ # 1. Checking if the pass is added to the pass manager when related config is enabled
96+ compilation_interface = CustomizeCompilationInterface (vllm_config , checking_fusion_pass = checking_string_for_fusion_pass )
97+ quant_fusion_pass_found = False
7498 for pass_ in compilation_interface .graph_rewriter_manager .passes :
75-
99+ if isinstance (pass_ , AscendQuantFusionPass ):
100+ quant_fusion_pass_found = True
101+ break
102+ assert quant_fusion_pass_found , "AscendQuantFusionPass not found in the pass manager"
76103
104+ # 2, Check if the pass is applied correctly,the checking process happens in the `__call__` method of `CustomizeCompilationInterface`
77105 # Initialize the model with RMSNorm quantization
78106 model = ModelWithRMSNormQuant (hidden_size = hidden_size )
79107 new_model = copy .deepcopy (model )
80- compiled_model = torch .compile (model , backend = CustomizeCompilationInterface ( vllm_config ) )
108+ compiled_model = torch .compile (model , backend = compilation_interface )
81109 for i in range (3 ):
82110 output = compiled_model (input_tensor )
83- # Check if the output is as expected
111+
112+ # 3. Check if the output is as expected, we use the original model to get the reference output
84113 reference_output = model (input_tensor )
85114 compiled_output = compiled_model (input_tensor )
86115 assert torch .allclose (reference_output [0 ], compiled_output [0 ]), "Outputs do not match"
87116
88117 print ("Test passed successfully!" )
118+
119+ @pytest .mark .parametrize ("model" , MODELS )
120+ @pytest .mark .parametrize ("max_tokens" , [32 ])
121+ def test_whole_model_with_quant_fusion_pass (
122+ model : str ,
123+ max_tokens : int ,
124+ ):
125+ prompts = [
126+ "Hello, my name is" , "The president of the United States is" ,
127+ "The capital of France is" , "The future of AI is"
128+ ]
129+
130+ sampling_params = SamplingParams (max_tokens = max_tokens , temperature = 0.0 )
131+ # TODO: change to use vllmrunner when the registry of custom op is solved
132+ # while running pytest
133+ vllm_model = LLM (model ,
134+ max_model_len = 1024 ,
135+ additional_config = {
136+ 'ascend_compilation_config' : {
137+ 'enable_graph_rewriter' : True ,
138+ 'enable_quantization_fusion' : True }
139+ })
140+
141+ vllm_aclgraph_outputs = vllm_model .generate (prompts , sampling_params )
142+ del vllm_model
143+ torch .npu .empty_cache ()
144+
145+ vllm_model = LLM (model , enforce_eager = True , max_model_len = 1024 )
146+ vllm_eager_outputs = vllm_model .generate (prompts , sampling_params )
147+ del vllm_model
148+ torch .npu .empty_cache ()
149+
150+ vllm_aclgraph_outputs_list = []
151+ for output in vllm_aclgraph_outputs :
152+ vllm_aclgraph_outputs_list .append (
153+ (output .outputs [0 ].index , output .outputs [0 ].text ))
154+
155+ vllm_eager_outputs_list = []
156+ for output in vllm_eager_outputs :
157+ vllm_eager_outputs_list .append (
158+ (output .outputs [0 ].index , output .outputs [0 ].text ))
159+
160+ check_outputs_equal (
161+ outputs_0_lst = vllm_eager_outputs_list ,
162+ outputs_1_lst = vllm_aclgraph_outputs_list ,
163+ name_0 = "vllm_eager_outputs" ,
164+ name_1 = "vllm_aclgraph_outputs" ,
165+ )
0 commit comments