Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def mla_decode_fwd(
if num_kv_splits is None:
num_kv_splits = get_cu_num()
if nhead == 16 or (
nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8
(nhead == 128 or nhead == 32) and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8
):
# Natively support cases
pass
Expand Down Expand Up @@ -307,6 +307,35 @@ def mla_decode_fwd(
else None
)

def print_ptr(name, t: torch.Tensor):
addr = t.data_ptr()
size = t.numel() * t.element_size()
print(
f"{name}: [{hex(addr)}, {hex(addr + size)})",
f" numel={t.numel()}, elem_size={t.element_size()}, shape={t.shape}",
)

print("=== Persistent Mode Tensor Info ===")
for name, t in [
("q", q),
("kv_buffer", kv_buffer),
("qo_indptr", qo_indptr),
("kv_indptr", kv_indptr),
("kv_indices", kv_indices),
("kv_last_page_lens", kv_last_page_lens),
("num_kv_splits_indptr", num_kv_splits_indptr),
("work_meta_data", work_meta_data),
("work_indptr", work_indptr),
("work_info_set", work_info_set),
("logits", logits),
("attn_lse", attn_lse),
("o", o),
]:
if t is not None:
print_ptr(name, t)

print("====================================")

aiter.mla_decode_stage1_asm_fwd(
q,
kv_buffer,
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def get_mla_metadata_info_v1(
max_qo_tiles_per_batch = (
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16
or (num_head_qo == 128 and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8)
or ((num_head_qo == 128 or num_head_qo == 32) and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8)
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
Expand Down
2 changes: 1 addition & 1 deletion aiter/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def run_iters_rotate(num_iters, func, rotate_args):
def run_perftest(
func,
*args,
num_iters=101,
num_iters=2,
num_warmup=2,
testGraph=False,
num_rotate_args=0,
Expand Down
28 changes: 15 additions & 13 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
torch::Tensor& reduce_final_map,
torch::Tensor& reduce_partial_map)
{
constexpr int32_t kPackedQoLenPerWg = 128;
const hipStream_t stream = at::hip::getCurrentHIPStream();

hipDevice_t dev;
Expand All @@ -395,6 +394,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
(kv_dtype == at::ScalarType::Float8_e4m3fnuz || kv_dtype == at::ScalarType::Float8_e4m3fn);

const bool natively_supported = (num_heads == 16) ||
((num_heads == 32) && q_is_fp8 && kv_is_fp8) ||
((num_heads == 128) && q_is_fp8 && kv_is_fp8);

if((natively_supported == false) && (num_heads % 16 == 0))
Expand All @@ -404,7 +404,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
num_batches *= qk_batch_ratio;
}

TORCH_CHECK((num_heads == 16) || (num_heads == 128),
TORCH_CHECK((num_heads == 16) || (num_heads == 128) || ((num_heads == 32) && q_is_fp8 && kv_is_fp8),
__func__,
": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8).")
Expand Down Expand Up @@ -436,15 +436,17 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
params.qk_batch_ratio = qk_batch_ratio;

// launch kernel
MLA_METADATA_DISPATCHER(
max_seqlen_qo * num_heads_per_head_k,
kPackedQoLenPerWg,
params.uni_seqlen_qo,
topk,
dispatch_mla_metadata_v1_2_device<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>(
params,
stream,
max_seqlen_qo,
dev_prop.warpSize,
dev_prop.maxSharedMemoryPerMultiProcessor));
MLA_NUM_HEADS_DISPATCHER(
num_heads_per_head_k,
MLA_METADATA_DISPATCHER(
max_seqlen_qo * num_heads_per_head_k,
kPackedQoLenPerWg,
params.uni_seqlen_qo,
topk,
dispatch_mla_metadata_v1_2_device<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>(
params,
stream,
max_seqlen_qo,
dev_prop.warpSize,
dev_prop.maxSharedMemoryPerMultiProcessor)));
}
21 changes: 21 additions & 0 deletions csrc/kernels/mla/metadata/v1_comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,24 @@ private:
MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \
} \
}

#define MLA_NUM_HEADS_CASE(C_NUM_HEADS, ...) \
case C_NUM_HEADS: \
{ \
constexpr int32_t kPackedQoLenPerWg = C_NUM_HEADS; \
__VA_ARGS__; \
break; \
}

#define MLA_NUM_HEADS_DISPATCHER(NUM_HEADS, ...) \
switch (NUM_HEADS) \
{ \
MLA_NUM_HEADS_CASE(32, __VA_ARGS__); \
MLA_NUM_HEADS_CASE(64, __VA_ARGS__); \
default: \
{ \
constexpr int32_t kPackedQoLenPerWg = 128; \
__VA_ARGS__; \
break; \
} \
}
45 changes: 27 additions & 18 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,24 +173,24 @@ void mla_decode_stage1_asm_fwd(
args.ptr_RP = output.data_ptr(); //final output


// std::cout << "mla args" << std::endl;
// std::cout << "ptr_R: " << args.ptr_R << std::endl;
// std::cout << "ptr_LSE: " << args.ptr_LSE << std::endl;
// std::cout << "ptr_Q: " << args.ptr_Q << std::endl;
// std::cout << "ptr_KV: " << args.ptr_KV << std::endl;
// std::cout << "ptr_LTP: " << args.ptr_LTP << std::endl;
// std::cout << "ptr_LTD: " << args.ptr_LTD << std::endl;
// std::cout << "ptr_LTL: " << args.ptr_LTL << std::endl;
// std::cout << "scalar: " << args.scalar << std::endl;
// std::cout << "s_MQA: " << args.s_MQA << std::endl;
// std::cout << "s_kv_split: " << args.s_kv_split << std::endl;
// std::cout << "s_Q_Bs: " << args.s_Q_Bs << std::endl;
// std::cout << "s_Bs: " << args.s_Bs << std::endl;
// std::cout << "s_log2_plen: " << args.s_log2_plen << std::endl;
// std::cout << "ptr_RP: " << args.ptr_RP << std::endl;
// std::cout << "ptr_QTP: " << args.ptr_QTP << std::endl;
// std::cout << "ptr_STP: " << args.ptr_STP << std::endl;
// std::cout << "out_16_nosplit: " << args.out_16_nosplit << std::endl;
std::cout << "mla args" << std::endl;
std::cout << "ptr_R: " << args.ptr_R << std::endl;
std::cout << "ptr_LSE: " << args.ptr_LSE << std::endl;
std::cout << "ptr_Q: " << args.ptr_Q << std::endl;
std::cout << "ptr_KV: " << args.ptr_KV << std::endl;
std::cout << "ptr_LTP: " << args.ptr_LTP << std::endl;
std::cout << "ptr_LTD: " << args.ptr_LTD << std::endl;
std::cout << "ptr_LTL: " << args.ptr_LTL << std::endl;
std::cout << "scalar: " << args.scalar << std::endl;
std::cout << "s_MQA: " << args.s_MQA << std::endl;
std::cout << "s_kv_split: " << args.s_kv_split << std::endl;
std::cout << "s_Q_Bs: " << args.s_Q_Bs << std::endl;
std::cout << "s_Bs: " << args.s_Bs << std::endl;
std::cout << "s_log2_plen: " << args.s_log2_plen << std::endl;
std::cout << "ptr_RP: " << args.ptr_RP << std::endl;
std::cout << "ptr_QTP: " << args.ptr_QTP << std::endl;
std::cout << "ptr_STP: " << args.ptr_STP << std::endl;
std::cout << "out_16_nosplit: " << args.out_16_nosplit << std::endl;

const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q));
const hipStream_t stream = at::hip::getCurrentHIPStream();
Expand Down Expand Up @@ -287,6 +287,15 @@ void mla_decode_stage1_asm_fwd(
config_max_seqlen_q = 0;
sub_Q = 64;
}
}else if (q_type == "fp8" && kv_type == "fp8"){
if((max_seqlen_q == 4) && persistent){
config_max_seqlen_q = 4;
sub_Q = 128;
} else {
TORCH_CHECK(false, __func__,
": fp8/fp8 with gqa_ratio=32 only supports decode_qlen=4 in persistent mode. "
"Got decode_qlen=", max_seqlen_q, ", persistent=", persistent);
}
}
} else if (gqa_ratio == 64){
if (q_type == "bf16" && kv_type == "bf16"){
Expand Down
Binary file added hsa/gfx950/mla/mla.co
Binary file not shown.
1 change: 1 addition & 0 deletions hsa/gfx950/mla/mla_asm.csv
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ fp8,fp8,1,1,0,1,1,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal1E,mla_pfl_q
fp8,fp8,1,1,0,1,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal0E,mla_pfl_qh192_vh128_m32x8_n128x1_causal0.co
bf16,bf16,32,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m64x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co
bf16,bf16,64,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m64x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_64mx1_16nx1_Coex0_Msk1_QH16.co
fp8,fp8,32,1,4,0,0,mla_kernel_func,mla.co
5 changes: 3 additions & 2 deletions op_tests/test_mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,14 @@ def test_mla(
reduce_partial_map = torch.empty(
reduce_partial_map_size, dtype=reduce_partial_map_type, device="cuda"
)
print("max_seqlen_qo: ", max_seqlen_qo)

meta = aiter.get_mla_metadata_v1(
qo_indptr,
kv_indptr,
nhead // nhead_kv,
nhead_kv,
True,
False,
work_meta_data,
work_info_set,
work_indptr,
Expand Down Expand Up @@ -358,7 +359,7 @@ def test_absorb_decode_fp8():
kv_lora_rank,
qk_rope_head_dim,
dtype=out_dtype,
is_causal=True,
is_causal=False,
q_scale=None,
kv_scale=kv_scale,
)
Expand Down
Loading