@@ -63,9 +63,7 @@ def generate_inputs_outputs(self) -> Tuple[List[torch.Tensor], List[torch.Tensor
6363 ref_C_list .append (ref_C )
6464 return A_list , B_list , C_list , ref_C_list
6565
66- def test_grouped_gemm (
67- self , atol = 1e-2 , rtol = 1e-2 , check_accuracy = True , check_performance = False
68- ):
66+ def test_grouped_gemm (self , atol = 1e-2 , rtol = 1e-2 , check_accuracy = True , check_performance = False ):
6967
7068 WARM_ITERS = 10
7169 ITERS = 1000
@@ -91,11 +89,14 @@ def test_grouped_gemm(
9189 get_multi_stream_cublas_workspace (),
9290 layout = layout ,
9391 m_splits = self .m_splits ,
94- accumulate = self .accumulate
92+ accumulate = self .accumulate ,
9593 )
9694 torch .cuda .synchronize ()
9795
98- print (f'\n === Accuracy Testing with Layout:{ layout } GemmType:{ os .getenv ("NVTE_USE_CUTLASS_GROUPGEMM" , "0" )} ' )
96+ print (
97+ "\n === Accuracy Testing with"
98+ f" Layout:{ layout } GemmType:{ os .getenv ('NVTE_USE_CUTLASS_GROUPGEMM' , '0' )} "
99+ )
99100 if check_accuracy :
100101
101102 alpha = 1.0
@@ -140,7 +141,7 @@ def test_grouped_gemm(
140141 get_multi_stream_cublas_workspace (),
141142 layout = layout ,
142143 m_splits = self .m_splits ,
143- accumulate = self .accumulate
144+ accumulate = self .accumulate ,
144145 )
145146
146147 torch .cuda .synchronize ()
@@ -154,7 +155,7 @@ def test_grouped_gemm(
154155 get_multi_stream_cublas_workspace (),
155156 layout = layout ,
156157 m_splits = self .m_splits ,
157- accumulate = self .accumulate
158+ accumulate = self .accumulate ,
158159 )
159160 torch .cuda .synchronize ()
160161 end_time = time .perf_counter ()
@@ -193,12 +194,12 @@ def run_grouped_gemm(group_config, check_performance, transa, transb, accumulate
193194 [4096 , 768 , 2048 ],
194195 [4096 , 768 , 2048 ],
195196 [4096 , 768 , 2048 ],
196- [4096 , 768 , 2048 ]
197+ [4096 , 768 , 2048 ],
197198 ],
198199 "accumulate" : False ,
199200 "check_performance" : True ,
200201 "transa" : False ,
201- "transb" : True
202+ "transb" : True ,
202203 },
203204 {
204205 "group_config" : [
@@ -217,16 +218,16 @@ def run_grouped_gemm(group_config, check_performance, transa, transb, accumulate
217218 [2048 , 768 , 2048 ],
218219 [2048 , 768 , 2048 ],
219220 [2048 , 768 , 2048 ],
220- [2048 , 768 , 2048 ]
221+ [2048 , 768 , 2048 ],
221222 ],
222223 "accumulate" : False ,
223224 "check_performance" : True ,
224225 "transa" : False ,
225- "transb" : True
226- }
226+ "transb" : True ,
227+ },
227228 ]
228229 }
229-
230+
230231 for i , case in enumerate (config_data ["configs" ]):
231232 group_config = [tuple (x ) for x in case ["group_config" ]]
232233 accumulate = case .get ("accumulate" , False )
0 commit comments