3838)
3939
4040
41+ def _get_number_of_gpu_sm () -> int :
42+ if not torch .cuda .is_available ():
43+ raise RuntimeError ("CUDA is not available" )
44+ device_props = torch .cuda .get_device_properties (0 )
45+ return device_props .multi_processor_count
46+
47+
4148def _str_1d_tensor (t : torch .Tensor ) -> str :
4249 sl = [f"{ x :7.4f} " for x in t .tolist ()]
4350 if len (sl ) > 5 :
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
4855def _do_test_all_to_all (
4956 pgi : ProcessGroupInfo ,
5057 dp_size : int ,
58+ max_sm_count : int ,
5159 moe : MoEConfig ,
5260 internode : bool ,
5361) -> None :
@@ -79,6 +87,7 @@ def _do_test_all_to_all(
7987 * torch .float32 .itemsize
8088 )
8189 ),
90+ max_sm_count = max_sm_count ,
8291 )
8392 else :
8493 ata = AllToAll .intranode (
@@ -99,6 +108,7 @@ def _do_test_all_to_all(
99108 * torch .float32 .itemsize
100109 )
101110 ),
111+ max_sm_count = max_sm_count ,
102112 )
103113
104114 # Generate the same test data on all ranks
@@ -283,6 +293,7 @@ def _worker_test_all_to_all(
283293 dp_size : int ,
284294 in_dtype : str ,
285295 out_dtype : str ,
296+ max_sm_count : int ,
286297 moe_config : MoEConfig ,
287298 internode : bool ,
288299) -> None :
@@ -295,16 +306,21 @@ def _worker_test_all_to_all(
295306 in_dtype = getattr (torch , in_dtype ),
296307 out_dtype = getattr (torch , out_dtype ),
297308 )
298- _do_test_all_to_all (pgi , dp_size , moe_config , internode )
309+ _do_test_all_to_all (pgi , dp_size , max_sm_count , moe_config , internode )
299310
300311 nvshmem_finalize ()
301312
302313
303314@pytest .mark .skipif (torch .cuda .device_count () < 4 , reason = "Requires at least 4 GPUs" )
304315@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
305316@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
317+ @pytest .mark .parametrize (
318+ "max_sm_count" , [_get_number_of_gpu_sm (), _get_number_of_gpu_sm () // 2 ]
319+ )
306320@pytest .mark .parametrize ("internode" , [True , False ])
307- def test_all_to_all_4_gpu (in_dtype : str , out_dtype : str , internode : bool ) -> None :
321+ def test_all_to_all_4_gpu (
322+ in_dtype : str , out_dtype : str , max_sm_count : int , internode : bool
323+ ) -> None :
308324 world_size = 4
309325 dp_size = 2
310326 parallel_launch (
@@ -313,6 +329,7 @@ def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> Non
313329 dp_size ,
314330 in_dtype ,
315331 out_dtype ,
332+ max_sm_count ,
316333 small_moe ,
317334 internode ,
318335 )
@@ -322,13 +339,15 @@ def _worker_test_all_to_all_multi_node(
322339 pgi : ProcessGroupInfo ,
323340 in_dtype : str ,
324341 out_dtype : str ,
342+ max_sm_count : int ,
325343) -> None :
326344 dp_size = 4
327345 _worker_test_all_to_all (
328346 pgi ,
329347 dp_size ,
330348 in_dtype ,
331349 out_dtype ,
350+ max_sm_count ,
332351 medium_moe ,
333352 True ,
334353 )
@@ -338,4 +357,7 @@ def _worker_test_all_to_all_multi_node(
338357@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
339358@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
340359def test_all_to_all_multi_node (in_dtype : str , out_dtype : str ) -> None :
341- parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype )
360+ max_sm_count = _get_number_of_gpu_sm ()
361+ parallel_launch_from_env (
362+ _worker_test_all_to_all_multi_node , in_dtype , out_dtype , max_sm_count
363+ )
0 commit comments