Skip to content

Commit adf49e3

Browse files
authored
Optimization: EPV1 dispatch & combine kernel (#128)
- optimize dispatch recv / combine inter node phase for better CU utilization - separate dispatch staging buffer copy / combine all phase to single kernel to use more CU
1 parent 6c364d9 commit adf49e3

6 files changed

Lines changed: 684 additions & 124 deletions

File tree

examples/ops/dispatch_combine/test_dispatch_combine_internode.py

Lines changed: 149 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
hidden_dim=7168,
5959
scale_dim=32,
6060
scale_type_size=4,
61-
max_num_inp_token_per_rank=max_tokens,
61+
max_num_inp_token_per_rank=(max_tokens + 63) // 64 * 64,
6262
num_experts_per_rank=16,
6363
num_experts_per_token=8,
6464
warp_num_per_block=8,
@@ -129,16 +129,16 @@ def _allgather_with_token_num_padding(self, input, max_token_num):
129129
dist.all_gather(output, padded_input)
130130
return output
131131

132-
def gen_test_data(self, use_max_token_num=False):
132+
def gen_test_data(self, max_num_token, use_max_token_num=False):
133133
# gen num_tokens
134134
if use_max_token_num:
135135
num_token = torch.tensor(
136-
[self.config.max_num_inp_token_per_rank for i in range(self.world_size)]
136+
[max_num_token for i in range(self.world_size)]
137137
).to(self.device)
138138
else:
139139
num_token = torch.randint(
140140
1,
141-
self.config.max_num_inp_token_per_rank + 1,
141+
max_num_token + 1,
142142
[self.world_size],
143143
generator=self.rng,
144144
device=self.device,
@@ -158,19 +158,21 @@ def gen_test_data(self, use_max_token_num=False):
158158
device=self.device,
159159
)
160160
# argsort gives us a random permutation, take first K columns
161-
indices = torch.argsort(random_vals, dim=1)[:, : self.config.num_experts_per_token]
161+
indices = torch.argsort(random_vals, dim=1)[
162+
:, : self.config.num_experts_per_token
163+
]
162164
all_rank_indices.append(indices.to(torch.int32))
163165

164166
# num_total_experts = self.config.num_experts_per_rank * self.config.world_size
165167
# num_nodes = self.config.world_size // self.config.gpu_per_node
166168

167169
# even_indices = (
168170
# torch.arange(
169-
# self.config.max_num_inp_token_per_rank
171+
# max_num_token
170172
# * self.config.num_experts_per_token,
171173
# device="cuda",
172174
# ).view(
173-
# self.config.max_num_inp_token_per_rank,
175+
# max_num_token,
174176
# self.config.num_experts_per_token,
175177
# )
176178
# % 256
@@ -420,7 +422,10 @@ def test_dispatch_combine(self):
420422
for i in range(5000):
421423
if self.rank == 0:
422424
print(f"Round {i} begin")
423-
test_data = self.gen_test_data(use_max_token_num=False)
425+
test_data = self.gen_test_data(
426+
max_num_token=self.config.max_num_inp_token_per_rank,
427+
use_max_token_num=False,
428+
)
424429
if self.rank == 0:
425430
print(f"Round {i} gen test_data done")
426431
self.run_test_once(op, test_data, error_round, i)
@@ -443,7 +448,11 @@ def stress_dispatch_combine(self):
443448
if self.rank == 0:
444449
print("Stress Test")
445450
test_data_list = [
446-
self.gen_test_data(use_max_token_num=False) for i in range(num_test_data)
451+
self.gen_test_data(
452+
max_num_token=self.config.max_num_inp_token_per_rank,
453+
use_max_token_num=False,
454+
)
455+
for i in range(num_test_data)
447456
]
448457
for i in tqdm(range(5000)):
449458
(
@@ -480,7 +489,10 @@ def stress_dispatch_combine(self):
480489

481490
if self.rank == 0:
482491
print("Stress Test with CUDA Graph")
483-
test_data = self.gen_test_data(use_max_token_num=False)
492+
test_data = self.gen_test_data(
493+
max_num_token=self.config.max_num_inp_token_per_rank,
494+
use_max_token_num=False,
495+
)
484496
(
485497
all_rank_num_token,
486498
all_rank_indices,
@@ -520,7 +532,7 @@ def stress_dispatch_combine(self):
520532

521533
del op
522534

523-
def run_bench_once(self, op, test_data, repeat=10):
535+
def run_bench_once(self, max_num_token, op, test_data, repeat=10):
524536
num_events = 2 * repeat + 1
525537
events = [torch.cuda.Event(enable_timing=True) for i in range(num_events)]
526538

@@ -559,9 +571,7 @@ def run_bench_once(self, op, test_data, repeat=10):
559571
)
560572
torch.cuda.synchronize()
561573

562-
total_rdma_recv_num_token = (
563-
self.config.max_num_inp_token_per_rank * self.config.world_size // 8
564-
)
574+
total_rdma_recv_num_token = max_num_token * self.config.world_size // 8
565575
print(
566576
f"rank {self.rank} recv {total_recv_num_token} tokens {total_rdma_recv_num_token} rdma tokens"
567577
)
@@ -598,7 +608,7 @@ def run_bench_once(self, op, test_data, repeat=10):
598608
element_size = all_rank_input[self.rank].element_size()
599609
total_bytes = total_recv_num_token * self.config.hidden_dim * element_size
600610
ll_mode_scale = (
601-
self.config.max_num_inp_token_per_rank
611+
max_num_token
602612
* self.config.num_experts_per_token
603613
/ (total_recv_num_token + 1) # avoid division by zero
604614
)
@@ -635,9 +645,11 @@ def run_bench_once(self, op, test_data, repeat=10):
635645
ll_mode_scale,
636646
)
637647

638-
def bench_dispatch_combine(self):
648+
def bench_dispatch_combine(self, max_num_token):
639649
op = mori.ops.EpDispatchCombineOp(self.config)
640-
test_data = self.gen_test_data(use_max_token_num=True)
650+
test_data = self.gen_test_data(
651+
max_num_token=max_num_token, use_max_token_num=True
652+
)
641653

642654
repeat = 50
643655
disp_duration_us_list = []
@@ -664,7 +676,7 @@ def bench_dispatch_combine(self):
664676
comb_rdma_bandwidth,
665677
comb_bandwidth,
666678
ll_mode_scale,
667-
) = self.run_bench_once(op, test_data, repeat)
679+
) = self.run_bench_once(max_num_token, op, test_data, repeat)
668680

