Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def mla_decode_fwd(
kv_indices,
kv_last_page_lens,
max_seqlen_q,
page_size=1,
nhead_kv=1,
sm_scale=None, # 1.0 / (qk_head_dim**0.5)
logit_cap=0.0,
num_kv_splits=None, # for experts only!!!
Expand All @@ -168,7 +170,11 @@ def mla_decode_fwd(
):
device = q.device
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape
if kv_buffer.dtype != torch.uint8:
_, _, _, qk_head_dim = kv_buffer.shape
else:
_, _, qk_head_dim = q.shape

if sm_scale is None:
sm_scale = 1.0 / (qk_head_dim**0.5)

Expand Down Expand Up @@ -227,6 +233,8 @@ def mla_decode_fwd(
None,
None,
max_seqlen_q,
page_size,
nhead_kv,
sm_scale,
logits,
attn_lse,
Expand Down Expand Up @@ -319,6 +327,8 @@ def mla_decode_fwd(
work_indptr,
work_info_set,
max_seqlen_q,
page_size,
nhead_kv,
sm_scale,
logits,
attn_lse,
Expand Down
8 changes: 7 additions & 1 deletion aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,8 @@ def mla_decode_stage1_asm_fwd(
work_indptr: Optional[torch.Tensor],
work_info_set: Optional[torch.Tensor],
max_seqlen_q: int,
page_size: int,
nhead_kv: int,
softmax_scale: float,
# [batch_size, num_kv_splits, num_heads, v_head_dim]
splitData: torch.Tensor,
Expand Down Expand Up @@ -854,6 +856,7 @@ def get_mla_metadata_info_v1(
def get_mla_metadata_v1(
seqlens_qo_indptr: torch.Tensor,
seqlens_kv_indptr: torch.Tensor,
kv_last_page_lens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
is_causal: bool,
Expand All @@ -863,6 +866,7 @@ def get_mla_metadata_v1(
reduce_indptr: torch.Tensor,
reduce_final_map: torch.Tensor,
reduce_partial_map: torch.Tensor,
page_size: int = 1,
kv_granularity: int = 16,
max_seqlen_qo: int = -1,
uni_seqlen_qo: int = -1,
Expand All @@ -876,11 +880,13 @@ def get_mla_metadata_v1(
"""
Inputs:
cumulated seqlens of q/o: (batch_size + 1), dtype torch.int32.
cumulated seqlens of k/v: (batch_size + 1), dtype torch.int32.
cumulated seqlens or page indices of k/v: (batch_size + 1), dtype torch.int32.
Length of last page of k/v: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
is_causal: Whether causal mask is enabled.
Options: Detailed settings for spliting. All of them are optional.
page_size: default=1. The size of a page.
kv_granularity: default=16. The granularity on kv sequence length when cutting batch.
max_seqlen_qo: default=-1. Used to check lds usage and save time. value less than 1 means unknown.
uni_seqlen_qo: default=-1. Sequence length of qo is uniform across batches. value less than 1 means the
Expand Down
2 changes: 2 additions & 0 deletions csrc/include/attention_asm_mla.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ void mla_decode_stage1_asm_fwd(
std::optional<torch::Tensor>& work_indptr, // metadata
std::optional<torch::Tensor>& work_info_set, // [batch_size+1]
int max_seqlen_q,
int page_size,
int nhead_kv,
float softmax_scale,
// following are output
torch::Tensor& splitData, //[batch_size, num_kv_splits, num_heads, v_head_dim]
Expand Down
6 changes: 4 additions & 2 deletions csrc/include/mla.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -37,6 +37,7 @@ static_assert(kSizeMlaPartialTileInfoInDw == 2);

void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
const torch::Tensor& seqlens_kv_indptr, // [batch size + 1]
const torch::Tensor& kv_last_page_lens, // [batch size]
const int32_t num_heads_per_head_k,
const int32_t num_heads_k,
const bool is_causal,
Expand All @@ -46,13 +47,14 @@ void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size
torch::Tensor& reduce_indptr,
torch::Tensor& reduce_final_map,
torch::Tensor& reduce_partial_map,
const int32_t page_size,
const int32_t kv_granularity,
const int32_t max_seqlen_qo,
const int32_t uni_seqlen_qo,
const bool fast_mode,
const int32_t topk,
const int32_t max_split_per_batch,
const bool intra_batch_mode,
const bool intra_batch_mode,
const std::optional<at::ScalarType> dtype_q,
const std::optional<at::ScalarType> dtype_kv);

Expand Down
4 changes: 4 additions & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ namespace py = pybind11;
py::arg("work_indptr"), \
py::arg("work_info_set"), \
py::arg("max_seqlen_q"), \
py::arg("page_size"), \
py::arg("nhead_kv"), \
py::arg("softmax_scale"), \
py::arg("splitData"), \
py::arg("splitLse"), \
Expand Down Expand Up @@ -1654,6 +1656,7 @@ namespace py = pybind11;
"get_mla_metadata_v1", \
py::arg("seqlens_qo_indptr"), \
py::arg("seqlens_kv_indptr"), \
py::arg("kv_last_page_lens"), \
py::arg("num_heads_per_head_k"), \
py::arg("num_heads_k"), \
py::arg("is_causal"), \
Expand All @@ -1663,6 +1666,7 @@ namespace py = pybind11;
py::arg("reduce_indptr"), \
py::arg("reduce_final_map"), \
py::arg("reduce_partial_map"), \
py::arg("page_size") = 1, \
py::arg("kv_granularity") = 16, \
py::arg("max_seqlen_qo") = -1, \
py::arg("uni_seqlen_qo") = -1, \
Expand Down
12 changes: 11 additions & 1 deletion csrc/kernels/mla/metadata.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.

#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include "metadata/v1_0_device.cuh"
Expand Down Expand Up @@ -40,6 +40,7 @@
void get_mla_metadata_v1(
const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
const torch::Tensor& seqlens_kv_indptr, // [batch size + 1]
const torch::Tensor& kv_last_page_lens, // [batch size]
const int32_t num_heads_per_head_k,
const int32_t num_heads_k,
const bool is_causal,
Expand All @@ -49,6 +50,7 @@ void get_mla_metadata_v1(
torch::Tensor& reduce_indptr,
torch::Tensor& reduce_final_map,
torch::Tensor& reduce_partial_map,
const int32_t page_size,
const int32_t kv_granularity,
const int32_t max_seqlen_qo,
const int32_t uni_seqlen_qo,
Expand All @@ -63,6 +65,8 @@ void get_mla_metadata_v1(

TORCH_CHECK((kv_granularity & (kv_granularity - 1)) == 0,
__func__, ": kv_granularity Must be power of 2!");
TORCH_CHECK((page_size & (page_size - 1)) == 0,
__func__, ": page_size Must be power of 2!");
TORCH_CHECK(seqlens_qo_indptr.stride(0) == 1,
__func__, ": seqlens_qo_indptr should be continuous!");
TORCH_CHECK(seqlens_qo_indptr.scalar_type() == at::ScalarType::Int,
Expand All @@ -71,6 +75,10 @@ void get_mla_metadata_v1(
__func__, ": seqlens_kv_indptr should be continuous!");
TORCH_CHECK(seqlens_kv_indptr.scalar_type() == at::ScalarType::Int,
__func__, ": seqlens_kv_indptr's element type should be int!");
TORCH_CHECK(kv_last_page_lens.stride(0) == 1,
__func__, ": kv_last_page_lens should be continuous!");
TORCH_CHECK(kv_last_page_lens.scalar_type() == at::ScalarType::Int,
__func__, ": kv_last_page_lens's element type should be int!");

at::ScalarType q_dtype = dtype_q.has_value() ? dtype_q.value() : at::ScalarType::BFloat16;
at::ScalarType kv_dtype = dtype_kv.has_value() ? dtype_kv.value() : at::ScalarType::BFloat16;
Expand All @@ -80,9 +88,11 @@ void get_mla_metadata_v1(
get_mla_metadata_v1_2_device(
seqlens_qo_indptr,
seqlens_kv_indptr,
kv_last_page_lens,
num_heads_per_head_k,
num_heads_k,
is_causal,
page_size,
kv_granularity,
max_seqlen_qo,
uni_seqlen_qo,
Expand Down
Loading
Loading