Skip to content

Commit bf3f45e

Browse files
author
Kaniel_Zhou
committed
fix cleancode
1 parent 5bb7c55 commit bf3f45e

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

tests/python/deepep/test_fused_deep_moe.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test(
324324
w2_weight_scale.clone().detach(),
325325
)
326326

327-
if rank == 0:
327+
if args.debug and rank == 0:
328328
print("=== Check fused weights ===")
329329
print("w13_f:", w13_f.shape, w13_f.dtype, w13_f.device)
330330
print("w13s_f:", w13s_f.shape, w13s_f.dtype, w13s_f.device)
@@ -338,7 +338,8 @@ def test(
338338
start, end = r * experts_per_rank, (r + 1) * experts_per_rank
339339
tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum()
340340

341-
print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True)
341+
if args.debug:
342+
print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True)
342343

343344
# ====== ensure topk_weights is defined (fix missing var) ======
344345
topk_weights = torch.randn(
@@ -379,15 +380,15 @@ def test(
379380
if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk:
380381
topk_idx_dropped[:, args.topk_drop_col] = -1
381382
topk_weights_dropped[:, args.topk_drop_col] = 0
382-
383-
print(
384-
f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}",
385-
flush=True,
386-
)
387-
print(
388-
f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}",
389-
flush=True,
390-
)
383+
if args.debug:
384+
print(
385+
f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}",
386+
flush=True,
387+
)
388+
print(
389+
f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}",
390+
flush=True,
391+
)
391392

392393
# print drop ratio
393394
drop_ratio = (topk_idx_dropped == -1).float().mean().item()
@@ -407,12 +408,15 @@ def test(
407408
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
408409
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
409410

410-
print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}")
411-
if rank == 0:
412-
print(
413-
f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}"
414-
)
415-
base_prefix_sum = num_tokens_per_expert.clone()
411+
412+
if args.debug:
413+
print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}")
414+
if rank == 0:
415+
print(
416+
f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}"
417+
)
418+
419+
local_expert_token_count = num_tokens_per_expert.clone()
416420

417421
# ----- Baseline -----
418422
baseline_output, base_ep_recv_count = baseline_test(
@@ -459,22 +463,22 @@ def test(
459463
assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}"
460464

461465
# ----- Compare Recv Count -----
462-
global_base_prefix_sum = [
463-
torch.zeros_like(base_prefix_sum) for _ in range(num_ranks)
466+
all_expert_token_counts = [
467+
torch.zeros_like(local_expert_token_count) for _ in range(num_ranks)
464468
]
465-
dist.all_gather(global_base_prefix_sum, base_prefix_sum)
469+
dist.all_gather(all_expert_token_counts, local_expert_token_count)
466470

467-
global_base_prefix_sum = torch.stack(global_base_prefix_sum, dim=0)
471+
all_expert_token_counts = torch.stack(all_expert_token_counts, dim=0)
468472

469-
if rank == 0:
473+
if args.debug and rank == 0:
470474
print(
471-
f"[DEBUG] Global base_prefix_sum (before transpose):\n{global_base_prefix_sum}"
475+
f"[DEBUG] Global local_expert_token_count (before transpose):\n{all_expert_token_counts}"
472476
)
473477

474-
transposed_base_prefix_sum = global_base_prefix_sum.T
475-
if rank == 0:
476-
print(f"[DEBUG] Transposed base_prefix_sum:\n{transposed_base_prefix_sum}")
477-
print(f"[DEBUG] Transposed base_prefix_sum: {transposed_base_prefix_sum.shape}")
478+
transposed_base_prefix_sum = all_expert_token_counts.T
479+
if args.debug and rank == 0:
480+
print(f"[DEBUG] Transposed local_expert_token_count:\n{transposed_base_prefix_sum}")
481+
print(f"[DEBUG] Transposed local_expert_token_count: {transposed_base_prefix_sum.shape}")
478482

479483
experts_per_rank = num_experts // dist.get_world_size()
480484
start_expert = rank * experts_per_rank
@@ -484,14 +488,16 @@ def test(
484488
expected_recv = transposed_base_prefix_sum[start_expert:end_expert].reshape(-1)
485489
fused_recv = fused_ep_recv_count
486490

487-
print(f"expected_recv: {expected_recv}")
488-
print(f"fused_recv: {fused_recv}")
491+
if args.debug:
492+
print(f"expected_recv: {expected_recv}")
493+
print(f"fused_recv: {fused_recv}")
489494

490495
diff = (expected_recv - fused_recv).abs()
491-
print(
492-
f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}",
493-
flush=True,
494-
)
496+
if args.debug:
497+
print(
498+
f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}",
499+
flush=True,
500+
)
495501

496502
max_recv_count_diff = diff.max().item()
497503
mean_recv_count_diff = diff.mean().item()
@@ -597,6 +603,12 @@ def str_to_bool(value):
597603
default=-1,
598604
help="If >=0, drop this specific top-k column (set index to -1 for testing).",
599605
)
606+
parser.add_argument(
607+
"--debug",
608+
action="store_true",
609+
default=False,
610+
help="Enable debug logging.",
611+
)
600612

601613
args = parser.parse_args()
602614
num_processes = args.num_processes

0 commit comments

Comments
 (0)