-
Notifications
You must be signed in to change notification settings - Fork 53
[Test] Testing the generalization of fused moe #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
5f07e2c
ea04152
9053340
9c28fbc
26025ad
015bbfd
d9e3c18
4139531
8f21280
3bf64de
76f5c3a
e02f612
9804516
5bb7c55
bf3f45e
752f68c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,16 +12,33 @@ | |
| from utils import bench, calc_diff, hash_tensor, init_dist | ||
|
|
||
| torch_npu.npu.config.allow_internal_format = True | ||
| test_topk_minus1 = False | ||
| small_bs_flag = False | ||
|
|
||
|
|
||
| # ======================== Weight Initialization ======================== | ||
| def init_base_weights(): | ||
| w13_weight = torch.randint(-16, 16, [16, 4096, 7168]).to(torch.int8) | ||
| w2_weight = torch.randint(-16, 16, [16, 7168, 2048]).to(torch.int8) | ||
| w13_weight_scale = (torch.rand([16, 4096, 1]) * 0.0004 + 0.0015).bfloat16() | ||
| w2_weight_scale = (torch.rand([16, 7168, 1]) * 0.0004 + 0.0015).bfloat16() | ||
| def init_base_weights( | ||
| num_local_experts, hidden_in=7168, hidden_mid=4096, hidden_out=2048 | ||
| ): | ||
| """ | ||
| Initialize the weights for each local expert. | ||
| `num_local_experts`: Number of experts per rank = `num_experts` // `num_ranks` | ||
| `hidden_in`: Input dimension (default 7168) | ||
| `hidden_mid`: Intermediate layer dimension (default 4096) | ||
| `hidden_out`: Output dimension (default 2048) | ||
| """ | ||
|
|
||
| w13_weight = torch.randint( | ||
| -16, 16, [num_local_experts, hidden_mid, hidden_in], dtype=torch.int8 | ||
| ) | ||
| w2_weight = torch.randint( | ||
| -16, 16, [num_local_experts, hidden_in, hidden_out], dtype=torch.int8 | ||
| ) | ||
|
|
||
| w13_weight_scale = ( | ||
| torch.rand([num_local_experts, hidden_mid, 1]) * 0.0004 + 0.0015 | ||
| ).bfloat16() | ||
| w2_weight_scale = ( | ||
| torch.rand([num_local_experts, hidden_in, 1]) * 0.0004 + 0.0015 | ||
| ).bfloat16() | ||
|
|
||
| return w13_weight, w13_weight_scale, w2_weight, w2_weight_scale | ||
|
|
||
|
|
@@ -238,6 +255,7 @@ def test( | |
| ), "Too many ranks (exceeding test precision limit)" | ||
|
|
||
| x = torch.rand((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * 10 - 5 | ||
|
|
||
| # ----- Routing(topk_idx) ----- | ||
| if args.active_ranks: | ||
| try: | ||
|
|
@@ -289,7 +307,10 @@ def test( | |
| topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] | ||
|
|
||
| # ----- Weights ----- | ||
| w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights() | ||
| w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights( | ||
| num_local_experts=num_local_experts, | ||
| hidden_in=hidden, | ||
| ) | ||
| w13, w13_scale, w2, w2_scale = init_baseline_weights( | ||
| w13_weight.clone().detach(), | ||
| w13_weight_scale.clone().detach(), | ||
|
|
@@ -316,96 +337,112 @@ def test( | |
| for r in range(num_ranks): | ||
| start, end = r * experts_per_rank, (r + 1) * experts_per_rank | ||
| tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum() | ||
| print(f"Tokens per rank: {tokens_per_rank}") | ||
|
|
||
| # ----- Random drop ----- | ||
| if args.drop_prob > 0: | ||
| drop_mask = torch.rand_like(topk_idx, dtype=torch.float32) < args.drop_prob | ||
| topk_idx = topk_idx.masked_fill(drop_mask, -1) | ||
| for i in range(num_tokens): | ||
| if (topk_idx[i] == -1).all(): | ||
| topk_idx[i, 0] = torch.topk(scores[i], 1, largest=True)[1].item() | ||
| print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True) | ||
|
|
||
| # ====== ensure topk_weights is defined (fix missing var) ====== | ||
| topk_weights = torch.randn( | ||
| (num_tokens, num_topk), dtype=torch.float32, device="npu" | ||
| ).abs() | ||
|
|
||
| # ====== cumulative stats and flags ====== | ||
| cumulative_local_expert_recv_stats = torch.zeros( | ||
| (num_local_experts,), dtype=torch.int, device="npu" | ||
| (num_local_experts,), dtype=torch.int32, device="npu" | ||
| ) | ||
| return_recv_hook = False | ||
|
|
||
| hidden_states = x | ||
|
|
||
| if small_bs_flag and rank == 0: | ||
| # Test with a small batch size of 1 | ||
| x = x[:1, :] | ||
| topk_idx = topk_idx[:1, :] | ||
| topk_weights = topk_weights[:1, :] | ||
|
|
||
| if test_topk_minus1: | ||
| topk_idx_minus1 = topk_idx.clone() | ||
| topk_idx_minus1[:, -2:-1] = -1 | ||
| topk_weights_minus1 = topk_weights.clone() | ||
| topk_weights_minus1[:, -2:-1] = 0 | ||
| # ----- Baseline ----- | ||
| baseline_output, base_ep_recv_count = baseline_test( | ||
| buffer2, | ||
| x, | ||
| topk_idx, | ||
| num_tokens, | ||
| num_experts, | ||
| cumulative_local_expert_recv_stats, | ||
| return_recv_hook, | ||
| w13, | ||
| w13_scale, | ||
| w2, | ||
| w2_scale, | ||
| topk_weights_minus1, | ||
| ) | ||
| # ----- Fused ----- | ||
| fused_output, fused_ep_recv_count = buffer.fused_deep_moe( | ||
| x, | ||
| topk_idx_minus1, | ||
| topk_weights, | ||
| w13_f, | ||
| w13s_f, | ||
| w2_f, | ||
| w2s_f, | ||
| num_tokens, | ||
| num_experts, | ||
| 0, | ||
| ) | ||
| # ----- Random or fixed drop ----- | ||
| if args.topk_drop_prob > 0 or args.topk_drop_col >= 0: | ||
| topk_idx_dropped = topk_idx.clone() | ||
| topk_weights_dropped = topk_weights.clone() | ||
|
|
||
| # Random drop (based on probability) | ||
| if args.topk_drop_prob > 0: | ||
| drop_mask = ( | ||
| torch.rand_like(topk_idx, dtype=torch.float32) < args.topk_drop_prob | ||
| ) | ||
| topk_idx_dropped = topk_idx.clone() | ||
| topk_idx_dropped = topk_idx_dropped.masked_fill(drop_mask, -1) | ||
|
|
||
| # Guarantee that each token has at least one valid expert. | ||
| for i in range(num_tokens): | ||
| if (topk_idx_dropped[i] == -1).all(): | ||
| topk_idx_dropped[i, 0] = torch.topk(scores[i], 1, largest=True)[ | ||
| 1 | ||
| ].item() | ||
|
|
||
| # Construct topk_weights_dropped | ||
| invalid_mask = topk_idx_dropped == -1 | ||
| topk_weights_dropped = topk_weights_dropped.masked_fill(invalid_mask, 0.0) | ||
|
|
||
| # Fixed column drop (for the test_topk_minus1 scenario) | ||
| if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk: | ||
| topk_idx_dropped[:, args.topk_drop_col] = -1 | ||
| topk_weights_dropped[:, args.topk_drop_col] = 0 | ||
|
|
||
| print( | ||
| f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}", | ||
| flush=True, | ||
| ) | ||
| print( | ||
| f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}", | ||
| flush=True, | ||
| ) | ||
|
|
||
| # print drop ratio | ||
| drop_ratio = (topk_idx_dropped == -1).float().mean().item() | ||
| if rank == 0: | ||
| print( | ||
| f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%", | ||
| flush=True, | ||
| ) | ||
| else: | ||
| # ----- Baseline ----- | ||
| baseline_output, base_ep_recv_count = baseline_test( | ||
| buffer2, | ||
| x, | ||
| topk_idx, | ||
| num_tokens, | ||
| num_experts, | ||
| cumulative_local_expert_recv_stats, | ||
| return_recv_hook, | ||
| w13, | ||
| w13_scale, | ||
| w2, | ||
| w2_scale, | ||
| topk_weights, | ||
| ) | ||
| topk_idx_dropped = topk_idx | ||
| topk_weights_dropped = topk_weights | ||
|
|
||
| # ----- Fused ----- | ||
| fused_output, fused_ep_recv_count = buffer.fused_deep_moe( | ||
| x, | ||
| topk_idx, | ||
| topk_weights, | ||
| w13_f, | ||
| w13s_f, | ||
| w2_f, | ||
| w2s_f, | ||
| num_tokens, | ||
| num_experts, | ||
| 0, | ||
| # Expert meta | ||
| num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="npu") | ||
| for i in range(num_experts): | ||
| num_tokens_per_expert[i] = (topk_idx_dropped == i).sum() | ||
| gbl_num_tokens_per_expert = num_tokens_per_expert.clone() | ||
| dist.all_reduce(gbl_num_tokens_per_expert, group=group) | ||
|
|
||
| print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}") | ||
| if rank == 0: | ||
| print( | ||
| f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}" | ||
| ) | ||
| base_prefix_sum = num_tokens_per_expert.clone() | ||
|
||
|
|
||
| # ----- Baseline ----- | ||
| baseline_output, base_ep_recv_count = baseline_test( | ||
| buffer2, | ||
| x, | ||
| topk_idx, | ||
| num_tokens, | ||
| num_experts, | ||
| cumulative_local_expert_recv_stats, | ||
| return_recv_hook, | ||
| w13, | ||
| w13_scale, | ||
| w2, | ||
| w2_scale, | ||
| topk_weights_dropped, | ||
| ) | ||
|
|
||
| # ----- Fused ----- | ||
| fused_output, fused_ep_recv_count = buffer.fused_deep_moe( | ||
| x, | ||
| topk_idx_dropped, | ||
| topk_weights, | ||
| w13_f, | ||
| w13s_f, | ||
| w2_f, | ||
| w2s_f, | ||
| num_tokens, | ||
| num_experts, | ||
| 0, | ||
| ) | ||
|
|
||
| # ----- Compare Outputs ----- | ||
| max_diff = torch.max(torch.abs(fused_output - baseline_output)).item() | ||
|
|
@@ -415,30 +452,63 @@ def test( | |
|
|
||
| print( | ||
| f"[Rank {rank}] baseline_avg={baseline_output_avg:.6e}, fused_avg={fused_output_avg:.6e}, " | ||
| f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}" | ||
| f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}", | ||
| flush=True, | ||
| ) | ||
|
|
||
| assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}" | ||
|
|
||
| # ----- Compare RecvCount ----- | ||
| recv_count_diff = ( | ||
| from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count | ||
| ).abs() | ||
| max_recv_count_diff = recv_count_diff.max().item() | ||
| mean_recv_count_diff = recv_count_diff.mean().item() | ||
| # ----- Compare Recv Count ----- | ||
| global_base_prefix_sum = [ | ||
| torch.zeros_like(base_prefix_sum) for _ in range(num_ranks) | ||
| ] | ||
| dist.all_gather(global_base_prefix_sum, base_prefix_sum) | ||
|
|
||
| global_base_prefix_sum = torch.stack(global_base_prefix_sum, dim=0) | ||
|
|
||
| if rank == 0: | ||
| print( | ||
| f"[DEBUG] Global base_prefix_sum (before transpose):\n{global_base_prefix_sum}" | ||
| ) | ||
|
|
||
| transposed_base_prefix_sum = global_base_prefix_sum.T | ||
| if rank == 0: | ||
| print(f"[DEBUG] Transposed base_prefix_sum:\n{transposed_base_prefix_sum}") | ||
| print(f"[DEBUG] Transposed base_prefix_sum: {transposed_base_prefix_sum.shape}") | ||
|
|
||
| experts_per_rank = num_experts // dist.get_world_size() | ||
| start_expert = rank * experts_per_rank | ||
| end_expert = start_expert + experts_per_rank | ||
|
|
||
| # shape [experts_per_rank * num_ranks] | ||
| expected_recv = transposed_base_prefix_sum[start_expert:end_expert].reshape(-1) | ||
| fused_recv = fused_ep_recv_count | ||
|
|
||
| print(f"expected_recv: {expected_recv}") | ||
| print(f"fused_recv: {fused_recv}") | ||
|
|
||
| diff = (expected_recv - fused_recv).abs() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lack of assertion; it is recommended to add the following code: assert torch.all(diff == 0), (
f"Recv count mismatch on rank {rank}. Max difference: {diff.max().item()}",
f"\nExpected:\n{expected_recv}\nActual:\n{fused_recv}"
)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check already exists below. |
||
| print( | ||
| f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}" | ||
| f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}", | ||
| flush=True, | ||
| ) | ||
|
|
||
| if not test_topk_minus1: | ||
| assert ( | ||
| max_recv_count_diff < 1e-4 | ||
| ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" | ||
| max_recv_count_diff = diff.max().item() | ||
| mean_recv_count_diff = diff.mean().item() | ||
| print( | ||
| f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}", | ||
| flush=True, | ||
| ) | ||
| assert ( | ||
| max_recv_count_diff < 1e-4 | ||
| ), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}" | ||
|
|
||
|
|
||
| # ======================== Distributed Entry ======================== | ||
| def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): | ||
| rank, num_ranks, group = init_dist(local_rank, num_local_ranks) | ||
| group2 = dist.new_group(list(range(16))) | ||
| group2 = dist.new_group(list(range(num_ranks))) | ||
|
|
||
| shared_expert_rank_num = int(os.getenv("MOE_SHARED_EXPERT_RANK_NUM", 0)) | ||
| num_tokens, hidden = args.num_tokens, args.hidden | ||
| num_topk, num_experts = args.num_topk, args.num_experts | ||
|
|
@@ -511,34 +581,25 @@ def str_to_bool(value): | |
| "--active-ranks", | ||
| type=str, | ||
| default="", | ||
| help="Comma-separated list of ranks that will receive tokens. " | ||
| 'Example: "0,1,3". If empty, all ranks may receive tokens.', | ||
| help='Comma-separated list of ranks that will receive tokens. Example: "0,1,3". If empty, all ranks may receive tokens.', | ||
| ) | ||
| parser.add_argument( | ||
| "--drop-prob", | ||
| "--topk-drop-prob", | ||
| dest="topk_drop_prob", | ||
| type=float, | ||
| default=0.0, | ||
| help="Probability of dropping an individual top-k index (set to -1). " | ||
| "Guaranteed that each token keeps at least one valid expert.", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--minus1-flag", type=str_to_bool, default=False, help="bool flag, True/False" | ||
| help="Probability of randomly dropping a top-k index (set to -1).", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--small-bs-flag", | ||
| type=str_to_bool, | ||
| default=False, | ||
| help="define small bs on certain rank", | ||
| "--topk-drop-col", | ||
| dest="topk_drop_col", | ||
| type=int, | ||
| default=-1, | ||
| help="If >=0, drop this specific top-k column (set index to -1 for testing).", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| num_processes = args.num_processes | ||
| test_topk_minus1 = args.minus1_flag | ||
| small_bs_flag = args.small_bs_flag | ||
|
|
||
| torch.multiprocessing.spawn( | ||
| test_loop, args=(num_processes, args), nprocs=num_processes | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a large number of print statements in the code (including detailed tensor prints for Rank 0). These print statements should be removed or placed under strict debugging conditions (e.g., using
if DEBUG_MODE:or the logging systemlogging.debug) when the code is deployed to production, as they can affect performance and generate excessive output.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add debug mode