Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions .github/workflows/pr-test-npu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ jobs:
HCCL_BUFFSIZE: 2000
run: |
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --minus1-flag True --small-bs-flag True
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32

test-build-deepep:
if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
Expand Down Expand Up @@ -128,7 +133,12 @@ jobs:
HCCL_BUFFSIZE: 2000
run: |
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --minus1-flag True --small-bs-flag True
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-col 3
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --topk-drop-prob 0.3
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 2
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 2 --topk-drop-col 3 --num-experts 16
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_fused_deep_moe.py --num-tokens 4 --topk-drop-col 1 --num-experts 32

finish:
if: always()
Expand Down
293 changes: 177 additions & 116 deletions tests/python/deepep/test_fused_deep_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,33 @@
from utils import bench, calc_diff, hash_tensor, init_dist

torch_npu.npu.config.allow_internal_format = True
test_topk_minus1 = False
small_bs_flag = False


# ======================== Weight Initialization ========================
def init_base_weights():
w13_weight = torch.randint(-16, 16, [16, 4096, 7168]).to(torch.int8)
w2_weight = torch.randint(-16, 16, [16, 7168, 2048]).to(torch.int8)
w13_weight_scale = (torch.rand([16, 4096, 1]) * 0.0004 + 0.0015).bfloat16()
w2_weight_scale = (torch.rand([16, 7168, 1]) * 0.0004 + 0.0015).bfloat16()
def init_base_weights(
num_local_experts, hidden_in=7168, hidden_mid=4096, hidden_out=2048
):
"""
Initialize the weights for each local expert.
`num_local_experts`: Number of experts per rank = `num_experts` // `num_ranks`
`hidden_in`: Input dimension (default 7168)
`hidden_mid`: Intermediate layer dimension (default 4096)
`hidden_out`: Output dimension (default 2048)
"""

w13_weight = torch.randint(
-16, 16, [num_local_experts, hidden_mid, hidden_in], dtype=torch.int8
)
w2_weight = torch.randint(
-16, 16, [num_local_experts, hidden_in, hidden_out], dtype=torch.int8
)

w13_weight_scale = (
torch.rand([num_local_experts, hidden_mid, 1]) * 0.0004 + 0.0015
).bfloat16()
w2_weight_scale = (
torch.rand([num_local_experts, hidden_in, 1]) * 0.0004 + 0.0015
).bfloat16()

return w13_weight, w13_weight_scale, w2_weight, w2_weight_scale

Expand Down Expand Up @@ -238,6 +255,7 @@ def test(
), "Too many ranks (exceeding test precision limit)"

