@@ -324,7 +324,7 @@ def test(
324324 w2_weight_scale .clone ().detach (),
325325 )
326326
327- if rank == 0 :
327+ if args . debug and rank == 0 :
328328 print ("=== Check fused weights ===" )
329329 print ("w13_f:" , w13_f .shape , w13_f .dtype , w13_f .device )
330330 print ("w13s_f:" , w13s_f .shape , w13s_f .dtype , w13s_f .device )
@@ -338,7 +338,8 @@ def test(
338338 start , end = r * experts_per_rank , (r + 1 ) * experts_per_rank
339339 tokens_per_rank [r ] = ((topk_idx >= start ) & (topk_idx < end )).sum ()
340340
341- print (f"[DEBUG] Tokens per rank: { tokens_per_rank } " , flush = True )
341+ if args .debug :
342+ print (f"[DEBUG] Tokens per rank: { tokens_per_rank } " , flush = True )
342343
343344 # ====== ensure topk_weights is defined (fix missing var) ======
344345 topk_weights = torch .randn (
@@ -379,15 +380,15 @@ def test(
379380 if args .topk_drop_col >= 0 and args .topk_drop_col < num_topk :
380381 topk_idx_dropped [:, args .topk_drop_col ] = - 1
381382 topk_weights_dropped [:, args .topk_drop_col ] = 0
382-
383- print (
384- f"[DEBUG] [rank { rank } ] topk_idx_dropped (after fixed-column drop):\n { topk_idx_dropped .cpu ().numpy ()} " ,
385- flush = True ,
386- )
387- print (
388- f"[DEBUG] [rank { rank } ] topk_weights_dropped (after fixed-column drop):\n { topk_weights_dropped .cpu ().numpy ()} " ,
389- flush = True ,
390- )
383+ if args . debug :
384+ print (
385+ f"[DEBUG] [rank { rank } ] topk_idx_dropped (after fixed-column drop):\n { topk_idx_dropped .cpu ().numpy ()} " ,
386+ flush = True ,
387+ )
388+ print (
389+ f"[DEBUG] [rank { rank } ] topk_weights_dropped (after fixed-column drop):\n { topk_weights_dropped .cpu ().numpy ()} " ,
390+ flush = True ,
391+ )
391392
392393 # print drop ratio
393394 drop_ratio = (topk_idx_dropped == - 1 ).float ().mean ().item ()
@@ -407,12 +408,15 @@ def test(
407408 gbl_num_tokens_per_expert = num_tokens_per_expert .clone ()
408409 dist .all_reduce (gbl_num_tokens_per_expert , group = group )
409410
410- print (f"[Rank { rank } ] num_tokens_per_expert: { num_tokens_per_expert .tolist ()} " )
411- if rank == 0 :
412- print (
413- f"[Rank { rank } ] gbl_num_tokens_per_expert: { gbl_num_tokens_per_expert .tolist ()} "
414- )
415- base_prefix_sum = num_tokens_per_expert .clone ()
411+
412+ if args .debug :
413+ print (f"[Rank { rank } ] num_tokens_per_expert: { num_tokens_per_expert .tolist ()} " )
414+ if rank == 0 :
415+ print (
416+ f"[Rank { rank } ] gbl_num_tokens_per_expert: { gbl_num_tokens_per_expert .tolist ()} "
417+ )
418+
419+ local_expert_token_count = num_tokens_per_expert .clone ()
416420
417421 # ----- Baseline -----
418422 baseline_output , base_ep_recv_count = baseline_test (
@@ -459,22 +463,22 @@ def test(
459463 assert avg_diff < 1e-4 , f"[Rank { rank } ] Mismatch detected! diff={ avg_diff } "
460464
461465 # ----- Compare Recv Count -----
462- global_base_prefix_sum = [
463- torch .zeros_like (base_prefix_sum ) for _ in range (num_ranks )
466+ all_expert_token_counts = [
467+ torch .zeros_like (local_expert_token_count ) for _ in range (num_ranks )
464468 ]
465- dist .all_gather (global_base_prefix_sum , base_prefix_sum )
469+ dist .all_gather (all_expert_token_counts , local_expert_token_count )
466470
467- global_base_prefix_sum = torch .stack (global_base_prefix_sum , dim = 0 )
471+ all_expert_token_counts = torch .stack (all_expert_token_counts , dim = 0 )
468472
469- if rank == 0 :
473+ if args . debug and rank == 0 :
470474 print (
471- f"[DEBUG] Global base_prefix_sum (before transpose):\n { global_base_prefix_sum } "
475+ f"[DEBUG] Global local_expert_token_count (before transpose):\n { all_expert_token_counts } "
472476 )
473477
474- transposed_base_prefix_sum = global_base_prefix_sum .T
475- if rank == 0 :
476- print (f"[DEBUG] Transposed base_prefix_sum :\n { transposed_base_prefix_sum } " )
477- print (f"[DEBUG] Transposed base_prefix_sum : { transposed_base_prefix_sum .shape } " )
478+ transposed_base_prefix_sum = all_expert_token_counts .T
479+ if args . debug and rank == 0 :
480+ print (f"[DEBUG] Transposed local_expert_token_count :\n { transposed_base_prefix_sum } " )
481+ print (f"[DEBUG] Transposed local_expert_token_count : { transposed_base_prefix_sum .shape } " )
478482
479483 experts_per_rank = num_experts // dist .get_world_size ()
480484 start_expert = rank * experts_per_rank
@@ -484,14 +488,16 @@ def test(
484488 expected_recv = transposed_base_prefix_sum [start_expert :end_expert ].reshape (- 1 )
485489 fused_recv = fused_ep_recv_count
486490
487- print (f"expected_recv: { expected_recv } " )
488- print (f"fused_recv: { fused_recv } " )
491+ if args .debug :
492+ print (f"expected_recv: { expected_recv } " )
493+ print (f"fused_recv: { fused_recv } " )
489494
490495 diff = (expected_recv - fused_recv ).abs ()
491- print (
492- f"[Rank { rank } ] diff (experts { start_expert } ~{ end_expert - 1 } ): { diff .cpu ().numpy ()} " ,
493- flush = True ,
494- )
496+ if args .debug :
497+ print (
498+ f"[Rank { rank } ] diff (experts { start_expert } ~{ end_expert - 1 } ): { diff .cpu ().numpy ()} " ,
499+ flush = True ,
500+ )
495501
496502 max_recv_count_diff = diff .max ().item ()
497503 mean_recv_count_diff = diff .mean ().item ()
@@ -597,6 +603,12 @@ def str_to_bool(value):
597603 default = - 1 ,
598604 help = "If >=0, drop this specific top-k column (set index to -1 for testing)." ,
599605 )
606+ parser .add_argument (
607+ "--debug" ,
608+ action = "store_true" ,
609+ default = False ,
610+ help = "Enable debug logging." ,
611+ )
600612
601613 args = parser .parse_args ()
602614 num_processes = args .num_processes
0 commit comments