669681
for i in range(repeat):
670682
disp_duration_output = [torch.zeros(1) for _ in range(self.world_size)]
@@ -821,14 +833,29 @@ def collect_metrics(per_round_data):
821833

822834
del op
823835

836+
return (disp_bw, disp_rdma_bw, disp_ll_bw, disp_lat), (
837+
comb_bw,
838+
comb_rdma_bw,
839+
comb_ll_bw,
840+
comb_lat,
841+
)
824842

825-
def test_dispatch_combine(
826-
local_rank, num_node, gpu_per_node, max_tokens, kernel_type, num_qp, cmd="test"
843+
844+
def sweep_bench_dispatch_combine(
845+
local_rank,
846+
num_node,
847+
gpu_per_node,
848+
max_tokens,
849+
kernel_type,
850+
num_qp,
851+
sweep_token_interval,
827852
):
828853
world_size = num_node * gpu_per_node
829854
node_rank = int(os.environ["RANK"])
830855
global_rank = node_rank * gpu_per_node + local_rank
831-
856+
sweep_token_interval = int(sweep_token_interval)
857+
if sweep_token_interval <= 0:
858+
raise ValueError(f"sweep_token_interval must >= 1, got {sweep_token_interval}")
832859
test_case = EpDispatchCombineTestCase(
833860
global_rank,
834861
gpu_per_node,
@@ -840,32 +867,122 @@ def test_dispatch_combine(
840867
# torch.float8_e4m3fnuz,
841868
)
842869
test_case.setup()
843-
if cmd == "test":
844-
test_case.test_dispatch_combine()
845-
elif cmd == "bench":
846-
test_case.bench_dispatch_combine()
847-
elif cmd == "stress":
848-
test_case.stress_dispatch_combine()
870+
871+
num_iters = (max_tokens + sweep_token_interval - 1) // sweep_token_interval
872+
max_token_list = [i * sweep_token_interval for i in range(num_iters)]
873+
874+
disp_lat_min_list = []
875+
disp_lat_max_list = []
876+
comb_lat_min_list = []
877+
comb_lat_max_list = []
878+
for max_token in max_token_list:
879+
if max_token == 0:
880+
max_token = 1
881+
disp_stats, comb_stats = test_case.bench_dispatch_combine(max_token)
882+
disp_bw, disp_rdma_bw, disp_ll_bw, disp_lat = disp_stats
883+
comb_bw, comb_rdma_bw, comb_ll_bw, comb_lat = comb_stats
884+
885+
disp_lat_min_list.append(disp_lat[0])
886+
comb_lat_min_list.append(comb_lat[0])
887+
disp_lat_max_list.append(disp_lat[1])
888+
comb_lat_max_list.append(comb_lat[1])
889+
890+
if local_rank == 0:
891+
import matplotlib.pyplot as plt
892+
893+
plt.figure()
894+
# plt.plot(max_token_list, disp_lat_min_list, label='Dispatch Min')
895+
# plt.plot(max_token_list, comb_lat_min_list, label='Combine Min')
896+
# plt.plot(max_token_list, disp_lat_max_list, label='Dispatch Max')
897+
# plt.plot(max_token_list, comb_lat_max_list, label='Combine Max')
898+
plt.plot(
899+
max_token_list,
900+
[max - min for max, min in zip(disp_lat_max_list, disp_lat_min_list)],
901+
label="Dispatch Max-Min",
902+
)
903+
plt.plot(
904+
max_token_list,
905+
[max - min for max, min in zip(comb_lat_max_list, comb_lat_min_list)],
906+
label="Combine Max-Min",
907+
)
908+
plt.xticks([i * 16 for i in range(max_tokens // 16)])
909+
plt.title("Dispatch / Combine Max-Min Latency (us)")
910+
plt.xlabel("# of Tokens")
911+
plt.ylabel("Latency (us)")
912+
plt.grid(True)
913+
plt.legend()
914+
plt.tight_layout()
915+
plt.savefig("dispatch_combine_perf_maxmin.png", dpi=300, bbox_inches="tight")
916+
test_case.cleanup()
917+
918+
919+
def test_dispatch_combine(
920+
local_rank,
921+
num_node,
922+
gpu_per_node,
923+
max_tokens,
924+
kernel_type,
925+
num_qp,
926+
cmd="test",
927+
sweep_token_interval=64,
928+
):
929+
world_size = num_node * gpu_per_node
930+
node_rank = int(os.environ["RANK"])
931+
global_rank = node_rank * gpu_per_node + local_rank
932+
933+
if cmd in ("test", "bench", "stress"):
934+
test_case = EpDispatchCombineTestCase(
935+
global_rank,
936+
gpu_per_node,
937+
world_size,
938+
max_tokens,
939+
kernel_type,
940+
num_qp,
941+
torch.bfloat16,
942+
# torch.float8_e4m3fnuz,
943+
)
944+
test_case.setup()
945+
if cmd == "test":
946+
test_case.test_dispatch_combine()
947+
elif cmd == "bench":
948+
test_case.bench_dispatch_combine(max_tokens)
949+
elif cmd == "stress":
950+
test_case.stress_dispatch_combine()
951+
test_case.cleanup()
952+
elif cmd == "sweep_bench":
953+
sweep_bench_dispatch_combine(
954+
local_rank,
955+
num_node,
956+
gpu_per_node,
957+
max_tokens,
958+
kernel_type,
959+
num_qp,
960+
sweep_token_interval,
961+
)
849962
else:
850963
raise ValueError(f"unsupported command: {cmd}")
851964

852-
test_case.cleanup()
853-
854965

855966
parser = argparse.ArgumentParser(description="dispatch/combine internode test")
856967
parser.add_argument(
857968
"--cmd",
858969
type=str,
859970
default="test",
860-
choices=["test", "bench", "stress"],
861-
help="Available subcommands: test, bench, stress",
971+
choices=["test", "bench", "stress", "sweep_bench"],
972+
help="Available subcommands: test, bench, stress, sweep_bench",
862973
)
863974
parser.add_argument(
864975
"--max-tokens",
865976
type=int,
866977
default=4096,
867978
help="Maximum number of input tokens per rank (default: 4096)",
868979
)
980+
parser.add_argument(
981+
"--sweep-token-interval",
982+
type=int,
983+
default=2,
984+
help="Number of token interval when sweep bench",
985+
)
869986
parser.add_argument(
870987
"--kernel-type",
871988
type=str,
@@ -896,6 +1013,7 @@ def test_dispatch_combine(
8961013
args_cli.kernel_type,
8971014
args_cli.num_qp,
8981015
args_cli.cmd,
1016+
args_cli.sweep_token_interval,
8991017
),
9001018
nprocs=gpu_per_node,
9011019
join=True,

include/mori/ops/dispatch_combine/dispatch_combine.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class EpDispatchCombineHandle {
165165
public:
166166
// Number of tokens on this rank and size of scale data type, updated at each round of inference
167167
index_t curRankNumToken{0};
168+
index_t multiProcessorCount{0};
168169

169170
public:
170171
// Config

include/mori/utils/hip_helper.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright © Advanced Micro Devices, Inc. All rights reserved.
2+
//
3+
// MIT License
4+
//
5+
// Permission is hereby granted, free of charge, to any person obtaining a copy
6+
// of this software and associated documentation files (the "Software"), to deal
7+
// in the Software without restriction, including without limitation the rights
8+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
// copies of the Software, and to permit persons to whom the Software is
10+
// furnished to do so, subject to the following conditions:
11+
//
12+
// The above copyright notice and this permission notice shall be included in all
13+
// copies or substantial portions of the Software.
14+
//
15+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
// SOFTWARE.
22+
#pragma once
23+
24+
#include <hip/hip_runtime.h>
25+
26+
namespace mori {
27+
inline int GetMultiProcessorCount(int device) {
28+
hipDeviceProp_t prop;
29+
HIP_RUNTIME_CHECK(hipGetDeviceProperties(&prop, device));
30+
return prop.multiProcessorCount;
31+
}
32+
33+
inline int GetCurDeviceMultiProcessorCount() {
34+
int device = 0;
35+
HIP_RUNTIME_CHECK(hipGetDevice(&device));
36+
return GetMultiProcessorCount(device);
37+
}
38+
} // namespace mori

0 commit comments

Comments
 (0)