x = torch.rand((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * 10 - 5

# ----- Routing(topk_idx) -----
if args.active_ranks:
try:
Expand Down Expand Up @@ -289,7 +307,10 @@ def test(
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]

# ----- Weights -----
w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights()
w13_weight, w13_weight_scale, w2_weight, w2_weight_scale = init_base_weights(
num_local_experts=num_local_experts,
hidden_in=hidden,
)
w13, w13_scale, w2, w2_scale = init_baseline_weights(
w13_weight.clone().detach(),
w13_weight_scale.clone().detach(),
Expand All @@ -316,96 +337,112 @@ def test(
for r in range(num_ranks):
start, end = r * experts_per_rank, (r + 1) * experts_per_rank
tokens_per_rank[r] = ((topk_idx >= start) & (topk_idx < end)).sum()
print(f"Tokens per rank: {tokens_per_rank}")

# ----- Random drop -----
if args.drop_prob > 0:
drop_mask = torch.rand_like(topk_idx, dtype=torch.float32) < args.drop_prob
topk_idx = topk_idx.masked_fill(drop_mask, -1)
for i in range(num_tokens):
if (topk_idx[i] == -1).all():
topk_idx[i, 0] = torch.topk(scores[i], 1, largest=True)[1].item()
print(f"[DEBUG] Tokens per rank: {tokens_per_rank}", flush=True)

# ====== ensure topk_weights is defined (fix missing var) ======
topk_weights = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="npu"
).abs()

# ====== cumulative stats and flags ======
cumulative_local_expert_recv_stats = torch.zeros(
(num_local_experts,), dtype=torch.int, device="npu"
(num_local_experts,), dtype=torch.int32, device="npu"
)
return_recv_hook = False

hidden_states = x

if small_bs_flag and rank == 0:
# Test with a small batch size of 1
x = x[:1, :]
topk_idx = topk_idx[:1, :]
topk_weights = topk_weights[:1, :]

if test_topk_minus1:
topk_idx_minus1 = topk_idx.clone()
topk_idx_minus1[:, -2:-1] = -1
topk_weights_minus1 = topk_weights.clone()
topk_weights_minus1[:, -2:-1] = 0
# ----- Baseline -----
baseline_output, base_ep_recv_count = baseline_test(
buffer2,
x,
topk_idx,
num_tokens,
num_experts,
cumulative_local_expert_recv_stats,
return_recv_hook,
w13,
w13_scale,
w2,
w2_scale,
topk_weights_minus1,
)
# ----- Fused -----
fused_output, fused_ep_recv_count = buffer.fused_deep_moe(
x,
topk_idx_minus1,
topk_weights,
w13_f,
w13s_f,
w2_f,
w2s_f,
num_tokens,
num_experts,
0,
)
# ----- Random or fixed drop -----
if args.topk_drop_prob > 0 or args.topk_drop_col >= 0:
topk_idx_dropped = topk_idx.clone()
topk_weights_dropped = topk_weights.clone()

# Random drop (based on probability)
if args.topk_drop_prob > 0:
drop_mask = (
torch.rand_like(topk_idx, dtype=torch.float32) < args.topk_drop_prob
)
topk_idx_dropped = topk_idx.clone()
topk_idx_dropped = topk_idx_dropped.masked_fill(drop_mask, -1)

# Guarantee that each token has at least one valid expert.
for i in range(num_tokens):
if (topk_idx_dropped[i] == -1).all():
topk_idx_dropped[i, 0] = torch.topk(scores[i], 1, largest=True)[
1
].item()

# Construct topk_weights_dropped
invalid_mask = topk_idx_dropped == -1
topk_weights_dropped = topk_weights_dropped.masked_fill(invalid_mask, 0.0)

# Fixed column drop (for the test_topk_minus1 scenario)
if args.topk_drop_col >= 0 and args.topk_drop_col < num_topk:
topk_idx_dropped[:, args.topk_drop_col] = -1
topk_weights_dropped[:, args.topk_drop_col] = 0

print(
f"[DEBUG] [rank {rank}] topk_idx_dropped (after fixed-column drop):\n{topk_idx_dropped.cpu().numpy()}",
flush=True,
)
print(
f"[DEBUG] [rank {rank}] topk_weights_dropped (after fixed-column drop):\n{topk_weights_dropped.cpu().numpy()}",
flush=True,
)

# print drop ratio
drop_ratio = (topk_idx_dropped == -1).float().mean().item()
if rank == 0:
print(
f"[DEBUG] [rank {rank}] topk dropped ratio = {drop_ratio*100:.2f}%",
flush=True,
)
else:
# ----- Baseline -----
baseline_output, base_ep_recv_count = baseline_test(
buffer2,
x,
topk_idx,
num_tokens,
num_experts,
cumulative_local_expert_recv_stats,
return_recv_hook,
w13,
w13_scale,
w2,
w2_scale,
topk_weights,
)
topk_idx_dropped = topk_idx
topk_weights_dropped = topk_weights

# ----- Fused -----
fused_output, fused_ep_recv_count = buffer.fused_deep_moe(
x,
topk_idx,
topk_weights,
w13_f,
w13s_f,
w2_f,
w2s_f,
num_tokens,
num_experts,
0,
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="npu")
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx_dropped == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)

print(f"[Rank {rank}] num_tokens_per_expert: {num_tokens_per_expert.tolist()}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a large number of print statements in the code (including detailed tensor prints for Rank 0). These print statements should be removed or placed under strict debugging conditions (e.g., using if DEBUG_MODE: or the logging system logging.debug) when the code is deployed to production, as they can affect performance and generate excessive output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add debug mode

if rank == 0:
print(
f"[Rank {rank}] gbl_num_tokens_per_expert: {gbl_num_tokens_per_expert.tolist()}"
)
base_prefix_sum = num_tokens_per_expert.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The term "base_prefix_sum" is too generic and does not specify what it counts.
If it counts the tokens sent to all experts, it is recommended to rename it to something like "local_expert_send_counts" or "local_expert_token_counts" to improve the self-explanatory nature of the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to local_expert_token_counts


# ----- Baseline -----
baseline_output, base_ep_recv_count = baseline_test(
buffer2,
x,
topk_idx,
num_tokens,
num_experts,
cumulative_local_expert_recv_stats,
return_recv_hook,
w13,
w13_scale,
w2,
w2_scale,
topk_weights_dropped,
)

# ----- Fused -----
fused_output, fused_ep_recv_count = buffer.fused_deep_moe(
x,
topk_idx_dropped,
topk_weights,
w13_f,
w13s_f,
w2_f,
w2s_f,
num_tokens,
num_experts,
0,
)

# ----- Compare Outputs -----
max_diff = torch.max(torch.abs(fused_output - baseline_output)).item()
Expand All @@ -415,30 +452,63 @@ def test(

print(
f"[Rank {rank}] baseline_avg={baseline_output_avg:.6e}, fused_avg={fused_output_avg:.6e}, "
f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}"
f"max_diff={max_diff:.6e}, avg_diff={avg_diff:.6e}",
flush=True,
)

assert avg_diff < 1e-4, f"[Rank {rank}] Mismatch detected! diff={avg_diff}"

# ----- Compare RecvCount -----
recv_count_diff = (
from_inclusive_prefix_sum(base_ep_recv_count) - fused_ep_recv_count
).abs()
max_recv_count_diff = recv_count_diff.max().item()
mean_recv_count_diff = recv_count_diff.mean().item()
# ----- Compare Recv Count -----
global_base_prefix_sum = [
torch.zeros_like(base_prefix_sum) for _ in range(num_ranks)
]
dist.all_gather(global_base_prefix_sum, base_prefix_sum)

global_base_prefix_sum = torch.stack(global_base_prefix_sum, dim=0)

if rank == 0:
print(
f"[DEBUG] Global base_prefix_sum (before transpose):\n{global_base_prefix_sum}"
)

transposed_base_prefix_sum = global_base_prefix_sum.T
if rank == 0:
print(f"[DEBUG] Transposed base_prefix_sum:\n{transposed_base_prefix_sum}")
print(f"[DEBUG] Transposed base_prefix_sum: {transposed_base_prefix_sum.shape}")

experts_per_rank = num_experts // dist.get_world_size()
start_expert = rank * experts_per_rank
end_expert = start_expert + experts_per_rank

# shape [experts_per_rank * num_ranks]
expected_recv = transposed_base_prefix_sum[start_expert:end_expert].reshape(-1)
fused_recv = fused_ep_recv_count

print(f"expected_recv: {expected_recv}")
print(f"fused_recv: {fused_recv}")

diff = (expected_recv - fused_recv).abs()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lack of assertion; it is recommended to add the following code:

assert torch.all(diff == 0), (
f"Recv count mismatch on rank {rank}. Max difference: {diff.max().item()}",
f"\nExpected:\n{expected_recv}\nActual:\n{fused_recv}"
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check already exists below.

print(
f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}"
f"[Rank {rank}] diff (experts {start_expert}~{end_expert-1}): {diff.cpu().numpy()}",
flush=True,
)

if not test_topk_minus1:
assert (
max_recv_count_diff < 1e-4
), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}"
max_recv_count_diff = diff.max().item()
mean_recv_count_diff = diff.mean().item()
print(
f"[Rank {rank}] Difference between base and fused recv_count -> max: {max_recv_count_diff}, mean: {mean_recv_count_diff}",
flush=True,
)
assert (
max_recv_count_diff < 1e-4
), f"[Rank {rank}] Mismatch detected! diff={max_recv_count_diff}"


# ======================== Distributed Entry ========================
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
group2 = dist.new_group(list(range(16)))
group2 = dist.new_group(list(range(num_ranks)))

shared_expert_rank_num = int(os.getenv("MOE_SHARED_EXPERT_RANK_NUM", 0))
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
Expand Down Expand Up @@ -511,34 +581,25 @@ def str_to_bool(value):
"--active-ranks",
type=str,
default="",
help="Comma-separated list of ranks that will receive tokens. "
'Example: "0,1,3". If empty, all ranks may receive tokens.',
help='Comma-separated list of ranks that will receive tokens. Example: "0,1,3". If empty, all ranks may receive tokens.',
)
parser.add_argument(
"--drop-prob",
"--topk-drop-prob",
dest="topk_drop_prob",
type=float,
default=0.0,
help="Probability of dropping an individual top-k index (set to -1). "
"Guaranteed that each token keeps at least one valid expert.",
)

parser.add_argument(
"--minus1-flag", type=str_to_bool, default=False, help="bool flag, True/False"
help="Probability of randomly dropping a top-k index (set to -1).",
)

parser.add_argument(
"--small-bs-flag",
type=str_to_bool,
default=False,
help="define small bs on certain rank",
"--topk-drop-col",
dest="topk_drop_col",
type=int,
default=-1,
help="If >=0, drop this specific top-k column (set index to -1 for testing).",
)

args = parser.parse_args()

num_processes = args.num_processes
test_topk_minus1 = args.minus1_flag
small_bs_flag = args.small_bs_flag

torch.multiprocessing.spawn(
test_loop, args=(num_processes, args), nprocs=num_processes
)