Skip to content

Commit cf385ec

Browse files
Remove uses of AITER_ASM_DIR
Embed code objects into binary. Use hipRegisterFatBinary to make it seamlessly work on multiple gpus. Make CFG tables read-only and AiterAsmKernels statically allocated.
1 parent 1c564c2 commit cf385ec

26 files changed

+1271
-1079
lines changed

aiter/jit/optCompilerConfig.json

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170
"extra_ldflags": "None",
171171
"extra_include": [],
172172
"verbose": "False",
173-
"blob_gen_cmd": "''"
173+
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py --code-objects -m custom_all_reduce --glob=all_reduce*.co --output_dir {{}}'"
174174
},
175175
"module_quick_all_reduce": {
176176
"srcs": [
@@ -397,7 +397,7 @@
397397
"extra_ldflags": "None",
398398
"extra_include": [],
399399
"verbose": "False",
400-
"blob_gen_cmd": "''"
400+
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py --code-objects -m flatmm_fp8gemm_blockscale --output_dir {{}}'"
401401
},
402402
"module_gemm_a8w8_blockscale_bpreshuffle_asm": {
403403
"srcs": [
@@ -421,7 +421,7 @@
421421
"extra_ldflags": "None",
422422
"extra_include": [],
423423
"verbose": "False",
424-
"blob_gen_cmd": "''"
424+
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py --code_objects -m fp8gemm_blockscale_mi350 --glob=f8_block_scale_mi350*.co --output_dir {{}}'"
425425
},
426426
"module_moe_asm": {
427427
"srcs": [
@@ -444,7 +444,8 @@
444444
"blob_gen_cmd": [
445445
"f'{AITER_META_DIR}/hsa/codegen.py -m fmoe_2stages --output_dir {{}}'",
446446
"f'{AITER_META_DIR}/hsa/codegen.py -m fmoe --output_dir {{}}'",
447-
"f'{AITER_META_DIR}/hsa/codegen.py -m topksoftmax --output_dir {{}}'"
447+
"f'{AITER_META_DIR}/hsa/codegen.py -m topksoftmax --output_dir {{}}'",
448+
"f'{AITER_META_DIR}/hsa/codegen.py --code-objects -m fmoe --glob=fmoe_*16.co --glob=fmoe/gelu/fmoe_int8_g1u0_subGU_*.co --glob=fmoe/silu/fmoe_int8_g1u0_subGU_*.co --glob=fmoe_int4fp8_g1u1_subGU_*.co --glob=fmoe_int8_g1u0_smf.co --output_dir {{}}'"
448449
]
449450
},
450451
"module_moe_ck2stages": {
@@ -535,7 +536,10 @@
535536
"f'{CK_DIR}/example/ck_tile/02_layernorm2d'"
536537
],
537538
"verbose": "False",
538-
"blob_gen_cmd": "f'{CK_DIR}/example/ck_tile/02_layernorm2d/generate.py --api fwd --gen_blobs --working_path {{}}'"
539+
"blob_gen_cmd": [
540+
"f'{CK_DIR}/example/ck_tile/02_layernorm2d/generate.py --api fwd --gen_blobs --working_path {{}}'",
541+
"f'{AITER_META_DIR}/hsa/codegen.py --code-objects -m norm --glob=layer_norm.co --glob=layer_norm_qnt.co --output_dir {{}}'"
542+
]
539543
},
540544
"module_pos_encoding": {
541545
"srcs": [
@@ -1130,7 +1134,10 @@
11301134
"extra_ldflags": "None",
11311135
"extra_include": [],
11321136
"verbose": "False",
1133-
"blob_gen_cmd": "''"
1137+
"blob_gen_cmd": [
1138+
"f'{AITER_META_DIR}/hsa/codegen.py --code-objects -m topk_per_row_decode --output_dir {{}}'",
1139+
"f'{AITER_META_DIR}/hsa/codegen.py --code-objects -m topk_per_row_prefill --output_dir {{}}'"
1140+
]
11341141
},
11351142
"module_mla_metadata": {
11361143
"srcs": [

csrc/cpp_itfs/mha_bwd.cu

Lines changed: 63 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
#include <string>
66

77
namespace aiter {
8-
std::tuple<int, int> get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id)
8+
std::tuple<int, int> get_padded_hdim(int hdim_q, int hdim_v, GPUArchId arch_id)
99
{
10-
if(hdim_q == 192 && hdim_v == 128 && arch_id == "gfx950")
10+
if(hdim_q == 192 && hdim_v == 128 && arch_id == GPUArchId::gfx950)
1111
return std::make_tuple(hdim_q, hdim_v);
1212
assert(hdim_q == hdim_v);
1313
if(hdim_q <= 64)
@@ -27,53 +27,44 @@ std::tuple<int, int> get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id
2727
return std::make_tuple(hdim_q, hdim_v);
2828
}
2929

30-
std::tuple<std::string, std::string, std::string> get_heuristic_kernel(std::string data_type,
31-
std::string arch_id,
32-
int seqlen_q,
33-
int seqlen_k,
34-
int hdim_q,
35-
int hdim_v,
36-
int mask_type,
37-
bool atomic32,
38-
int bf16_cvt,
39-
bool mode,
40-
CFG* pre_cfgs,
41-
CFG* cfgs,
42-
CFG* post_cfgs)
30+
std::tuple<const CFG::Entry*, const CFG::Entry*, const CFG::Entry*>
31+
get_heuristic_kernel(std::string data_type,
32+
GPUArchId arch_id,
33+
int seqlen_q,
34+
int seqlen_k,
35+
int hdim_q,
36+
int hdim_v,
37+
int mask_type,
38+
bool atomic32,
39+
int bf16_cvt,
40+
bool mode,
41+
const CFG* pre_cfgs,
42+
const CFG* cfgs,
43+
const CFG* post_cfgs)
4344
{
4445
auto [padded_hdim_q, padded_hdim_v] = get_padded_hdim(hdim_q, hdim_v, arch_id);
4546
int pddv = (padded_hdim_q != hdim_q) || (padded_hdim_v != hdim_v);
4647
int pssk;
4748
int ts_kv = 0;
4849

49-
std::string preProcessingKernelName = "";
50-
std::string dQdKdVKernelName = "";
51-
std::string postProcessingKernelName = "";
50+
const CFG::Entry* preProcessingCfg = nullptr;
51+
const CFG::Entry* dQdKdVCfg = nullptr;
52+
const CFG::Entry* postProcessingCfg = nullptr;
5253

53-
for(const auto& el : *pre_cfgs)
54+
for(const auto& cfg : pre_cfgs->get_configs_for_arch(arch_id))
5455
{
55-
if(el.first.find(arch_id) != 0)
56-
continue;
57-
const auto& cfg = el.second;
58-
5956
if((cfg.dtype == data_type) && (cfg.hdim_v == padded_hdim_v) && (cfg.mode == mode))
6057
{
61-
preProcessingKernelName = el.first;
58+
preProcessingCfg = &cfg;
6259
break;
6360
}
6461
}
6562

66-
for(const auto& el : *cfgs)
63+
for(const auto& cfg : cfgs->get_configs_for_arch(arch_id))
6764
{
68-
if(el.first.find(arch_id) != 0)
69-
{
70-
continue;
71-
}
72-
const auto& cfg = el.second;
73-
7465
if((cfg.dtype == data_type) && (cfg.hdim_q == padded_hdim_q) &&
7566
(cfg.hdim_v == padded_hdim_v) && (cfg.mask == mask_type) && (cfg.atomic32 == atomic32) &&
76-
((arch_id == "gfx950") || ((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt))) &&
67+
((arch_id == GPUArchId::gfx950) || ((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt))) &&
7768
(cfg.mode == mode))
7869
{
7970
int tmp_ts_kv = 0;
@@ -82,7 +73,7 @@ std::tuple<std::string, std::string, std::string> get_heuristic_kernel(std::stri
8273
ts_kv = cfg.ts;
8374
tmp_ts_kv = ts_kv;
8475
if(cfg.atomic32 == 0 &&
85-
((arch_id == "gfx942") || (el.first.find("recompile") != std::string::npos)))
76+
((arch_id == GPUArchId::gfx942) || (cfgs->get_kernel_name_for_config(&cfg).find("recompile") != std::string::npos)))
8677
{
8778

8879
tmp_ts_kv = 64;
@@ -91,38 +82,34 @@ std::tuple<std::string, std::string, std::string> get_heuristic_kernel(std::stri
9182
}
9283
if((cfg.pssk == pssk) && (cfg.pddv == pddv))
9384
{
94-
dQdKdVKernelName = el.first;
85+
dQdKdVCfg = &cfg;
9586
break;
9687
}
9788
else if((cfg.pssk >= pssk) && (cfg.pddv >= pddv))
9889
{
99-
dQdKdVKernelName = el.first;
90+
dQdKdVCfg = &cfg;
10091
}
10192
}
10293
}
10394

10495
if(!post_cfgs)
10596
{
106-
return std::make_tuple(preProcessingKernelName, dQdKdVKernelName, postProcessingKernelName);
97+
return std::make_tuple(preProcessingCfg, dQdKdVCfg, nullptr);
10798
}
10899

109-
for(const auto& el : *post_cfgs)
100+
for(const auto& cfg : post_cfgs->get_configs_for_arch(arch_id))
110101
{
111-
if(el.first.find(arch_id) != 0)
112-
continue;
113-
const auto& cfg = el.second;
114-
115102
if((cfg.hdim_q == padded_hdim_q) && (cfg.mode == mode) &&
116-
((arch_id == "gfx950") || ((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt))))
103+
((arch_id == GPUArchId::gfx950) || ((data_type == "fp16") || (cfg.bf16_cvt == bf16_cvt))))
117104
{
118105
if((cfg.dtype == data_type) || (atomic32 == 0))
119106
{
120-
postProcessingKernelName = el.first;
107+
postProcessingCfg = &cfg;
121108
break;
122109
}
123110
}
124111
}
125-
return std::make_tuple(preProcessingKernelName, dQdKdVKernelName, postProcessingKernelName);
112+
return std::make_tuple(preProcessingCfg, dQdKdVCfg, postProcessingCfg);
126113
}
127114

128115
float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
@@ -236,19 +223,19 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
236223

237224
float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
238225
{
239-
std::string arch_id = get_gpu_arch();
226+
auto arch_id = get_gpu_arch();
240227

241228
if((!a.use_asm_v3) || (a.hdim_q % 8 != 0) || (a.hdim_v % 8 != 0) || (a.has_dbias) ||
242229
(a.bias_type != 0) || (a.has_dropout) || (a.is_deterministic) ||
243-
((arch_id != "gfx942") && (arch_id != "gfx950")))
230+
((arch_id != GPUArchId::gfx942) && (arch_id != GPUArchId::gfx950)))
244231
{
245232
return -1;
246233
}
247234

248235
auto pre_cfgs = &cfg_fmha_bwd_odo;
249236
auto dqdkdv_cfgs = &cfg_fmha_bwd_dqdkdv;
250237
auto post_cfgs = [&]() {
251-
if(arch_id == "gfx950")
238+
if(arch_id == GPUArchId::gfx950)
252239
{
253240
if(a.v3_atomic_fp32)
254241
{
@@ -267,29 +254,29 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
267254
}
268255
else
269256
{
270-
return static_cast<CFG*>(nullptr);
257+
return static_cast<const CFG*>(nullptr);
271258
}
272259
}
273260
}();
274261

275262
bool need_post_processing =
276-
((arch_id == "gfx950") && (a.hdim_q != 64)) || (a.v3_atomic_fp32 == 1);
277-
278-
auto [pre_kernel, dqdkdv_kernel, post_kernel] = get_heuristic_kernel(a.data_type,
279-
arch_id,
280-
a.seqlen_q,
281-
a.seqlen_k,
282-
a.hdim_q,
283-
a.hdim_v,
284-
a.mask_type,
285-
a.v3_atomic_fp32,
286-
a.v3_bf16_cvt,
287-
a.is_group_mode,
288-
pre_cfgs,
289-
dqdkdv_cfgs,
290-
post_cfgs);
291-
292-
if((pre_kernel == "") || (dqdkdv_kernel == "") || (need_post_processing && (post_kernel == "")))
263+
((arch_id == GPUArchId::gfx950) && (a.hdim_q != 64)) || (a.v3_atomic_fp32 == 1);
264+
265+
auto [pre_cfg, dqdkdv_cfg, post_cfg] = get_heuristic_kernel(a.data_type,
266+
arch_id,
267+
a.seqlen_q,
268+
a.seqlen_k,
269+
a.hdim_q,
270+
a.hdim_v,
271+
a.mask_type,
272+
a.v3_atomic_fp32,
273+
a.v3_bf16_cvt,
274+
a.is_group_mode,
275+
pre_cfgs,
276+
dqdkdv_cfgs,
277+
post_cfgs);
278+
279+
if((pre_cfg == nullptr) || (dqdkdv_cfg == nullptr) || (need_post_processing && (post_cfg == nullptr)))
293280
{
294281
return -1;
295282
}
@@ -298,76 +285,18 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
298285
int ts_kv;
299286
int ts_dq;
300287
int arg_size;
301-
302-
AiterAsmKernel* impl_ptr_pre = nullptr;
303-
AiterAsmKernel* impl_ptr_dqdkdv = nullptr;
304-
AiterAsmKernel* impl_ptr_post = nullptr;
305-
static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map;
306-
307-
auto it_pre = pre_cfgs->find(pre_kernel);
308-
if(it_pre != pre_cfgs->end())
309-
{
310-
const auto& cfg = it_pre->second;
311-
const char* name = cfg.knl_name.c_str();
312-
const char* co_name = cfg.co_name.c_str();
313-
ts_odo = cfg.ts;
314-
315-
auto result = impl_ptr_map.emplace(name, nullptr);
316-
if(result.second)
317-
{
318-
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
319-
}
320-
321-
impl_ptr_pre = result.first->second.get();
322-
}
323-
else
324-
{
325-
return -1;
326-
}
327-
328-
auto it_dqdkdv = dqdkdv_cfgs->find(dqdkdv_kernel);
329-
if(it_dqdkdv != dqdkdv_cfgs->end())
330-
{
331-
const auto& cfg = it_dqdkdv->second;
332-
const char* name = cfg.knl_name.c_str();
333-
const char* co_name = cfg.co_name.c_str();
334-
ts_kv = cfg.ts;
335-
336-
auto result = impl_ptr_map.emplace(name, nullptr);
337-
if(result.second)
338-
{
339-
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
340-
}
341-
342-
impl_ptr_dqdkdv = result.first->second.get();
343-
}
344-
else
345-
{
346-
return -1;
347-
}
348-
349-
if(need_post_processing)
288+
AiterAsmKernel<>* impl_ptr_pre = nullptr;
289+
AiterAsmKernel<>* impl_ptr_dqdkdv = nullptr;
290+
AiterAsmKernel<>* impl_ptr_post = nullptr;
291+
292+
impl_ptr_pre = pre_cfgs->load_kernel_for_config(pre_cfg);
293+
ts_odo = pre_cfg->ts;
294+
impl_ptr_dqdkdv = dqdkdv_cfgs->load_kernel_for_config(dqdkdv_cfg);
295+
ts_kv = dqdkdv_cfg->ts;
296+
if(post_cfg != nullptr)
350297
{
351-
auto it_post = post_cfgs->find(post_kernel);
352-
if(it_post != post_cfgs->end())
353-
{
354-
const auto& cfg = it_post->second;
355-
const char* name = cfg.knl_name.c_str();
356-
const char* co_name = cfg.co_name.c_str();
357-
ts_dq = cfg.ts;
358-
359-
auto result = impl_ptr_map.emplace(name, nullptr);
360-
if(result.second)
361-
{
362-
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
363-
}
364-
365-
impl_ptr_post = result.first->second.get();
366-
}
367-
else
368-
{
369-
return -1;
370-
}
298+
impl_ptr_post = post_cfgs->load_kernel_for_config(post_cfg);
299+
ts_dq = post_cfg->ts;
371300
}
372301

373302
if(a.v3_api_check)

0 commit comments

Comments
 (0)