@@ -58,7 +58,7 @@ def __init__(
5858 hidden_dim = 7168 ,
5959 scale_dim = 32 ,
6060 scale_type_size = 4 ,
61- max_num_inp_token_per_rank = max_tokens ,
61+ max_num_inp_token_per_rank = ( max_tokens + 63 ) // 64 * 64 ,
6262 num_experts_per_rank = 16 ,
6363 num_experts_per_token = 8 ,
6464 warp_num_per_block = 8 ,
@@ -129,16 +129,16 @@ def _allgather_with_token_num_padding(self, input, max_token_num):
129129 dist .all_gather (output , padded_input )
130130 return output
131131
132- def gen_test_data (self , use_max_token_num = False ):
132+ def gen_test_data (self , max_num_token , use_max_token_num = False ):
133133 # gen num_tokens
134134 if use_max_token_num :
135135 num_token = torch .tensor (
136- [self . config . max_num_inp_token_per_rank for i in range (self .world_size )]
136+ [max_num_token for i in range (self .world_size )]
137137 ).to (self .device )
138138 else :
139139 num_token = torch .randint (
140140 1 ,
141- self . config . max_num_inp_token_per_rank + 1 ,
141+ max_num_token + 1 ,
142142 [self .world_size ],
143143 generator = self .rng ,
144144 device = self .device ,
@@ -158,19 +158,21 @@ def gen_test_data(self, use_max_token_num=False):
158158 device = self .device ,
159159 )
160160 # argsort gives us a random permutation, take first K columns
161- indices = torch .argsort (random_vals , dim = 1 )[:, : self .config .num_experts_per_token ]
161+ indices = torch .argsort (random_vals , dim = 1 )[
162+ :, : self .config .num_experts_per_token
163+ ]
162164 all_rank_indices .append (indices .to (torch .int32 ))
163165
164166 # num_total_experts = self.config.num_experts_per_rank * self.config.world_size
165167 # num_nodes = self.config.world_size // self.config.gpu_per_node
166168
167169 # even_indices = (
168170 # torch.arange(
169- # self.config.max_num_inp_token_per_rank
171+ # max_num_token
170172 # * self.config.num_experts_per_token,
171173 # device="cuda",
172174 # ).view(
173- # self.config.max_num_inp_token_per_rank ,
175+ # max_num_token ,
174176 # self.config.num_experts_per_token,
175177 # )
176178 # % 256
@@ -420,7 +422,10 @@ def test_dispatch_combine(self):
420422 for i in range (5000 ):
421423 if self .rank == 0 :
422424 print (f"Round { i } begin" )
423- test_data = self .gen_test_data (use_max_token_num = False )
425+ test_data = self .gen_test_data (
426+ max_num_token = self .config .max_num_inp_token_per_rank ,
427+ use_max_token_num = False ,
428+ )
424429 if self .rank == 0 :
425430 print (f"Round { i } gen test_data done" )
426431 self .run_test_once (op , test_data , error_round , i )
@@ -443,7 +448,11 @@ def stress_dispatch_combine(self):
443448 if self .rank == 0 :
444449 print ("Stress Test" )
445450 test_data_list = [
446- self .gen_test_data (use_max_token_num = False ) for i in range (num_test_data )
451+ self .gen_test_data (
452+ max_num_token = self .config .max_num_inp_token_per_rank ,
453+ use_max_token_num = False ,
454+ )
455+ for i in range (num_test_data )
447456 ]
448457 for i in tqdm (range (5000 )):
449458 (
@@ -480,7 +489,10 @@ def stress_dispatch_combine(self):
480489
481490 if self .rank == 0 :
482491 print ("Stress Test with CUDA Graph" )
483- test_data = self .gen_test_data (use_max_token_num = False )
492+ test_data = self .gen_test_data (
493+ max_num_token = self .config .max_num_inp_token_per_rank ,
494+ use_max_token_num = False ,
495+ )
484496 (
485497 all_rank_num_token ,
486498 all_rank_indices ,
@@ -520,7 +532,7 @@ def stress_dispatch_combine(self):
520532
521533 del op
522534
523- def run_bench_once (self , op , test_data , repeat = 10 ):
535+ def run_bench_once (self , max_num_token , op , test_data , repeat = 10 ):
524536 num_events = 2 * repeat + 1
525537 events = [torch .cuda .Event (enable_timing = True ) for i in range (num_events )]
526538
@@ -559,9 +571,7 @@ def run_bench_once(self, op, test_data, repeat=10):
559571 )
560572 torch .cuda .synchronize ()
561573
562- total_rdma_recv_num_token = (
563- self .config .max_num_inp_token_per_rank * self .config .world_size // 8
564- )
574+ total_rdma_recv_num_token = max_num_token * self .config .world_size // 8
565575 print (
566576 f"rank { self .rank } recv { total_recv_num_token } tokens { total_rdma_recv_num_token } rdma tokens"
567577 )
@@ -598,7 +608,7 @@ def run_bench_once(self, op, test_data, repeat=10):
598608 element_size = all_rank_input [self .rank ].element_size ()
599609 total_bytes = total_recv_num_token * self .config .hidden_dim * element_size
600610 ll_mode_scale = (
601- self . config . max_num_inp_token_per_rank
611+ max_num_token
602612 * self .config .num_experts_per_token
603613 / (total_recv_num_token + 1 ) # avoid division by zero
604614 )
@@ -635,9 +645,11 @@ def run_bench_once(self, op, test_data, repeat=10):
635645 ll_mode_scale ,
636646 )
637647
638- def bench_dispatch_combine (self ):
648+ def bench_dispatch_combine (self , max_num_token ):
639649 op = mori .ops .EpDispatchCombineOp (self .config )
640- test_data = self .gen_test_data (use_max_token_num = True )
650+ test_data = self .gen_test_data (
651+ max_num_token = max_num_token , use_max_token_num = True
652+ )
641653
642654 repeat = 50
643655 disp_duration_us_list = []
@@ -664,7 +676,7 @@ def bench_dispatch_combine(self):
664676 comb_rdma_bandwidth ,
665677 comb_bandwidth ,
666678 ll_mode_scale ,
667- ) = self .run_bench_once (op , test_data , repeat )
679+ ) = self .run_bench_once (max_num_token , op , test_data , repeat )
668680
669681 for i in range (repeat ):
670682 disp_duration_output = [torch .zeros (1 ) for _ in range (self .world_size )]
@@ -821,14 +833,29 @@ def collect_metrics(per_round_data):
821833
822834 del op
823835
836+ return (disp_bw , disp_rdma_bw , disp_ll_bw , disp_lat ), (
837+ comb_bw ,
838+ comb_rdma_bw ,
839+ comb_ll_bw ,
840+ comb_lat ,
841+ )
824842
825- def test_dispatch_combine (
826- local_rank , num_node , gpu_per_node , max_tokens , kernel_type , num_qp , cmd = "test"
843+
844+ def sweep_bench_dispatch_combine (
845+ local_rank ,
846+ num_node ,
847+ gpu_per_node ,
848+ max_tokens ,
849+ kernel_type ,
850+ num_qp ,
851+ sweep_token_interval ,
827852):
828853 world_size = num_node * gpu_per_node
829854 node_rank = int (os .environ ["RANK" ])
830855 global_rank = node_rank * gpu_per_node + local_rank
831-
856+ sweep_token_interval = int (sweep_token_interval )
857+ if sweep_token_interval <= 0 :
858+ raise ValueError (f"sweep_token_interval must >= 1, got { sweep_token_interval } " )
832859 test_case = EpDispatchCombineTestCase (
833860 global_rank ,
834861 gpu_per_node ,
@@ -840,32 +867,122 @@ def test_dispatch_combine(
840867 # torch.float8_e4m3fnuz,
841868 )
842869 test_case .setup ()
843- if cmd == "test" :
844- test_case .test_dispatch_combine ()
845- elif cmd == "bench" :
846- test_case .bench_dispatch_combine ()
847- elif cmd == "stress" :
848- test_case .stress_dispatch_combine ()
870+
871+ num_iters = (max_tokens + sweep_token_interval - 1 ) // sweep_token_interval
872+ max_token_list = [i * sweep_token_interval for i in range (num_iters )]
873+
874+ disp_lat_min_list = []
875+ disp_lat_max_list = []
876+ comb_lat_min_list = []
877+ comb_lat_max_list = []
878+ for max_token in max_token_list :
879+ if max_token == 0 :
880+ max_token = 1
881+ disp_stats , comb_stats = test_case .bench_dispatch_combine (max_token )
882+ disp_bw , disp_rdma_bw , disp_ll_bw , disp_lat = disp_stats
883+ comb_bw , comb_rdma_bw , comb_ll_bw , comb_lat = comb_stats
884+
885+ disp_lat_min_list .append (disp_lat [0 ])
886+ comb_lat_min_list .append (comb_lat [0 ])
887+ disp_lat_max_list .append (disp_lat [1 ])
888+ comb_lat_max_list .append (comb_lat [1 ])
889+
890+ if local_rank == 0 :
891+ import matplotlib .pyplot as plt
892+
893+ plt .figure ()
894+ # plt.plot(max_token_list, disp_lat_min_list, label='Dispatch Min')
895+ # plt.plot(max_token_list, comb_lat_min_list, label='Combine Min')
896+ # plt.plot(max_token_list, disp_lat_max_list, label='Dispatch Max')
897+ # plt.plot(max_token_list, comb_lat_max_list, label='Combine Max')
898+ plt .plot (
899+ max_token_list ,
900+ [max - min for max , min in zip (disp_lat_max_list , disp_lat_min_list )],
901+ label = "Dispatch Max-Min" ,
902+ )
903+ plt .plot (
904+ max_token_list ,
905+ [max - min for max , min in zip (comb_lat_max_list , comb_lat_min_list )],
906+ label = "Combine Max-Min" ,
907+ )
908+ plt .xticks ([i * 16 for i in range (max_tokens // 16 )])
909+ plt .title ("Dispatch / Combine Max-Min Latency (us)" )
910+ plt .xlabel ("# of Tokens" )
911+ plt .ylabel ("Latency (us)" )
912+ plt .grid (True )
913+ plt .legend ()
914+ plt .tight_layout ()
915+ plt .savefig ("dispatch_combine_perf_maxmin.png" , dpi = 300 , bbox_inches = "tight" )
916+ test_case .cleanup ()
917+
918+
919+ def test_dispatch_combine (
920+ local_rank ,
921+ num_node ,
922+ gpu_per_node ,
923+ max_tokens ,
924+ kernel_type ,
925+ num_qp ,
926+ cmd = "test" ,
927+ sweep_token_interval = 64 ,
928+ ):
929+ world_size = num_node * gpu_per_node
930+ node_rank = int (os .environ ["RANK" ])
931+ global_rank = node_rank * gpu_per_node + local_rank
932+
933+ if cmd in ("test" , "bench" , "stress" ):
934+ test_case = EpDispatchCombineTestCase (
935+ global_rank ,
936+ gpu_per_node ,
937+ world_size ,
938+ max_tokens ,
939+ kernel_type ,
940+ num_qp ,
941+ torch .bfloat16 ,
942+ # torch.float8_e4m3fnuz,
943+ )
944+ test_case .setup ()
945+ if cmd == "test" :
946+ test_case .test_dispatch_combine ()
947+ elif cmd == "bench" :
948+ test_case .bench_dispatch_combine (max_tokens )
949+ elif cmd == "stress" :
950+ test_case .stress_dispatch_combine ()
951+ test_case .cleanup ()
952+ elif cmd == "sweep_bench" :
953+ sweep_bench_dispatch_combine (
954+ local_rank ,
955+ num_node ,
956+ gpu_per_node ,
957+ max_tokens ,
958+ kernel_type ,
959+ num_qp ,
960+ sweep_token_interval ,
961+ )
849962 else :
850963 raise ValueError (f"unsupported command: { cmd } " )
851964
852- test_case .cleanup ()
853-
854965
855966parser = argparse .ArgumentParser (description = "dispatch/combine internode test" )
856967parser .add_argument (
857968 "--cmd" ,
858969 type = str ,
859970 default = "test" ,
860- choices = ["test" , "bench" , "stress" ],
861- help = "Available subcommands: test, bench, stress" ,
971+ choices = ["test" , "bench" , "stress" , "sweep_bench" ],
972+ help = "Available subcommands: test, bench, stress, sweep_bench " ,
862973)
863974parser .add_argument (
864975 "--max-tokens" ,
865976 type = int ,
866977 default = 4096 ,
867978 help = "Maximum number of input tokens per rank (default: 4096)" ,
868979)
980+ parser .add_argument (
981+ "--sweep-token-interval" ,
982+ type = int ,
983+ default = 2 ,
984+ help = "Number of token interval when sweep bench" ,
985+ )
869986parser .add_argument (
870987 "--kernel-type" ,
871988 type = str ,
@@ -896,6 +1013,7 @@ def test_dispatch_combine(
8961013 args_cli .kernel_type ,
8971014 args_cli .num_qp ,
8981015 args_cli .cmd ,
1016+ args_cli .sweep_token_interval ,
8991017 ),
9001018 nprocs = gpu_per_node ,
9011019 join = True ,
0 commit comments