Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@
"f'{AITER_CSRC_DIR}/pybind/mha_bwd_asm_pybind.cu'"
],
"flags_extra_cc": [
"'-DONLY_FAV3=1'"
"'-DFAV3_ON=1'"
],
"flags_extra_hip": [
"'-DCK_TILE_FMHA_FWD_FAST_EXP2=1'"
Expand All @@ -939,7 +939,7 @@
"f'{AITER_CSRC_DIR}/pybind/mha_varlen_bwd_asm_pybind.cu'"
],
"flags_extra_cc": [
"'-DONLY_FAV3=1'"
"'-DFAV3_ON=1'"
],
"flags_extra_hip": [
"'-DCK_TILE_FMHA_FWD_FAST_EXP2=1'"
Expand Down Expand Up @@ -982,7 +982,10 @@
"f'{AITER_CSRC_DIR}/cpp_itfs/mha_bwd.cu'",
"f'{AITER_CSRC_DIR}/pybind/mha_bwd_pybind.cu'"
],
"flags_extra_cc": [],
"flags_extra_cc": [
"'-DFAV3_ON=1'",
"'-DFAV2_ON=1'"
],
"flags_extra_hip": [
"'-DCK_TILE_FMHA_FWD_FAST_EXP2=1'",
"f'-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get(\"CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT\", 2)}'"
Expand All @@ -1005,7 +1008,10 @@
"f'{AITER_CSRC_DIR}/cpp_itfs/mha_bwd.cu'",
"f'{AITER_CSRC_DIR}/pybind/mha_varlen_bwd_pybind.cu'"
],
"flags_extra_cc": [],
"flags_extra_cc": [
"'-DFAV3_ON=1'",
"'-DFAV2_ON=1'"
],
"flags_extra_hip": [
"'-DCK_TILE_FMHA_FWD_FAST_EXP2=1'",
"f'-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get(\"CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT\", 2)}'"
Expand All @@ -1027,7 +1033,9 @@
"f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_split.cu'",
"f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_batch_prefill.cu'"
],
"flags_extra_cc": [],
"flags_extra_cc": [
"'-DFAV2_ON=1'"
],
"flags_extra_hip": [
"'-DCK_TILE_FMHA_FWD_FAST_EXP2=1'",
"f'-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get(\"CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT\", 2)}'",
Expand All @@ -1051,7 +1059,9 @@
"srcs": [
"f'{AITER_CSRC_DIR}/cpp_itfs/mha_bwd.cu'"
],
"flags_extra_cc": [],
"flags_extra_cc": [
"'-DFAV2_ON=1'"
],
"flags_extra_hip": [
"f'-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get(\"CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT\", 0)}'"
],
Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def cmdGenFunc_mha_bwd(
return {
"md_name": md_name,
"blob_gen_cmd": blob_gen_cmd,
"flags_extra_cc": ["'-DONLY_FAV3=0'"],
"flags_extra_cc": ["'-DFAV3_ON=1'", "'-DFAV2_ON=1'"],
}


Expand Down Expand Up @@ -929,7 +929,7 @@ def cmdGenFunc_mha_varlen_bwd(
return {
"md_name": md_name,
"blob_gen_cmd": blob_gen_cmd,
"flags_extra_cc": ["'-DONLY_FAV3=0'"],
"flags_extra_cc": ["'-DFAV3_ON=1'", "'-DFAV2_ON=1'"],
}


Expand Down
25 changes: 21 additions & 4 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
#include "mha_bwd.h"
#include "aiter_hip_common.h"
#include "asm_fmha_v3_bwd_configs.hpp"
#include "mha_bwd.h"
#include <memory>
#include <string>

#if !defined(FAV3_ON) && !defined(FAV2_ON)
#define FAV3_ON 1
#define FAV2_ON 1
#endif

#if FAV3_ON
#include "asm_fmha_v3_bwd_configs.hpp"
#endif

namespace aiter {
#if FAV3_ON
std::tuple<int, int> get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id)
{
if(hdim_q == 192 && hdim_v == 128 && arch_id == "gfx950")
Expand Down Expand Up @@ -125,10 +134,16 @@ std::tuple<std::string, std::string, std::string> get_heuristic_kernel(std::stri
return std::make_tuple(preProcessingKernelName, dQdKdVKernelName, postProcessingKernelName);
}

#endif

float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
{
#if FAV3_ON
float asm_ret = fmha_v3_bwd(a, s);
#if ONLY_FAV3
#else
float asm_ret = -1;
#endif
#if !FAV2_ON
return asm_ret;
#else
fmha_bwd_traits traits{a.hdim_q,
Expand Down Expand Up @@ -226,14 +241,15 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
/* drop_seed_offset */ a.drop_seed_offset,
};

if(asm_ret == -1)
if(asm_ret == -1 && !a.v3_api_check)
{
return fmha_bwd(traits, ck_args, s);
}
return asm_ret;
#endif
}

#if FAV3_ON
float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
{
std::string arch_id = get_gpu_arch();
Expand Down Expand Up @@ -534,5 +550,6 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
[=](const ck_tile::stream_config& s_) { dqdkdv_kernel_launch(); },
[=](const ck_tile::stream_config& s_) { post_kernel_launch(); });
}
#endif

} // namespace aiter
12 changes: 9 additions & 3 deletions csrc/cpp_itfs/mha_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#include "mha_fwd.h"
#include "aiter_hip_common.h"
#include "mha_fwd.h"
#include <memory>
#include <string>

#if !defined(FAV3_ON) && !defined(FAV2_ON)
#define FAV3_ON 1
#define FAV2_ON 1
#endif

#if FAV3_ON
#include "asm_fmha_v3_fwd_configs.hpp"
#endif
#include <memory>
#include <string>

namespace aiter {
#if FAV3_ON
Expand Down
2 changes: 1 addition & 1 deletion op_tests/cpp/mha/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def cmdGenFunc_mha_bwd(ck_exclude: bool):
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd --receipt 600 --output_dir {{}}",
]
blob_gen_cmd.extend(BWD_CODEGEN_CMD)
flags_extra_cc = ["-DONLY_FAV3"] if ck_exclude else []
flags_extra_cc = ["-DFAV3_ON=1"] if ck_exclude else ["-DFAV3_ON=1", "-DFAV2_ON=1"]
return {
"md_name": "libmha_bwd",
"blob_gen_cmd": blob_gen_cmd,
Expand Down