diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index 6c5942a6..94aba705 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -72,7 +72,12 @@ jobs: HCCL_BUFFSIZE: 2000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --minus1-flag True --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32 test-build-deepep: if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') && @@ -128,7 +133,12 @@ jobs: HCCL_BUFFSIZE: 2000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py - python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --minus1-flag True --small-bs-flag True + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16 + python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32 finish: if: always() diff --git a/tests/python/deepep/test_fused_deep_moe.py b/tests/python/deepep/test_fused_deep_moe.py index 436cc77b..aefc353f 100644 --- a/tests/python/deepep/test_fused_deep_moe.py +++ b/tests/python/deepep/test_fused_deep_moe.py @@ -12,16 +12,33 @@ from utils import bench, calc_diff, hash_tensor, init_dist torch_npu.npu.config.allow_internal_format = True -test_topk_minus1 = False -small_bs_flag = False # ======================== Weight Initialization ======================== -def init_base_weights(): - w13_weight = torch.randint(-16, 16, [16, 4096, 7168]).to(torch.int8) - w2_weight = torch.randint(-16, 16, [16, 7168, 2048]).to(torch.int8) - w13_weight_scale = (torch.rand([16, 4096, 1]) * 0.0004 + 0.0015).bfloat16() - w2_weight_scale = (torch.rand([16, 7168, 1]) * 0.0004 + 0.0015).bfloat16() +def init_base_weights( + num_local_experts, hidden_in=7168, hidden_mid=4096, hidden_out=2048 +): + """ + Initialize the weights for each local expert. + `num_local_experts`: Number of experts per rank = `num_experts` // `num_ranks` + `hidden_in`: Input dimension (default 7168) + `hidden_mid`: Intermediate layer dimension (default 4096) + `hidden_out`: Output dimension (default 2048) + """ + + w13_weight = torch.randint( + -16, 16, [num_local_experts, hidden_mid, hidden_in], dtype=torch.int8 + ) + w2_weight = torch.randint( + -16, 16, [num_local_experts, hidden_in, hidden_out], dtype=torch.int8 + ) + + w13_weight_scale = ( + torch.rand([num_local_experts, hidden_mid, 1]) * 0.0004 + 0.0015 + ).bfloat16() + w2_weight_scale = ( + torch.rand([num_local_experts, hidden_in, 1]) * 0.0004 + 0.0015 + ).bfloat16() return w13_weight, w13_weight_scale, w2_weight, w2_weight_scale @@ -238,6 +255,7 @@ def test( ), "Too many ranks (exceeding test precision limit)" x = torch.rand((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * 10 - 5 + # ----- Routing(topk_idx) ----- if args.active_ranks: try: @@ -289,7 +307,10 @@ def test( topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] # ----- Weights ----- - w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights() + w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights( + num_local_experts=num_local_experts, + hidden_in=hidden, + ) w13, w13_scale, w2, w2_scale = init_baseline_weights( w13_weight.clone().detach(), w13_weight_scale.clone().detach(), @@ -303,7 +324,7 @@ def test( w2_weight_scale.clone().detach(), ) - if rank == 0: + if args.debug and rank == 0: print("=== Check fused weights ===") print("w13_f:", w13_f.shape, w13_f.dtype, w13_f.device) print("w13s_f:", w13s_f.shape, w13s_f.dtype, w13s_f.device) @@ -316,96 +337,115 @@ def test( for r in range(num_ranks): start, end = r * experts_per_rank, (r + 1) * experts_per_rank tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum() - print(f"Tokens per rank: {tokens_per_rank}") - # ----- Random drop ----- - if args.drop_prob > 0: - drop_mask = torch.rand_like(topk_idx, dtype=torch.float32) < args.drop_prob - topk_idx = topk_idx.masked_fill(drop_mask, -1) - for i in range(num_tokens): - if (topk_idx[i] == -1).all(): - topk_idx[i, 0] = torch.topk(scores[i], 1, largest=True)[1].item() + if args.debug: + print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True) + # ====== ensure topk_weights is defined (fix missing var) ====== topk_weights = torch.randn( (num_tokens, num_topk), dtype=torch.float32, device="npu" ).abs() + + # ====== cumulative stats and flags ====== cumulative_local_expert_recv_stats = torch.zeros( - (num_local_experts,), dtype=torch.int, device="npu" + (num_local_experts,), dtype=torch.int32, device="npu" ) return_recv_hook = False - hidden_states = x - - if small_bs_flag and rank == 0: - # Test with a small batch size of 1 - x = x[:1, :] - topk_idx = topk_idx[:1, :] - topk_weights = topk_weights[:1, :] - - if test_topk_minus1: - topk_idx_minus1 = topk_idx.clone() - topk_idx_minus1[:, -2:-1] = -1 - topk_weights_minus1 = topk_weights.clone() - topk_weights_minus1[:, -2:-1] = 0 - # ----- Baseline ----- - baseline_output, base_ep_recv_count = baseline_test( - buffer2, - x, - topk_idx, - num_tokens, - num_experts, - cumulative_local_expert_recv_stats, - return_recv_hook, - w13, - w13_scale, - w2, - w2_scale, - topk_weights_minus1, - ) - # ----- Fused ----- - fused_output, fused_ep_recv_count = buffer.fused_deep_moe( - x, - topk_idx_minus1, - topk_weights, - w13_f, - w13s_f, - w2_f, - w2s_f, - num_tokens, - num_experts, - 0, - ) + # ----- Random or fixed drop ----- + if args.topk_drop_prob > 0 or args.topk_drop_col >= 0: + topk_idx_dropped = topk_idx.clone() + topk_weights_dropped = topk_weights.clone() + # Random drop (based on probability) + if args.topk_drop_prob > 0: + drop_mask = ( + torch.rand_like(topk_idx, dtype=torch.float32) < args.topk_drop_prob + ) + topk_idx_dropped = topk_idx.clone() + topk_idx_dropped = topk_idx_dropped.masked_fill(drop_mask, -1) + + # Guarantee that each token has at least one valid expert. + for i in range(num_tokens): + if (topk_idx_dropped[i] == -1).all(): + topk_idx_dropped[i, 0] = torch.topk(scores[i], 1, largest=True)[ + 1 + ].item() + + # Construct topk_weights_dropped + invalid_mask = topk_idx_dropped == -1 + topk_weights_dropped = topk_weights_dropped.masked_fill(invalid_mask, 0.0) + + # Fixed column drop (for the test_topk_minus1 scenario) + if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk: + topk_idx_dropped[:, args.topk_drop_col] = -1 + topk_weights_dropped[:, args.topk_drop_col] = 0 + if args.debug: + print( + f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", + flush=True, + ) + print( + f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", + flush=True, + ) + + # print drop ratio + drop_ratio = (topk_idx_dropped == -1).float().mean().item() + if rank == 0: + print( + f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", + flush=True, + ) else: - # ----- Baseline ----- - baseline_output, base_ep_recv_count = baseline_test( - buffer2, - x, - topk_idx, - num_tokens, - num_experts, - cumulative_local_expert_recv_stats, - return_recv_hook, - w13, - w13_scale, - w2, - w2_scale, - topk_weights, - ) + topk_idx_dropped = topk_idx + topk_weights_dropped = topk_weights + + # Expert meta + num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="npu") + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx_dropped == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + if args.debug: + print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}") + if rank == 0: + print( + f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}" + ) - # ----- Fused ----- - fused_output, fused_ep_recv_count = buffer.fused_deep_moe( - x, - topk_idx, - topk_weights, - w13_f, - w13s_f, - w2_f, - w2s_f, - num_tokens, - num_experts, - 0, - ) + local_expert_token_count = num_tokens_per_expert.clone() + + # ----- Baseline ----- + baseline_output, base_ep_recv_count = baseline_test( + buffer2, + x, + topk_idx, + num_tokens, + num_experts, + cumulative_local_expert_recv_stats, + return_recv_hook, + w13, + w13_scale, + w2, + w2_scale, + topk_weights_dropped, + ) + + # ----- Fused ----- + fused_output, fused_ep_recv_count = buffer.fused_deep_moe( + x, + topk_idx_dropped, + topk_weights, + w13_f, + w13s_f, + w2_f, + w2s_f, + num_tokens, + num_experts, + 0, + ) # ----- Compare Outputs ----- max_diff = torch.max(torch.abs(fused_output - baseline_output)).item() @@ -415,30 +455,69 @@ def test( print( f"[Rank {rank}] baseline_avg={baseline_output_avg:.6e}, fused_avg={fused_output_avg:.6e}, " - f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}" + f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}", + flush=True, ) + assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}" - # ----- Compare RecvCount ----- - recv_count_diff = ( - from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count - ).abs() - max_recv_count_diff = recv_count_diff.max().item() - mean_recv_count_diff = recv_count_diff.mean().item() + # ----- Compare Recv Count ----- + all_expert_token_counts = [ + torch.zeros_like(local_expert_token_count) for _ in range(num_ranks) + ] + dist.all_gather(all_expert_token_counts, local_expert_token_count) + + all_expert_token_counts = torch.stack(all_expert_token_counts, dim=0) + + if args.debug and rank == 0: + print( + f"[DEBUG] Global local_expert_token_count (before transpose):\n{all_expert_token_counts}" + ) + + transposed_base_prefix_sum = all_expert_token_counts.T + if args.debug and rank == 0: + print( + f"[DEBUG] Transposed local_expert_token_count:\n{transposed_base_prefix_sum}" + ) + print( + f"[DEBUG] Transposed local_expert_token_count: {transposed_base_prefix_sum.shape}" + ) + + experts_per_rank = num_experts // dist.get_world_size() + start_expert = rank * experts_per_rank + end_expert = start_expert + experts_per_rank + + # shape [experts_per_rank * num_ranks] + expected_recv = transposed_base_prefix_sum[start_expert:end_expert].reshape(-1) + fused_recv = fused_ep_recv_count + + if args.debug: + print(f"expected_recv: {expected_recv}") + print(f"fused_recv: {fused_recv}") + + diff = (expected_recv - fused_recv).abs() + if args.debug: + print( + f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}", + flush=True, + ) + + max_recv_count_diff = diff.max().item() + mean_recv_count_diff = diff.mean().item() print( - f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}" + f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}", + flush=True, ) - - if not test_topk_minus1: - assert ( - max_recv_count_diff < 1e-4 - ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" + assert ( + max_recv_count_diff < 1e-4 + ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" # ======================== Distributed Entry ======================== def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - group2 = dist.new_group(list(range(16))) + group2 = dist.new_group(list(range(num_ranks))) + shared_expert_rank_num = int(os.getenv("MOE_SHARED_EXPERT_RANK_NUM", 0)) num_tokens, hidden = args.num_tokens, args.hidden num_topk, num_experts = args.num_topk, args.num_experts @@ -511,34 +590,31 @@ def str_to_bool(value): "--active-ranks", type=str, default="", - help="Comma-separated list of ranks that will receive tokens. " - 'Example: "0,1,3". If empty, all ranks may receive tokens.', + help='Comma-separated list of ranks that will receive tokens. Example: "0,1,3". If empty, all ranks may receive tokens.', ) parser.add_argument( - "--drop-prob", + "--topk-drop-prob", + dest="topk_drop_prob", type=float, default=0.0, - help="Probability of dropping an individual top-k index (set to -1). " - "Guaranteed that each token keeps at least one valid expert.", + help="Probability of randomly dropping a top-k index (set to -1).", ) - parser.add_argument( - "--minus1-flag", type=str_to_bool, default=False, help="bool flag, True/False" + "--topk-drop-col", + dest="topk_drop_col", + type=int, + default=-1, + help="If >=0, drop this specific top-k column (set index to -1 for testing).", ) - parser.add_argument( - "--small-bs-flag", - type=str_to_bool, + "--debug", + action="store_true", default=False, - help="define small bs on certain rank", + help="Enable debug logging.", ) args = parser.parse_args() - num_processes = args.num_processes - test_topk_minus1 = args.minus1_flag - small_bs_flag = args.small_bs_flag - torch.multiprocessing.spawn( test_loop, args=(num_processes, args), nprocs=num_processes )