55#include < string>
66
77namespace aiter {
8- std::tuple<int , int > get_padded_hdim (int hdim_q, int hdim_v, std::string arch_id)
8+ std::tuple<int , int > get_padded_hdim (int hdim_q, int hdim_v, GPUArchId arch_id)
99{
10- if (hdim_q == 192 && hdim_v == 128 && arch_id == " gfx950" )
10+ if (hdim_q == 192 && hdim_v == 128 && arch_id == GPUArchId:: gfx950)
1111 return std::make_tuple (hdim_q, hdim_v);
1212 assert (hdim_q == hdim_v);
1313 if (hdim_q <= 64 )
@@ -27,53 +27,44 @@ std::tuple<int, int> get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id
2727 return std::make_tuple (hdim_q, hdim_v);
2828}
2929
30- std::tuple<std::string, std::string, std::string> get_heuristic_kernel (std::string data_type,
31- std::string arch_id,
32- int seqlen_q,
33- int seqlen_k,
34- int hdim_q,
35- int hdim_v,
36- int mask_type,
37- bool atomic32,
38- int bf16_cvt,
39- bool mode,
40- CFG* pre_cfgs,
41- CFG* cfgs,
42- CFG* post_cfgs)
30+ std::tuple<const CFG::Entry*, const CFG::Entry*, const CFG::Entry*>
31+ get_heuristic_kernel (std::string data_type,
32+ GPUArchId arch_id,
33+ int seqlen_q,
34+ int seqlen_k,
35+ int hdim_q,
36+ int hdim_v,
37+ int mask_type,
38+ bool atomic32,
39+ int bf16_cvt,
40+ bool mode,
41+ const CFG* pre_cfgs,
42+ const CFG* cfgs,
43+ const CFG* post_cfgs)
4344{
4445 auto [padded_hdim_q, padded_hdim_v] = get_padded_hdim (hdim_q, hdim_v, arch_id);
4546 int pddv = (padded_hdim_q != hdim_q) || (padded_hdim_v != hdim_v);
4647 int pssk;
4748 int ts_kv = 0 ;
4849
49- std::string preProcessingKernelName = " " ;
50- std::string dQdKdVKernelName = " " ;
51- std::string postProcessingKernelName = " " ;
50+ const CFG::Entry* preProcessingCfg = nullptr ;
51+ const CFG::Entry* dQdKdVCfg = nullptr ;
52+ const CFG::Entry* postProcessingCfg = nullptr ;
5253
53- for (const auto & el : * pre_cfgs)
54+ for (const auto & cfg : pre_cfgs-> get_configs_for_arch (arch_id) )
5455 {
55- if (el.first .find (arch_id) != 0 )
56- continue ;
57- const auto & cfg = el.second ;
58-
5956 if ((cfg.dtype == data_type) && (cfg.hdim_v == padded_hdim_v) && (cfg.mode == mode))
6057 {
61- preProcessingKernelName = el. first ;
58+ preProcessingCfg = &cfg ;
6259 break ;
6360 }
6461 }
6562
66- for (const auto & el : * cfgs)
63+ for (const auto & cfg : cfgs-> get_configs_for_arch (arch_id) )
6764 {
68- if (el.first .find (arch_id) != 0 )
69- {
70- continue ;
71- }
72- const auto & cfg = el.second ;
73-
7465 if ((cfg.dtype == data_type) && (cfg.hdim_q == padded_hdim_q) &&
7566 (cfg.hdim_v == padded_hdim_v) && (cfg.mask == mask_type) && (cfg.atomic32 == atomic32) &&
76- ((arch_id == " gfx950" ) || ((data_type == " fp16" ) || (cfg.bf16_cvt == bf16_cvt))) &&
67+ ((arch_id == GPUArchId:: gfx950) || ((data_type == " fp16" ) || (cfg.bf16_cvt == bf16_cvt))) &&
7768 (cfg.mode == mode))
7869 {
7970 int tmp_ts_kv = 0 ;
@@ -82,7 +73,7 @@ std::tuple<std::string, std::string, std::string> get_heuristic_kernel(std::stri
8273 ts_kv = cfg.ts ;
8374 tmp_ts_kv = ts_kv;
8475 if (cfg.atomic32 == 0 &&
85- ((arch_id == " gfx942" ) || (el. first .find (" recompile" ) != std::string::npos)))
76+ ((arch_id == GPUArchId:: gfx942) || (cfgs-> get_kernel_name_for_config (&cfg) .find (" recompile" ) != std::string::npos)))
8677 {
8778
8879 tmp_ts_kv = 64 ;
@@ -91,38 +82,34 @@ std::tuple<std::string, std::string, std::string> get_heuristic_kernel(std::stri
9182 }
9283 if ((cfg.pssk == pssk) && (cfg.pddv == pddv))
9384 {
94- dQdKdVKernelName = el. first ;
85+ dQdKdVCfg = &cfg ;
9586 break ;
9687 }
9788 else if ((cfg.pssk >= pssk) && (cfg.pddv >= pddv))
9889 {
99- dQdKdVKernelName = el. first ;
90+ dQdKdVCfg = &cfg ;
10091 }
10192 }
10293 }
10394
10495 if (!post_cfgs)
10596 {
106- return std::make_tuple (preProcessingKernelName, dQdKdVKernelName, postProcessingKernelName );
97+ return std::make_tuple (preProcessingCfg, dQdKdVCfg, nullptr );
10798 }
10899
109- for (const auto & el : * post_cfgs)
100+ for (const auto & cfg : post_cfgs-> get_configs_for_arch (arch_id) )
110101 {
111- if (el.first .find (arch_id) != 0 )
112- continue ;
113- const auto & cfg = el.second ;
114-
115102 if ((cfg.hdim_q == padded_hdim_q) && (cfg.mode == mode) &&
116- ((arch_id == " gfx950" ) || ((data_type == " fp16" ) || (cfg.bf16_cvt == bf16_cvt))))
103+ ((arch_id == GPUArchId:: gfx950) || ((data_type == " fp16" ) || (cfg.bf16_cvt == bf16_cvt))))
117104 {
118105 if ((cfg.dtype == data_type) || (atomic32 == 0 ))
119106 {
120- postProcessingKernelName = el. first ;
107+ postProcessingCfg = &cfg ;
121108 break ;
122109 }
123110 }
124111 }
125- return std::make_tuple (preProcessingKernelName, dQdKdVKernelName, postProcessingKernelName );
112+ return std::make_tuple (preProcessingCfg, dQdKdVCfg, postProcessingCfg );
126113}
127114
128115float mha_bwd (mha_bwd_args a, const ck_tile::stream_config& s)
@@ -236,19 +223,19 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
236223
237224float fmha_v3_bwd (mha_bwd_args a, const ck_tile::stream_config& s)
238225{
239- std::string arch_id = get_gpu_arch ();
226+ auto arch_id = get_gpu_arch ();
240227
241228 if ((!a.use_asm_v3 ) || (a.hdim_q % 8 != 0 ) || (a.hdim_v % 8 != 0 ) || (a.has_dbias ) ||
242229 (a.bias_type != 0 ) || (a.has_dropout ) || (a.is_deterministic ) ||
243- ((arch_id != " gfx942" ) && (arch_id != " gfx950" )))
230+ ((arch_id != GPUArchId:: gfx942) && (arch_id != GPUArchId:: gfx950)))
244231 {
245232 return -1 ;
246233 }
247234
248235 auto pre_cfgs = &cfg_fmha_bwd_odo;
249236 auto dqdkdv_cfgs = &cfg_fmha_bwd_dqdkdv;
250237 auto post_cfgs = [&]() {
251- if (arch_id == " gfx950" )
238+ if (arch_id == GPUArchId:: gfx950)
252239 {
253240 if (a.v3_atomic_fp32 )
254241 {
@@ -267,29 +254,29 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
267254 }
268255 else
269256 {
270- return static_cast <CFG*>(nullptr );
257+ return static_cast <const CFG*>(nullptr );
271258 }
272259 }
273260 }();
274261
275262 bool need_post_processing =
276- ((arch_id == " gfx950" ) && (a.hdim_q != 64 )) || (a.v3_atomic_fp32 == 1 );
277-
278- auto [pre_kernel, dqdkdv_kernel, post_kernel ] = get_heuristic_kernel (a.data_type ,
279- arch_id,
280- a.seqlen_q ,
281- a.seqlen_k ,
282- a.hdim_q ,
283- a.hdim_v ,
284- a.mask_type ,
285- a.v3_atomic_fp32 ,
286- a.v3_bf16_cvt ,
287- a.is_group_mode ,
288- pre_cfgs,
289- dqdkdv_cfgs,
290- post_cfgs);
291-
292- if ((pre_kernel == " " ) || (dqdkdv_kernel == " " ) || (need_post_processing && (post_kernel == " " )))
263+ ((arch_id == GPUArchId:: gfx950) && (a.hdim_q != 64 )) || (a.v3_atomic_fp32 == 1 );
264+
265+ auto [pre_cfg, dqdkdv_cfg, post_cfg ] = get_heuristic_kernel (a.data_type ,
266+ arch_id,
267+ a.seqlen_q ,
268+ a.seqlen_k ,
269+ a.hdim_q ,
270+ a.hdim_v ,
271+ a.mask_type ,
272+ a.v3_atomic_fp32 ,
273+ a.v3_bf16_cvt ,
274+ a.is_group_mode ,
275+ pre_cfgs,
276+ dqdkdv_cfgs,
277+ post_cfgs);
278+
279+ if ((pre_cfg == nullptr ) || (dqdkdv_cfg == nullptr ) || (need_post_processing && (post_cfg == nullptr )))
293280 {
294281 return -1 ;
295282 }
@@ -298,76 +285,18 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
298285 int ts_kv;
299286 int ts_dq;
300287 int arg_size;
301-
302- AiterAsmKernel* impl_ptr_pre = nullptr ;
303- AiterAsmKernel* impl_ptr_dqdkdv = nullptr ;
304- AiterAsmKernel* impl_ptr_post = nullptr ;
305- static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map;
306-
307- auto it_pre = pre_cfgs->find (pre_kernel);
308- if (it_pre != pre_cfgs->end ())
309- {
310- const auto & cfg = it_pre->second ;
311- const char * name = cfg.knl_name .c_str ();
312- const char * co_name = cfg.co_name .c_str ();
313- ts_odo = cfg.ts ;
314-
315- auto result = impl_ptr_map.emplace (name, nullptr );
316- if (result.second )
317- {
318- result.first ->second = std::make_unique<AiterAsmKernel>(name, co_name);
319- }
320-
321- impl_ptr_pre = result.first ->second .get ();
322- }
323- else
324- {
325- return -1 ;
326- }
327-
328- auto it_dqdkdv = dqdkdv_cfgs->find (dqdkdv_kernel);
329- if (it_dqdkdv != dqdkdv_cfgs->end ())
330- {
331- const auto & cfg = it_dqdkdv->second ;
332- const char * name = cfg.knl_name .c_str ();
333- const char * co_name = cfg.co_name .c_str ();
334- ts_kv = cfg.ts ;
335-
336- auto result = impl_ptr_map.emplace (name, nullptr );
337- if (result.second )
338- {
339- result.first ->second = std::make_unique<AiterAsmKernel>(name, co_name);
340- }
341-
342- impl_ptr_dqdkdv = result.first ->second .get ();
343- }
344- else
345- {
346- return -1 ;
347- }
348-
349- if (need_post_processing)
288+ AiterAsmKernel<>* impl_ptr_pre = nullptr ;
289+ AiterAsmKernel<>* impl_ptr_dqdkdv = nullptr ;
290+ AiterAsmKernel<>* impl_ptr_post = nullptr ;
291+
292+ impl_ptr_pre = pre_cfgs->load_kernel_for_config (pre_cfg);
293+ ts_odo = pre_cfg->ts ;
294+ impl_ptr_dqdkdv = dqdkdv_cfgs->load_kernel_for_config (dqdkdv_cfg);
295+ ts_kv = dqdkdv_cfg->ts ;
296+ if (post_cfg != nullptr )
350297 {
351- auto it_post = post_cfgs->find (post_kernel);
352- if (it_post != post_cfgs->end ())
353- {
354- const auto & cfg = it_post->second ;
355- const char * name = cfg.knl_name .c_str ();
356- const char * co_name = cfg.co_name .c_str ();
357- ts_dq = cfg.ts ;
358-
359- auto result = impl_ptr_map.emplace (name, nullptr );
360- if (result.second )
361- {
362- result.first ->second = std::make_unique<AiterAsmKernel>(name, co_name);
363- }
364-
365- impl_ptr_post = result.first ->second .get ();
366- }
367- else
368- {
369- return -1 ;
370- }
298+ impl_ptr_post = post_cfgs->load_kernel_for_config (post_cfg);
299+ ts_dq = post_cfg->ts ;
371300 }
372301
373302 if (a.v3_api_check )
0 commit comments