From fe0f981b059dc3f6c472ae51c156647f241708b9 Mon Sep 17 00:00:00 2001 From: minmengdie Date: Wed, 28 Jan 2026 01:14:40 +0000 Subject: [PATCH 1/4] mla ps support paged 64 and 3buffer layout for ds3.2 --- aiter/mla.py | 12 +- aiter/ops/attention.py | 8 +- csrc/include/attention_asm_mla.h | 2 + csrc/include/mla.h | 6 +- csrc/include/rocm_ops.hpp | 4 + csrc/kernels/mla/metadata.cu | 12 +- csrc/kernels/mla/metadata/v1_2_device.cuh | 122 +++--- csrc/kernels/mla/metadata/v1_comm.cuh | 324 ++++++++-------- csrc/py_itfs_cu/asm_mla.cu | 16 +- hsa/gfx942/mla/mla.co | Bin 0 -> 33728 bytes hsa/gfx942/mla/mla_asm.csv | 1 + hsa/gfx942/mla/mla_page64.co | Bin 0 -> 33984 bytes op_tests/test_mla.py | 4 + op_tests/test_mla_persistent.py | 439 ++++++++++++++++++++-- op_tests/test_mla_sparse.py | 8 +- 15 files changed, 713 insertions(+), 245 deletions(-) create mode 100755 hsa/gfx942/mla/mla.co create mode 100755 hsa/gfx942/mla/mla_page64.co diff --git a/aiter/mla.py b/aiter/mla.py index 4a85c320c0..6a38fd8cf2 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -150,6 +150,8 @@ def mla_decode_fwd( kv_indices, kv_last_page_lens, max_seqlen_q, + page_size=1, + nhead_kv=1, sm_scale=None, # 1.0 / (qk_head_dim**0.5) logit_cap=0.0, num_kv_splits=None, # for experts only!!! @@ -168,7 +170,11 @@ def mla_decode_fwd( ): device = q.device assert logit_cap <= 0, f"{logit_cap=} is not support yet" - num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape + if kv_buffer.dtype != torch.uint8: + _, _, _, qk_head_dim = kv_buffer.shape + else: + _, _, qk_head_dim = q.shape + if sm_scale is None: sm_scale = 1.0 / (qk_head_dim**0.5) @@ -227,6 +233,8 @@ def mla_decode_fwd( None, None, max_seqlen_q, + page_size, + nhead_kv, sm_scale, logits, attn_lse, @@ -319,6 +327,8 @@ def mla_decode_fwd( work_indptr, work_info_set, max_seqlen_q, + page_size, + nhead_kv, sm_scale, logits, attn_lse, diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 5f255c88bb..4219dfec4d 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -566,6 +566,8 @@ def mla_decode_stage1_asm_fwd( work_indptr: Optional[torch.Tensor], work_info_set: Optional[torch.Tensor], max_seqlen_q: int, + page_size: int, + nhead_kv: int, softmax_scale: float, # [batch_size, num_kv_splits, num_heads, v_head_dim] splitData: torch.Tensor, @@ -854,6 +856,7 @@ def get_mla_metadata_info_v1( def get_mla_metadata_v1( seqlens_qo_indptr: torch.Tensor, seqlens_kv_indptr: torch.Tensor, + kv_last_page_lens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, is_causal: bool, @@ -863,6 +866,7 @@ def get_mla_metadata_v1( reduce_indptr: torch.Tensor, reduce_final_map: torch.Tensor, reduce_partial_map: torch.Tensor, + page_size: int = 1, kv_granularity: int = 16, max_seqlen_qo: int = -1, uni_seqlen_qo: int = -1, @@ -876,11 +880,13 @@ def get_mla_metadata_v1( """ Inputs: cumulated seqlens of q/o: (batch_size + 1), dtype torch.int32. - cumulated seqlens of k/v: (batch_size + 1), dtype torch.int32. + cumulated seqlens or page indices of k/v: (batch_size + 1), dtype torch.int32. + Length of last page of k/v: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to num_heads_q // num_heads_k. num_heads_k: num_heads_k. is_causal: Whether causal mask is enabled. Options: Detailed settings for spliting. All of them are optional. + page_size: default=1. The size of a page. kv_granularity: default=16. The granularity on kv sequence length when cutting batch. max_seqlen_qo: default=-1. Used to check lds usage and save time. value less than 1 means unknown. uni_seqlen_qo: default=-1. Sequence length of qo is uniform across batches. value less than 1 means the diff --git a/csrc/include/attention_asm_mla.h b/csrc/include/attention_asm_mla.h index 47ef19afd3..93d8b6662e 100644 --- a/csrc/include/attention_asm_mla.h +++ b/csrc/include/attention_asm_mla.h @@ -15,6 +15,8 @@ void mla_decode_stage1_asm_fwd( std::optional& work_indptr, // metadata std::optional& work_info_set, // [batch_size+1] int max_seqlen_q, + int page_size, + int nhead_kv, float softmax_scale, // following are output torch::Tensor& splitData, //[batch_size, num_kv_splits, num_heads, v_head_dim] diff --git a/csrc/include/mla.h b/csrc/include/mla.h index 8191e863a7..147d7e2214 100644 --- a/csrc/include/mla.h +++ b/csrc/include/mla.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,6 +37,7 @@ static_assert(kSizeMlaPartialTileInfoInDw == 2); void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] const torch::Tensor& seqlens_kv_indptr, // [batch size + 1] + const torch::Tensor& kv_last_page_lens, // [batch size] const int32_t num_heads_per_head_k, const int32_t num_heads_k, const bool is_causal, @@ -46,13 +47,14 @@ void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size torch::Tensor& reduce_indptr, torch::Tensor& reduce_final_map, torch::Tensor& reduce_partial_map, + const int32_t page_size, const int32_t kv_granularity, const int32_t max_seqlen_qo, const int32_t uni_seqlen_qo, const bool fast_mode, const int32_t topk, const int32_t max_split_per_batch, - const bool intra_batch_mode, + const bool intra_batch_mode, const std::optional dtype_q, const std::optional dtype_kv); diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 4936c903b6..22c4d22e99 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -57,6 +57,8 @@ namespace py = pybind11; py::arg("work_indptr"), \ py::arg("work_info_set"), \ py::arg("max_seqlen_q"), \ + py::arg("page_size"), \ + py::arg("nhead_kv"), \ py::arg("softmax_scale"), \ py::arg("splitData"), \ py::arg("splitLse"), \ @@ -1654,6 +1656,7 @@ namespace py = pybind11; "get_mla_metadata_v1", \ py::arg("seqlens_qo_indptr"), \ py::arg("seqlens_kv_indptr"), \ + py::arg("kv_last_page_lens"), \ py::arg("num_heads_per_head_k"), \ py::arg("num_heads_k"), \ py::arg("is_causal"), \ @@ -1663,6 +1666,7 @@ namespace py = pybind11; py::arg("reduce_indptr"), \ py::arg("reduce_final_map"), \ py::arg("reduce_partial_map"), \ + py::arg("page_size") = 1, \ py::arg("kv_granularity") = 16, \ py::arg("max_seqlen_qo") = -1, \ py::arg("uni_seqlen_qo") = -1, \ diff --git a/csrc/kernels/mla/metadata.cu b/csrc/kernels/mla/metadata.cu index 992046e96c..32ef11722d 100644 --- a/csrc/kernels/mla/metadata.cu +++ b/csrc/kernels/mla/metadata.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. #include #include "metadata/v1_0_device.cuh" @@ -40,6 +40,7 @@ void get_mla_metadata_v1( const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] const torch::Tensor& seqlens_kv_indptr, // [batch size + 1] + const torch::Tensor& kv_last_page_lens, // [batch size] const int32_t num_heads_per_head_k, const int32_t num_heads_k, const bool is_causal, @@ -49,6 +50,7 @@ void get_mla_metadata_v1( torch::Tensor& reduce_indptr, torch::Tensor& reduce_final_map, torch::Tensor& reduce_partial_map, + const int32_t page_size, const int32_t kv_granularity, const int32_t max_seqlen_qo, const int32_t uni_seqlen_qo, @@ -63,6 +65,8 @@ void get_mla_metadata_v1( TORCH_CHECK((kv_granularity & (kv_granularity - 1)) == 0, __func__, ": kv_granularity Must be power of 2!"); + TORCH_CHECK((page_size & (page_size - 1)) == 0, + __func__, ": page_size Must be power of 2!"); TORCH_CHECK(seqlens_qo_indptr.stride(0) == 1, __func__, ": seqlens_qo_indptr should be continuous!"); TORCH_CHECK(seqlens_qo_indptr.scalar_type() == at::ScalarType::Int, @@ -71,6 +75,10 @@ void get_mla_metadata_v1( __func__, ": seqlens_kv_indptr should be continuous!"); TORCH_CHECK(seqlens_kv_indptr.scalar_type() == at::ScalarType::Int, __func__, ": seqlens_kv_indptr's element type should be int!"); + TORCH_CHECK(kv_last_page_lens.stride(0) == 1, + __func__, ": kv_last_page_lens should be continuous!"); + TORCH_CHECK(kv_last_page_lens.scalar_type() == at::ScalarType::Int, + __func__, ": kv_last_page_lens's element type should be int!"); at::ScalarType q_dtype = dtype_q.has_value() ? dtype_q.value() : at::ScalarType::BFloat16; at::ScalarType kv_dtype = dtype_kv.has_value() ? dtype_kv.value() : at::ScalarType::BFloat16; @@ -80,9 +88,11 @@ void get_mla_metadata_v1( get_mla_metadata_v1_2_device( seqlens_qo_indptr, seqlens_kv_indptr, + kv_last_page_lens, num_heads_per_head_k, num_heads_k, is_causal, + page_size, kv_granularity, max_seqlen_qo, uni_seqlen_qo, diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index 2e43b76713..8142c4c348 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. #include "v1_comm.cuh" @@ -16,10 +16,9 @@ struct MlaMetadataV12Traits // <= -1: read from seqlens_qo_indptr // == 0: read from MlaMetadataV1KernelParameter::uni_seqlen_qo // >= 1: read from MlaMetadataV12Traits::kUniSeqlenQo - static constexpr int32_t kUniSeqlenQo = kUniSeqlenQo_; - static constexpr int32_t kFixedOverheadNumBlocks = 16; - static constexpr int32_t kIsSparse = kIsSparse_; - static constexpr int32_t kLdsBatchInfo = kLdsBatchInfo_; + static constexpr int32_t kUniSeqlenQo = kUniSeqlenQo_; + static constexpr int32_t kIsSparse = kIsSparse_; + static constexpr int32_t kLdsBatchInfo = kLdsBatchInfo_; }; template @@ -29,7 +28,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ using QoState = QoState; const int32_t ori_seqlen_qo = [&]() { - if constexpr (Traits::kIsSparse) + if constexpr(Traits::kIsSparse) { return params.p_seqlens_qo_indptr[1] - params.p_seqlens_qo_indptr[0]; } @@ -40,7 +39,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ }(); const int32_t num_batches = [&]() { - if constexpr (Traits::kIsSparse) + if constexpr(Traits::kIsSparse) { return params.num_batches * ori_seqlen_qo; } @@ -77,9 +76,8 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ int32_t sum_blocks = 0; for(int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size()) { - const int32_t bid_ori = Traits::kIsSparse - ? (bid / ori_seqlen_qo / params.qk_batch_ratio) - : (bid / params.qk_batch_ratio); + const int32_t bid_ori = Traits::kIsSparse ? (bid / ori_seqlen_qo / params.qk_batch_ratio) + : (bid / params.qk_batch_ratio); const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1]; const int32_t seqlen_kv = Traits::kIsSparse ? min(kv_end - params.p_seqlens_kv_indptr[bid_ori], params.topk) @@ -93,7 +91,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ const int32_t num_blocks = integer_divide_ceil_power2( seqlen_kv, params.kv_granularity, params.kv_granularity_log2); const int32_t num_qo_tiles = get_num_qo_tiles(bid); - sum_blocks += (num_blocks + Traits::kFixedOverheadNumBlocks) * num_qo_tiles; + sum_blocks += (num_blocks + params.k_fixed_over_head_num_blocks) * num_qo_tiles; if constexpr(QoState::is_unique() == false) { @@ -118,8 +116,8 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ // expected payload handled by each cu part. const int32_t payload = ck_tile::integer_divide_ceil(sum_blocks, params.num_splits) + - Traits::kFixedOverheadNumBlocks; - + params.k_fixed_over_head_num_blocks; + const int32_t page_size = params.page_size; int32_t curr_batch = 0; // batch ID of the batch which is under review int32_t curr_kv_block = 0; // #blocks handled by previous cu part(s) int32_t curr_n_split_idx = 0; // #cu parts used to handle current batch @@ -151,7 +149,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ const int32_t remain_kv_blocks = num_kv_blocks - curr_kv_block; // If current cu part is able to handle this batch of seqences - if(remain_payload >= (remain_kv_blocks + Traits::kFixedOverheadNumBlocks)) + if(remain_payload >= (remain_kv_blocks + params.k_fixed_over_head_num_blocks)) { const int32_t num_splits = curr_n_split_idx + 1; @@ -165,19 +163,35 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size, qo_state.get_end(curr_batch)); work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity); - int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); - if constexpr(!Traits::kIsSparse) + if(page_size == 1) { - if (params.qk_batch_ratio != 1) + int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); + if constexpr(!Traits::kIsSparse) { - batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1; + if(params.qk_batch_ratio != 1) + { + batch_tail = + num_qo_tiles - + (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - + 1; + } } + work_info.kv_end = ck_tile::min( + work_info.kv_start + (remain_kv_blocks * params.kv_granularity), + curr_kv_end - batch_tail); + work_info.kv_offset = curr_kv_end - work_info.kv_end; + } + else + { + work_info.kv_end = ck_tile::min( + work_info.kv_start + (remain_kv_blocks * params.kv_granularity), + curr_kv_end); + work_info.kv_offset = + (curr_kv_end - work_info.kv_end == 0) + ? 0 + : ((curr_kv_end - work_info.kv_end - 1) * page_size + + params.p_kv_last_page_lens[curr_batch]); } - work_info.kv_end = ck_tile::min( - work_info.kv_start + (remain_kv_blocks * params.kv_granularity), - curr_kv_end - batch_tail); - work_info.kv_offset = curr_kv_end - work_info.kv_end; - // split related info if(curr_n_split_idx > 0) { @@ -220,7 +234,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ tot_qo_tiles += 1; num_works += 1; - remain_payload -= (remain_kv_blocks + Traits::kFixedOverheadNumBlocks); + remain_payload -= (remain_kv_blocks + params.k_fixed_over_head_num_blocks); // update state curr_qo_tile_idx = @@ -242,11 +256,11 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ } else { - const int32_t bid_ori = Traits::kIsSparse - ? (curr_batch / ori_seqlen_qo / - params.qk_batch_ratio) - : (curr_batch / params.qk_batch_ratio); - curr_kv_seqlen = params.p_seqlens_kv_indptr[bid_ori + 1] - + const int32_t bid_ori = + Traits::kIsSparse + ? (curr_batch / ori_seqlen_qo / params.qk_batch_ratio) + : (curr_batch / params.qk_batch_ratio); + curr_kv_seqlen = params.p_seqlens_kv_indptr[bid_ori + 1] - params.p_seqlens_kv_indptr[bid_ori]; curr_kv_seqlen = Traits::kIsSparse ? min(curr_kv_seqlen, params.topk) @@ -268,9 +282,10 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ } else { - if(remain_payload > Traits::kFixedOverheadNumBlocks) + if(remain_payload > params.k_fixed_over_head_num_blocks) { - const int32_t consuming_blks = remain_payload - Traits::kFixedOverheadNumBlocks; + const int32_t consuming_blks = + remain_payload - params.k_fixed_over_head_num_blocks; auto fill_work_info = [&]() { MlaWorkInfo work_info{}; @@ -281,18 +296,35 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ qo_state.get_end(curr_batch)); work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity); - int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); - if constexpr(!Traits::kIsSparse) + if(page_size == 1) { - if (params.qk_batch_ratio != 1) + int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); + if constexpr(!Traits::kIsSparse) { - batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1; + if(params.qk_batch_ratio != 1) + { + batch_tail = num_qo_tiles - + (work_info.qo_start / params.qk_batch_ratio) % + ori_seqlen_qo - + 1; + } } + work_info.kv_end = ck_tile::min( + work_info.kv_start + (consuming_blks * params.kv_granularity), + curr_kv_end - batch_tail); + work_info.kv_offset = curr_kv_end - work_info.kv_end; + } + else + { + work_info.kv_end = ck_tile::min( + work_info.kv_start + (consuming_blks * params.kv_granularity), + curr_kv_end); + work_info.kv_offset = + (curr_kv_end - work_info.kv_end == 0) + ? 0 + : ((curr_kv_end - work_info.kv_end - 1) * page_size + + params.p_kv_last_page_lens[curr_batch]); } - work_info.kv_end = ck_tile::min( - work_info.kv_start + (consuming_blks * params.kv_granularity), - curr_kv_end - batch_tail); - work_info.kv_offset = curr_kv_end - work_info.kv_end; work_info.partial_qo_loc = partial_idx; p_work_info_set[num_works] = work_info; }; @@ -354,9 +386,11 @@ void dispatch_mla_metadata_v1_2_device(const MlaMetadataV1KernelParameter& param void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] const torch::Tensor& seqlens_kv_indptr, // [batch size + 1] + const torch::Tensor& kv_last_page_lens, // [batch size] const int32_t num_heads_per_head_k, const int32_t num_heads_k, const bool is_causal, + const int32_t page_size, const int32_t kv_granularity, const int32_t max_seqlen_qo, const int32_t ori_uni_seqlen_qo, @@ -372,6 +406,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba torch::Tensor& reduce_partial_map) { constexpr int32_t kPackedQoLenPerWg = 128; + // constexpr int32_t kPageSize = page_size; const hipStream_t stream = at::hip::getCurrentHIPStream(); hipDevice_t dev; @@ -394,8 +429,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba const bool kv_is_fp8 = (kv_dtype == at::ScalarType::Float8_e4m3fnuz || kv_dtype == at::ScalarType::Float8_e4m3fn); - const bool natively_supported = (num_heads == 16) || - ((num_heads == 128) && q_is_fp8 && kv_is_fp8); + const bool natively_supported = + (num_heads == 16) || ((num_heads == 128) && q_is_fp8 && kv_is_fp8); if((natively_supported == false) && (num_heads % 16 == 0)) { @@ -422,18 +457,21 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba params.p_reduce_partial_map = reduce_partial_map.data_ptr(); params.p_seqlens_qo_indptr = seqlens_qo_indptr.data_ptr(); params.p_seqlens_kv_indptr = seqlens_kv_indptr.data_ptr(); + params.p_kv_last_page_lens = kv_last_page_lens.data_ptr(); params.num_batches = num_batches; params.num_heads = num_heads_k * num_heads_per_head_k; params.num_cu = num_clusters; params.num_splits = num_splits; params.reduce_indptr_size = reduce_indptr.size(0); + params.page_size = page_size; params.kv_granularity = kv_granularity; params.kv_granularity_log2 = __builtin_ctz(kv_granularity); params.uni_seqlen_qo = uni_seqlen_qo; params.ori_seqlen_qo = ori_uni_seqlen_qo; params.is_causal = is_causal; - params.topk = topk; + params.topk = (topk + page_size - 1) / page_size; params.qk_batch_ratio = qk_batch_ratio; + params.k_fixed_over_head_num_blocks = max(1, (16 + page_size - 1) / page_size); // launch kernel MLA_METADATA_DISPATCHER( diff --git a/csrc/kernels/mla/metadata/v1_comm.cuh b/csrc/kernels/mla/metadata/v1_comm.cuh index da8e73534e..cdfb45664b 100644 --- a/csrc/kernels/mla/metadata/v1_comm.cuh +++ b/csrc/kernels/mla/metadata/v1_comm.cuh @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,16 +12,12 @@ #include "pa.h" -CK_TILE_HOST_DEVICE int32_t cal_cost( - const int32_t qo_len, - const int32_t kv_len) +CK_TILE_HOST_DEVICE int32_t cal_cost(const int32_t qo_len, const int32_t kv_len) { return 2 * qo_len + kv_len; } -CK_TILE_HOST_DEVICE int32_t cal_kv_len( - const int32_t cost, - const int32_t qo_len) +CK_TILE_HOST_DEVICE int32_t cal_kv_len(const int32_t cost, const int32_t qo_len) { return cost - 2 * qo_len; } @@ -32,56 +28,53 @@ struct BatchInfo int32_t qo_len; int32_t kv_len; - int32_t get_cost() const - { - return cal_cost(qo_len, kv_len); - } + int32_t get_cost() const { return cal_cost(qo_len, kv_len); } - bool operator > (const BatchInfo& rhs) const - { - return get_cost() > rhs.get_cost(); - } + bool operator>(const BatchInfo& rhs) const { return get_cost() > rhs.get_cost(); } }; struct MlaMetadataV1KernelParameter { // Outputs uint64_t* p_work_metadata_ptrs; - int32_t* p_work_indptr; - int32_t* p_work_info_set_raw; - int32_t* p_reduce_indptr; - int32_t* p_reduce_final_map; - int32_t* p_reduce_partial_map; + int32_t* p_work_indptr; + int32_t* p_work_info_set_raw; + int32_t* p_reduce_indptr; + int32_t* p_reduce_final_map; + int32_t* p_reduce_partial_map; // Inputs const int32_t* p_seqlens_qo_indptr; const int32_t* p_seqlens_kv_indptr; - int32_t num_batches; - int32_t fixed_num_batches; - int32_t num_heads; - int32_t num_cu; - int32_t reduce_indptr_size; - int32_t kv_granularity; - int32_t kv_granularity_log2; - int32_t uni_seqlen_qo; - int32_t ori_seqlen_qo; - int32_t topk; - int32_t qk_batch_ratio; - int32_t num_splits; - bool is_causal; + const int32_t* p_kv_last_page_lens; + int32_t num_batches; + int32_t fixed_num_batches; + int32_t num_heads; + int32_t num_cu; + int32_t reduce_indptr_size; + int32_t page_size; + int32_t kv_granularity; + int32_t kv_granularity_log2; + int32_t uni_seqlen_qo; + int32_t ori_seqlen_qo; + int32_t topk; + int32_t qk_batch_ratio; + int32_t num_splits; + bool is_causal; + int32_t k_fixed_over_head_num_blocks; }; -struct PaMetadataV1KernelParameter: MlaMetadataV1KernelParameter +struct PaMetadataV1KernelParameter : MlaMetadataV1KernelParameter { // Inputs const int32_t* p_pages_kv_indptr; const int32_t* p_context_lens; - int32_t block_size; - int32_t blocks_per_unit; - int32_t num_heads_k; - int32_t gqa_ratio; - int32_t qhead_granularity; - int32_t qlen_granularity; + int32_t block_size; + int32_t blocks_per_unit; + int32_t num_heads_k; + int32_t gqa_ratio; + int32_t qhead_granularity; + int32_t qlen_granularity; }; template @@ -89,7 +82,7 @@ CK_TILE_DEVICE T warp_sum(const T* p_data, const int32_t size) { T sum = T(0); - for (int32_t idx = ck_tile::get_lane_id(); idx < size; idx += ck_tile::get_warp_size()) + for(int32_t idx = ck_tile::get_lane_id(); idx < size; idx += ck_tile::get_warp_size()) { sum += p_data[idx]; } @@ -102,9 +95,9 @@ CK_TILE_DEVICE T warp_sum(const T* p_data, const int32_t size) template CK_TILE_DEVICE T warp_prefix_sum(T value, const int32_t size) { - // Always assume that size is power of 2 - #pragma unroll - for (int32_t offset = 1; offset <= (ck_tile::get_warp_size() >> 1) ; offset *= 2) +// Always assume that size is power of 2 +#pragma unroll + for(int32_t offset = 1; offset <= (ck_tile::get_warp_size() >> 1); offset *= 2) { const T remote = ck_tile::warp_shuffle_up(value, offset); value += (ck_tile::get_lane_id() >= offset) ? remote : 0; @@ -113,73 +106,73 @@ CK_TILE_DEVICE T warp_prefix_sum(T value, const int32_t size) } // Warp level customized bitonic sort for sorting batch idx based on cost. High cost first. -CK_TILE_DEVICE void warp_sort( - int32_t* p_batch_idx, - int32_t* p_workspace, - const int32_t* p_qo_lens, - const int32_t* p_kv_lens, - const int32_t num_batches) +CK_TILE_DEVICE void warp_sort(int32_t* p_batch_idx, + int32_t* p_workspace, + const int32_t* p_qo_lens, + const int32_t* p_kv_lens, + const int32_t num_batches) { const int32_t lane_idx = ck_tile::get_lane_id(); - const int32_t num_batches_padded = - ck_tile::integer_least_multiple(ck_tile::next_power_of_two(num_batches), ck_tile::get_warp_size()); + const int32_t num_batches_padded = ck_tile::integer_least_multiple( + ck_tile::next_power_of_two(num_batches), ck_tile::get_warp_size()); const int32_t warp_loops = num_batches_padded / ck_tile::get_warp_size(); - int32_t* p_costs = p_workspace; - int32_t* p_indices = p_costs + num_batches_padded; + int32_t* p_costs = p_workspace; + int32_t* p_indices = p_costs + num_batches_padded; auto check_and_swap = [&](const int32_t idx0, const int32_t idx1, const bool dir) { const int32_t cost0 = p_costs[idx0]; const int32_t cost1 = p_costs[idx1]; - if ((cost0 > cost1) == dir) + if((cost0 > cost1) == dir) { int32_t temp_idx = p_indices[idx0]; - p_indices[idx0] = p_indices[idx1]; - p_indices[idx1] = temp_idx; - p_costs[idx1] = cost0; - p_costs[idx0] = cost1; + p_indices[idx0] = p_indices[idx1]; + p_indices[idx1] = temp_idx; + p_costs[idx1] = cost0; + p_costs[idx0] = cost1; } }; // Initialize smem // Pre-calculate cost for each batch - for (int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size()) + for(int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size()) { - p_costs[bid] = cal_cost(p_qo_lens[bid], p_kv_lens[bid]); + p_costs[bid] = cal_cost(p_qo_lens[bid], p_kv_lens[bid]); p_indices[bid] = bid; } - for (int32_t bid = lane_idx + num_batches; bid < num_batches_padded; bid += ck_tile::get_warp_size()) + for(int32_t bid = lane_idx + num_batches; bid < num_batches_padded; + bid += ck_tile::get_warp_size()) { - p_costs[bid] = 0; + p_costs[bid] = 0; p_indices[bid] = bid; } - for (int32_t size = 2; size < num_batches_padded; size <<= 1) + for(int32_t size = 2; size < num_batches_padded; size <<= 1) { const int32_t max_stride = size >> 1; - for (int32_t loop_idx = 0; loop_idx < warp_loops; ++loop_idx) + for(int32_t loop_idx = 0; loop_idx < warp_loops; ++loop_idx) { const int32_t thr_idx = lane_idx + loop_idx * ck_tile::get_warp_size(); - if (thr_idx * 2 < num_batches_padded) + if(thr_idx * 2 < num_batches_padded) { const bool dir = ((thr_idx & max_stride) == 0); - for (int32_t stride = max_stride; stride > 0; stride >>= 1) + for(int32_t stride = max_stride; stride > 0; stride >>= 1) { const int32_t stride_m1 = stride - 1; - const int32_t idx = 2 * thr_idx - (thr_idx & stride_m1); + const int32_t idx = 2 * thr_idx - (thr_idx & stride_m1); check_and_swap(idx, idx + stride, dir); } } } } - for (int32_t stride = num_batches_padded >> 1; stride > 0; stride >>= 1) + for(int32_t stride = num_batches_padded >> 1; stride > 0; stride >>= 1) { const int32_t stride_m1 = stride - 1; - for (int32_t loop_idx = 0; loop_idx < warp_loops; ++loop_idx) + for(int32_t loop_idx = 0; loop_idx < warp_loops; ++loop_idx) { const int32_t thr_idx = lane_idx + loop_idx * ck_tile::get_warp_size(); - if (thr_idx * 2 < num_batches_padded) + if(thr_idx * 2 < num_batches_padded) { const int32_t idx = 2 * thr_idx - (thr_idx & stride_m1); check_and_swap(idx, idx + stride, false); @@ -188,7 +181,7 @@ CK_TILE_DEVICE void warp_sort( } // Output results - for (int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size()) + for(int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size()) { p_batch_idx[bid] = p_indices[bid]; } @@ -201,14 +194,12 @@ CK_TILE_DEVICE T integer_divide_ceil_power2(T x, T y, T y_log2) } template -std::vector flatten( - const std::vector>& vec, - const int size_after_flatten) +std::vector flatten(const std::vector>& vec, const int size_after_flatten) { std::vector result; result.reserve(size_after_flatten); - for (const auto& inner_vec : vec) + for(const auto& inner_vec : vec) { result.insert(result.end(), inner_vec.begin(), inner_vec.end()); } @@ -216,21 +207,21 @@ std::vector flatten( return result; } -CK_TILE_HOST_DEVICE int32_t cal_packed_causal_kv_len( - const int32_t qo_len, - const int32_t kv_len, - const int32_t qo_tile_idx, - const int32_t packed_qo_tile_len, - const int32_t num_qo_tiles, - const int32_t num_heads, - const bool is_causal) +CK_TILE_HOST_DEVICE int32_t cal_packed_causal_kv_len(const int32_t qo_len, + const int32_t kv_len, + const int32_t qo_tile_idx, + const int32_t packed_qo_tile_len, + const int32_t num_qo_tiles, + const int32_t num_heads, + const bool is_causal) { int result = kv_len; - if (is_causal && (qo_tile_idx < num_qo_tiles)) + if(is_causal && (qo_tile_idx < num_qo_tiles)) { const int kv_len_init = kv_len - qo_len; - const int kv_len_slop = ck_tile::integer_divide_ceil((qo_tile_idx + 1) * packed_qo_tile_len, num_heads); + const int kv_len_slop = + ck_tile::integer_divide_ceil((qo_tile_idx + 1) * packed_qo_tile_len, num_heads); result = ck_tile::min(kv_len_init + kv_len_slop, kv_len); } @@ -240,31 +231,27 @@ CK_TILE_HOST_DEVICE int32_t cal_packed_causal_kv_len( template class QoState { -public: - CK_TILE_DEVICE explicit QoState( - const int32_t uni_seqlen_qo, - const int32_t ori_seqlen_qo, - const int32_t* p_lds_seqlens_qo, - const int32_t* p_seqlens_qo_indptr) : - uni_seqlen_qo_(uni_seqlen_qo), - ori_seqlen_qo_(ori_seqlen_qo), - p_lds_seqlens_qo_(p_lds_seqlens_qo), - p_seqlens_qo_indptr_(p_seqlens_qo_indptr) - { } - - CK_TILE_HOST_DEVICE static constexpr bool is_unique() + public: + CK_TILE_DEVICE explicit QoState(const int32_t uni_seqlen_qo, + const int32_t ori_seqlen_qo, + const int32_t* p_lds_seqlens_qo, + const int32_t* p_seqlens_qo_indptr) + : uni_seqlen_qo_(uni_seqlen_qo), + ori_seqlen_qo_(ori_seqlen_qo), + p_lds_seqlens_qo_(p_lds_seqlens_qo), + p_seqlens_qo_indptr_(p_seqlens_qo_indptr) { - return Traits::kUniSeqlenQo >= 0; } - CK_TILE_DEVICE int32_t get_seqlen( - const int32_t batch_idx) + CK_TILE_HOST_DEVICE static constexpr bool is_unique() { return Traits::kUniSeqlenQo >= 0; } + + CK_TILE_DEVICE int32_t get_seqlen(const int32_t batch_idx) { - if constexpr (Traits::kUniSeqlenQo == 0) + if constexpr(Traits::kUniSeqlenQo == 0) { return uni_seqlen_qo_; } - else if constexpr (Traits::kUniSeqlenQo <= -1) + else if constexpr(Traits::kUniSeqlenQo <= -1) { const int32_t bid = Traits::kIsSparse ? (batch_idx / ori_seqlen_qo_) : batch_idx; return p_lds_seqlens_qo_[bid]; @@ -275,14 +262,13 @@ public: } } - CK_TILE_DEVICE int32_t get_begin( - const int32_t batch_idx) + CK_TILE_DEVICE int32_t get_begin(const int32_t batch_idx) { - if constexpr (Traits::kUniSeqlenQo == 0) + if constexpr(Traits::kUniSeqlenQo == 0) { return uni_seqlen_qo_ * batch_idx; } - else if constexpr (Traits::kUniSeqlenQo <= -1) + else if constexpr(Traits::kUniSeqlenQo <= -1) { const int32_t bid = Traits::kIsSparse ? (batch_idx / ori_seqlen_qo_) : batch_idx; return p_seqlens_qo_indptr_[bid]; @@ -293,14 +279,13 @@ public: } } - CK_TILE_DEVICE int32_t get_end( - const int32_t batch_idx) + CK_TILE_DEVICE int32_t get_end(const int32_t batch_idx) { - if constexpr (Traits::kUniSeqlenQo == 0) + if constexpr(Traits::kUniSeqlenQo == 0) { return uni_seqlen_qo_ * (batch_idx + 1); } - else if constexpr (Traits::kUniSeqlenQo <= -1) + else if constexpr(Traits::kUniSeqlenQo <= -1) { const int32_t bid = Traits::kIsSparse ? (batch_idx / ori_seqlen_qo_) : batch_idx; return p_seqlens_qo_indptr_[bid + 1]; @@ -311,76 +296,75 @@ public: } } - CK_TILE_DEVICE int32_t get_q_head_range( - const int32_t q_head_start, const int32_t q_head_end) { + CK_TILE_DEVICE int32_t get_q_head_range(const int32_t q_head_start, const int32_t q_head_end) + { int32_t q_head_range = (q_head_end << 16) | (q_head_start & 0xFFFF); return q_head_range; } -private: + private: const int32_t uni_seqlen_qo_; const int32_t ori_seqlen_qo_; const int32_t* const p_lds_seqlens_qo_; const int32_t* const p_seqlens_qo_indptr_; }; -#define MLA_UNI_SEQLEN_QO_CASE(C_UNI_SEQLEN_QO, ...) \ - case C_UNI_SEQLEN_QO: \ - { \ - constexpr int32_t kUniSeqlenQo = C_UNI_SEQLEN_QO; \ - __VA_ARGS__; \ - break; \ +#define MLA_UNI_SEQLEN_QO_CASE(C_UNI_SEQLEN_QO, ...) \ + case C_UNI_SEQLEN_QO: { \ + constexpr int32_t kUniSeqlenQo = C_UNI_SEQLEN_QO; \ + __VA_ARGS__; \ + break; \ } -#define MLA_UNI_SEQLEN_DISPATCHER(UNI_SEQLEN_QO, ...) \ - switch (UNI_SEQLEN_QO) \ - { \ - MLA_UNI_SEQLEN_QO_CASE(1, __VA_ARGS__); \ - MLA_UNI_SEQLEN_QO_CASE(2, __VA_ARGS__); \ - MLA_UNI_SEQLEN_QO_CASE(3, __VA_ARGS__); \ - MLA_UNI_SEQLEN_QO_CASE(4, __VA_ARGS__); \ - default: \ - { \ - if ((UNI_SEQLEN_QO) > 0) \ - { \ - constexpr int32_t kUniSeqlenQo = 0; \ - __VA_ARGS__; \ - } \ - else \ - { \ - constexpr int32_t kUniSeqlenQo = -1; \ - __VA_ARGS__; \ - } \ - break; \ - } \ +#define MLA_UNI_SEQLEN_DISPATCHER(UNI_SEQLEN_QO, ...) \ + switch(UNI_SEQLEN_QO) \ + { \ + MLA_UNI_SEQLEN_QO_CASE(1, __VA_ARGS__); \ + MLA_UNI_SEQLEN_QO_CASE(2, __VA_ARGS__); \ + MLA_UNI_SEQLEN_QO_CASE(3, __VA_ARGS__); \ + MLA_UNI_SEQLEN_QO_CASE(4, __VA_ARGS__); \ + default: { \ + if((UNI_SEQLEN_QO) > 0) \ + { \ + constexpr int32_t kUniSeqlenQo = 0; \ + __VA_ARGS__; \ + } \ + else \ + { \ + constexpr int32_t kUniSeqlenQo = -1; \ + __VA_ARGS__; \ + } \ + break; \ + } \ } -#define MLA_METADATA_DISPATCHER(MAX_PACKED_SEQLEN_QO, PACKED_QO_LEN_PER_WG, UNI_SEQLEN_QO, TOPK, ...) \ - if (((MAX_PACKED_SEQLEN_QO) > 0) && ((MAX_PACKED_SEQLEN_QO) <= PACKED_QO_LEN_PER_WG)) \ - { \ - constexpr bool kQoSplits = false; \ - if ((TOPK) < 0) \ - { \ - constexpr bool kIsSparse = false; \ - MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ - } \ - else \ - { \ - constexpr bool kIsSparse = true; \ - MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ - } \ - } \ - else \ - { \ - constexpr bool kQoSplits = true; \ - if ((TOPK) < 0) \ - { \ - constexpr bool kIsSparse = false; \ - MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ - } \ - else \ - { \ - constexpr bool kIsSparse = true; \ - MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ - } \ +#define MLA_METADATA_DISPATCHER( \ + MAX_PACKED_SEQLEN_QO, PACKED_QO_LEN_PER_WG, UNI_SEQLEN_QO, TOPK, ...) \ + if(((MAX_PACKED_SEQLEN_QO) > 0) && ((MAX_PACKED_SEQLEN_QO) <= PACKED_QO_LEN_PER_WG)) \ + { \ + constexpr bool kQoSplits = false; \ + if((TOPK) < 0) \ + { \ + constexpr bool kIsSparse = false; \ + MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool kIsSparse = true; \ + MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ + } \ + } \ + else \ + { \ + constexpr bool kQoSplits = true; \ + if((TOPK) < 0) \ + { \ + constexpr bool kIsSparse = false; \ + MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool kIsSparse = true; \ + MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \ + } \ } diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index 2ba53b3f2c..6a539e495d 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -73,7 +73,6 @@ std::string get_heuristic_kernel_mla(std::string q_type, continue; if (cfg.causal != causal || cfg.qSeqLen != qseqlen) continue; - return el.first; } @@ -92,7 +91,7 @@ std::string get_heuristic_kernel_mla(std::string q_type, void mla_decode_stage1_asm_fwd( torch::Tensor& Q, // [num_seqs, num_heads, head_size] - torch::Tensor& KV, // [num_page, page_size, num_kv_heads, head_size] + torch::Tensor& KV, // [num_page, page_size, num_kv_heads, head_size] or [num_page, page_size*(nhead_kv*(kv_lora_rank+scale_dim+qk_rope_head_dim))] torch::Tensor& qo_indptr, // [batch_size+1] torch::Tensor& kv_indptr, // [batch_size+1] torch::Tensor& kv_page_indices, // [num_page_used] @@ -102,6 +101,8 @@ void mla_decode_stage1_asm_fwd( std::optional& work_indptr, // metadata std::optional& work_info_set, // [batch_size+1] int max_seqlen_q, + int page_size, + int nhead_kv, float softmax_scale, // following are output torch::Tensor& splitData, //[batch_size, num_kv_splits, num_heads, v_head_dim] @@ -114,8 +115,7 @@ void mla_decode_stage1_asm_fwd( int batch = qo_indptr.size(0) - 1; int num_heads = Q.size(1); int head_size = Q.size(2); - int page_size = KV.size(1); - int num_kv_heads = KV.size(2); + int num_kv_heads = nhead_kv; int kv_split = splitData.size(1); const int gqa_ratio = num_heads / num_kv_heads; @@ -197,7 +197,9 @@ void mla_decode_stage1_asm_fwd( TORCH_CHECK(Q.is_contiguous(), __func__, ":only support Q.is_contiguous() for now"); TORCH_CHECK(num_kv_heads == 1, __func__, ":only support num_kv_heads==1 for now"); - TORCH_CHECK(head_size == KV.size(3), __func__, ":only support head_size == KV.size(3) for now"); + if (KV.dtype() != at::ScalarType::Byte && KV.dtype() != at::ScalarType::Char) { + TORCH_CHECK(head_size == KV.size(3), __func__, ":only support head_size == KV.size(3) for now"); + } if(Q.dtype() == at::ScalarType::Float8_e4m3fnuz || Q.dtype() == at::ScalarType::Float8_e4m3fn) { @@ -225,6 +227,8 @@ void mla_decode_stage1_asm_fwd( kv_type = "bf16"; else if(KV.dtype() == at::ScalarType::Float8_e4m3fnuz || KV.dtype() == at::ScalarType::Float8_e4m3fn) kv_type = "fp8"; + else if(KV.dtype() == at::ScalarType::Byte || KV.dtype() == at::ScalarType::Char) + kv_type = "byte"; else TORCH_CHECK(false, __func__, ": unsupport KV dtype:", KV.scalar_type()); @@ -263,7 +267,7 @@ void mla_decode_stage1_asm_fwd( config_max_seqlen_q = 8; } } - }else if (q_type == "bf16" && kv_type == "fp8"){ + }else if ((q_type == "bf16" && kv_type == "fp8") || (q_type == "bf16" && kv_type == "byte")){ if(persistent){ if(max_seqlen_q <= 4){ config_max_seqlen_q = 4; diff --git a/hsa/gfx942/mla/mla.co b/hsa/gfx942/mla/mla.co new file mode 100755 index 0000000000000000000000000000000000000000..e878cf163a29d7d8a7e875d692af01ece8aaf2d4 GIT binary patch literal 33728 zcmeHw4SW>Ux%bKJZW3k_Ji!111YJU)fXfmBgbgBYfQ9fW--1=bk`DreBqjlY;${*8 zVyX}$T12!F5s{)*FKV=ytQIXQC=lyKtu40I+uN(Xy|=f0Ti@1w|IeH`%Vt?j0KM_v zH<`g9+#$#n-v9Pl&##|@xnCl*0tjx3#=VtNv-<>6} zSkTngBZ|+%fZ3VEc(WeSa+@aA7J8VfL@sn{2NM{Re5&A?i8)0=m&n4}k*zr!&GOxC zu}r9}$>Voo;g5rvM?t0ASD_>FZ3zok1DVHy=a;>j8 z*O!-9P*HKPyug>Adt3R+$^w++=a+-attj*573^P8T)Ny>{Ftl6x2E8c8yW>mih8!qxdL|4lLi0Yjh!d zs@zak@aW1>8JB9}T)8zKWozx4s6}biVH2M)yjWUNyx}XII73?Kh_bPEEgTZ1(bsL^ zB57fIl#R7(p*u>Wui3)A6BkGmJyEvSu8BEO8hza+PMRrAoD*eh?V7kaN~5pa#2KM4 z%d#k2YuCi0D2*;?qB$PytH>)}AstgDIrhc2Omy%W#5;OR?f-^K-pFSmoH%kALEC3LWN(m59d zE}?xJ7aYCZ4vt-J2gfg_gUBGSV-eKexuMzr_(`$3iwi`LpNZ1yf)mdr8ae-*k!Pc9 zuHD)7`0t{$>adX=jOh2yn-dy>{upI@?OXZRD6KkdWd}d+$)PyqL=20vy>_iU5fdM! z+1GAnD6&5h!=r4k17>PbnqAP$7hNmLOV^g=RursQQ&3WwTexz4L4Gbx#&6Tf|MSL) zAI|tf6@0}jSCm+vJCCO(Ct@5;t~I{(xrN2P%G}#Z%U7G78)fi1odEuLO|dUmoKEHz zt}V&?iL0!9d zMG?rDfm>OuqXp{|Bv6VT5NyAmKbyaP-o{R$RzhXvsapf*)tmG!wY6lZ_urb)n>aJX0W>EldJ)<9dbWpcXN<^znPAzAp6k7 zdrjP7;tx!G(7DBB|bM2&Q4OaGK^9WayUhN&feZ<(>(Pzoo{7&yZY9{kW@N)U&)<+q1^!=q;?N*_{I)doy3W|&H9C8}Ij45&j+`Co zTWguqQRC!k@oe;|c;wE!-05ghT>jokJq8bVI*X_*w?_}o_?)xxYOFqu^CDy2F4pJF zhE0mo-97=jcqkhVR}n%TbW<8mHVw&cXO_1B_-9&oztYdr;|+$&2d4u@!8fe zPnz4~DI(p016g;c>cP~3x{iZ0I55pXVW%r?2XhXtai%+h>FEg`3<=r@SJ7s0L1%hG zJaU?(*SIF6rLw;_>z?tjb?(`#m{n&ZVGPFDd8F%k2aKZ{U6e9Aoepz z`rm+pW}4aG6-V@V_$xFHFnn{(5efJ4+uMOfkwG9)W4Pz###FPG&L6(>(%5>jK>Yf#aFgDED{-PNwlmejfB)GmgPC{~vrF z)#CHe=NSHA0fAOF*U;lV0uywBJRmTc%`?h1kH8dNpd%nKjmvXEUy)J)HN%04RRl%Tgm4UjF=GT?M5C?ceu(*Ov(uTjg}EF-mn(sFLx?uqSww9d zn48?r{FP!JRMO&tsRpN~iaC?&22(L-Vm=5C9AqHvPf$5t1qL?0O;{CFPF4l6ruG@a zlPyDvx(yvtL~{wv?a7C2x`*&3o30ZEYTs-;{q0a0&<0#z^^&1)0+L=^6)^N|K+@~0UN`ieK+^YA9XIqwAn6CHP8<3mAnAvz z&KP{Z}Tx3sdK0$;nc zl^p}VaZf9I1Nhc`t?X^!I}f$8cYvoJX=U#MfBr-(`#JERpKN6x06%=Lm3;{O=*3oc z7WnZit?XmqZ(eJ~Ah`R?v(E!!fJcA|@KvA%cnoL-z74bi-v!13KLEx7&oWl6d@vyS z!w*+iE58`<`9~kFrqAT_Pi;#bg#Th&>LL8hwsaZc=eDKB@fwe|pQ4X$z|Vjvqj=5N zqeoF*b)KHLJa22B9@~QRm==oRe9PS1yRk8+y4{L+OId*9bP;4Rt;A@z_3$WF-pkX?}7 zkUfxxA&ro|kbRK-kOPo|kVBBCAWuV{g&c-F4{3tD1UUj}h6Et5LXJXSha7{v2{{gV z8*&nI8uBjWJ;)iz2asPtehE1X`8DLXkWV4MgZv)y8RQR;KSKT!@?Vg@K>iB(0`hmr zmk?H^D+NR7e7(JA^}eK$0Q7AR44EBn5I6!~sc#q(KHlhCqfvoDdfz9g+bV z4atPKA>$z9Arl}T$VA9w$P`EpWIALfWHw|DWL_1;*0oMo0*xt~o#}%qw%&$wE^S0S z;(w>3+UXkHh;hhj=&r#uCOI&68G3wt8nqX7)Px$;mF2b(YLOVHTnV%u!1!gud4>(g zT#c%x)vyFH?%}*I7JMx62?^-0dzw9LTpDDKOYu18=*aO*$A0p2<{0)9a}0aM9K(KQ zj$tRvF|5@b!`?T?uwU8qjZDl@=Z%eswcwU|iot+8S&V_ZJOcMH#n2l)0{1bCp&#%F zJVY@R+C5D%6!0*!8QA0zc!b3odcY&_D2p@nV;+IWS-hd2^a#Yr@!1)ivh;HFFi`WP z8Q6DOrhzFGqaDPUqz!+M9WwAV`w=i@%zNxf15Yz{PxJe+2N|P!eMX3SeMX3S{l|)W z{l}u-eWD(-n%nGUZgYsa&2bm8&1X8e&F9Q*;{0)Hwc0XN-O>_QV{J*c)wWa}y{m=B zy;SybYkiE;H`2GnH`wY@+1{lwrm0R_J+3GE|3 zBdWi6o<7vBw(1q5-8P_T_i+EZ6ta)%G!8%?5uYpZdk2X=_6?;z>Z^Wc^sy5t z`q)35_&rr8489R4_yZ%+U&L`i{T&=5`a6_G?G9A68to1NMZ2f26E;zPiRwR{jea7I zC*q&=Qv1VI?;Gu#fTI23X>}>>9ZS^y`B~_fzJG`zNN<`8-Tb^se`#S|iigS(TabQa ziRkyMjy5Py6_T}os;LHygd!yYsCA^l`^T}ot}LHy}0u*Zyz zNPqWEv7Ptsu1o32ST=v2ttE|ZLHw5K$4niZr)TL|%4mPw=+Q;~Zl71mB1F zj0}SxMEs21Mt+KprM|57RO-|^D`DDgc!R^IZG&ZqQ#b+?|_woqrtqarw z0{dKS97cVT9lb`ZGVIYtt+JjMo89x}hjAFg*?k7?)N>FoQLOfWq3>c38MsHE2A*O! zoGYPzBc00t_vzEYQ!Mv{p&wvR8hA*b0sa^wYkW?0E-;wAd9q)Bh4t<4?A3jEniXOD zsU9&8uwIrFzs>1VZE0z0kL)ymDtqfxs{Q8Lfxv;lG+-KV5O5GMHG8suVD=P$T6T_q zQ1*0xj{y!f)spTXI5J&HgA9Ub0~|_FnSkSi;(HGv;`@nyith)roRg_ugjh<1Z!2!2plyMbmq%KI&tKVPR~ zE^-d<0lRxxGW^3l&$Ma_70pSwMA2M?wS{h zT16XAxIxh-5ZQuxn@G4t(Iyk#sc2IO?^d)N!g~~LI^lhaHk0rHMVn3dkfO~Y zd_>Xa5k9793kaW3v_*tZD%xVgXB6!w!sit2X2KU0Z7Jc)infgK6-8T4_?n{S6TYEn zg@h**t%&fHqOBtQxuO*lzOQH{gdZtd8R5r@R!;bdqE!<9tD>zV{9Mu26aI&yZ6y4& zqHQAl8%{0>|Dk9)VT?tqA+%YvTEcD?Z7X4-McYQ0WYOvfds?)+2>V#H?S%a;T0P+a zi`GCm$fE5e9BR?-Cme3kb`g%WXuAo=ShPKaSr+YK!s{$rBVo2h+e_%RX!{7KS+xCx zvn<*H!nqcWzNf=57S?+YCa`CW@6XfhIRg*7z&ChbN|4{P4Cp((Zzjm^S0?oPy&oj# zGx2@HIn+4lyS%?oU@sc&o7l?+9+?1sxA%(#eHQ8=Qkn>Tk5^69XG6!i))eRud;2Ep zH$Xqj*mUTP-eHOC6{CHCy=LIi+2Hqj$0g1K?(wPJ)9C+CKX5xl!Vm@03{dw;P zi5r1U-rpzcIU?3-yh+4be(*1Ozeo%eimCsLo(AYgz0Gcj_}uUn~{A2$Y#Vs1ZJ(@g{~% zd11;_Ys;O=`SGDX)3E*NfWX74up<;7;=F=%86Sq{E{K=$VR#OMc#048Sw_E20f9$| zkBkq)a~)J386WDijXnYafk&xMWPBK&AE9y?ABN{ih^P2azrpD5SU})$;v?h3@O%oD zM~e?B$BFZlla%8GzDxCFd>Ed8A^xJ{L!6&cy9>sLIB%o!i;fR*p4XxH5a)r^zl;yV z^GC#o;zOJ#l73!%i1S9`+l>!#UP<-ZjSq3&N%gwoLy?2~w#0{_TwX+c7|QKM#D}3= zUqpN;bAMg&;ko!QGQZLlA0j^LiVr&!AIk5Uj1T4aOU8#X$KGCiD0A==G4@h?D0A^L zK9o86Z)1EIng5gVVPqb#2Kz?wAv5!rGCq`fOc@`_e5Q;KWnNRphcdq@<3pL}l<}d= zcgpxs<~{df``>hY*j`>TJU)!UHIRS*qNas7k|bNr`T0cTkm&je&B;cb7?HarUdD?T zlTV~}6dy**D|R%WNPSRz7%i`ODfvX|kK)51) zJQpu^3<>0&WiEXc6|99jQ*Z$-Dd?LjzbUy;k!4w}x%P-3K@M7|b z)E~u%(ejEiK5Q?aNPSRz7%i_T-~H38~*pjt|ZJ;f3PE@O+{XADVf@^WsA@kJxT}Xyy^yjStN{ zVpn|F_MW~@%O}cv0VCqWQ2sC?K9qTl$oNp^H@f0Oas4{l_2#a8Vpn|F?)7H#UflNL zL-~EXU_Q~jcelOxu)TcZw<$jC$|rXGKF9EU;^$w8d}2DD)#+{Bf_oL+Titk0L{+w& zpI3~MN2K_XViAfL3F&$>VMp#iq=b(8@hf>*pd4Qsa~|}!;yJDsz>o1 z;l*4JruxyY0~3ES*MB4Pc;R_Fiqoh+%HI)&$6=(0$6us}$62IDi>JQr_vXoa2GfJ? ztwG$f=Ei-4@_xa#_X@sofqMntxWK)FZ=8&NuizUequncbp?d^75|dG#$b0%&WGodK zLq%Sbjkp%ul~X&HQ|r3s+I7vfqxa;6a%bi}urfB1xwQ7KwNAu1YUIq!d-vouR+($- z=rvY(58op6_uq9do*&zyd+%z1bS+h0LzUM~w}Yp9>g2Ujd5u(F8+{n%bkE#gAYBWU zd8>F_e^iz8uRli1Cw*0%5G`L6ejggmVUck{wEWOl#R<{!KjHV4(Ke#R36c38s@D}K ze0%QU?}`(Q`*!zrp=3)BJv`>>CU^1dE{{d9pT0fGIg5BaWw zUU-(M@4!^0H?SX&o-IlN_5t<>;@M+AJ=@f4fWn3l4pdk#W;@lr8qW&@uK@B%E`QRb z3_r{~#Xorf?+s(8l0JEa#UR$k;+%EQu&%s9`}jVix=lgg|7DzDa6zPU!_TWVE)$5xf!xlQHYuT%Nmcd2~Gc9q{#ukw2v zRDR!1mH*&=l|QgclpChgJT<^D2L_N##FzN#!peQF%+V%3len{HL#~{I#Ph|Jmy*f8&_S z-+EK!CyuN9owrqf>ZHnBPpkar@2dQt-&6VfXH@>-2P*&Q7b^eNFIE2WS(X3h*DC+y zw<`aaPgVY}zf<|Y{a)ptf2Q((|AWf^|Lw0T|G!_T{2zZ; zdGJe>$HdsV#bW0+o1MqS*?G5ac5b)Zd19iSUvY(w%&z^SPySJV9>0{^p z`q_E^{&s%#)pkB$fSnHwb$DDb=TSX_1D{ZcD9{Qnq=o*uboewYUk6Y+4+nac0OyCo!@YSozI=Z`7?Vf3nu%Jz;M(GY?}kX5&3kix=zqk|j09d!uI1d!s&Rz5~yx*E)x!;k{3e zpwpSa@NmM6xU7aOJnNV|d`MBUWthK=iDznz_lTi<(u|o6l=}!e)r1^;eZ+I=&V*ie z)p+lejyi5tEu!brsoxrBd;)D}ANudrE&g1ccsvu2XWlRLyq!MN-Ov-yviHQZ?87Xh z{Vq$UKf~hoXIjSLdy{jrFP^2OXYu+*JX6{+~&EK!~k(XRisg6|*J zuAkYC?p?RxS$9IT_I%0wXoxQ*f9xqOzi|a^h9CDc-Y~A9g|^4PAAw= zZ~Egp*pUr8ys%?hC)mNe3xCivhICB^L(?#iK$nj9FF;yaH_%#!D%z+&XV8`N4utSbtvtY+u*s&0HEa?P0{tprmcx$1 zPOw9cU*Cltxv(P-cC6?GJLLFvGwfIfJMv*iQ770T$FEyphYxlXz>bxjV8=H$ekGaX z7h(ALMLLaNgkj^?D)^xUekg|@)^&m(V_+dNz&;UQ&-wA$@ z|9K0aR8yLAi8=hZb+E{^hfB0qoaqM4!k=pl>0N7m9sFqUZt$LX!4C&N z68sqOS>Tr?2!0&+3E(G!p8{Uzg1;90_24Igp9;P)S@7e*d%#Zyp94Oi3I00p+2Fn4 zr-474B6yKw`0{;`V=&OAGXpaM0yC+O&JAu{K;Q(HWatw;0;jaiF(`OXFphg+q zcI-z3_Tzr+hxGUN!S4WnFZdsTm;SyBd_DM`;CF$S{=OUhJ>c&H{{VRD@7>@Z2EP~l ze(=)Y4}$+8_(#D%4qp0u5BNs#`@kOnFa7-x_(#A$2L1`~-}SH-Y}FDzVSZRzvr7UW?m z7hByDi*zY192Yfdu?xVpPqElc++&ujrk-ZvpT7*2(&}Qf9OJIBY%+CI4Z+NnCHOh1 zaSr@uv18`U$&QS47ycYWMiiHptSBhYE%Q|t(GO4MR+Q#e7U74f%FEVNho$4pLQyS_Er{6GaI^kzhnW>k zaGU(LYufUYOy22i%g;9Xlo4(DB_{vb$hLgG$rs<;mN!0R*zcf=jbY11V?B>VXTDy z;Ry*WajY?d|H$Nr;CHpdm*1IuAcFss$sa2YuZ=gV`BD5B30$H(dc|mdR$n25n=4yI+MpkGwd$L>|IeCm+%@`G`kdc`YV%-zv z>Zs9UL#tCltT!vPn&o!O3Zuu4maAhjyrI=GAvVh$+VI$knXrR>?cXMRML|J!adUmMplu^>ikt0XT4Ucq>k(Q0jPM;L{`*M*w{oY*e z4%O zXp!?0e%g!uB^WG~?dALRXn}`AVV#Wwgk0X}`p)QDzA%m-C}U zxqa$`7TSmKA8Qh{g2O4 z#+qy|aS>SQyYQ+--!nDSutca=L8Sfi_@QA}+HSTmpvW&Yx&N|EZddB$-e~R5nC*QL J4I)(b{{R^JSgim6 literal 0 HcmV?d00001 diff --git a/hsa/gfx942/mla/mla_asm.csv b/hsa/gfx942/mla/mla_asm.csv index 9c02b8174e..4515be0b87 100644 --- a/hsa/gfx942/mla/mla_asm.csv +++ b/hsa/gfx942/mla/mla_asm.csv @@ -5,6 +5,7 @@ bf16,bf16,16,0,1,0,0,_ZN5aiter39mla_dec_stage1_bf16_a16w16_subQ16_mqa16E,mla_dec bf16,bf16,16,0,4,0,0,_ZN5aiter39mla_a16w16_qh16_m16x4_n16x1_coex0_mask1E,mla_a16w16_qh16_m16x4_n16x1_coex0_mask1.co bf16,bf16,16,0,8,0,0,_ZN5aiter39mla_a16w16_qh16_m32x4_n16x1_coex0_mask1E,mla_a16w16_qh16_m32x4_n16x1_coex0_mask1.co bf16,fp8,16,1,4,0,0,_ZN5aiter41mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_psE,mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps.co +bf16,byte,16,1,4,0,0,mla_kernel_func,mla.co fp8,fp8,16,1,1,0,0,_ZN5aiter36mla_a8w8_qh16_qseqlen1_gqaratio16_psE,mla_a8w8_qh16_qseqlen1_gqaratio16_ps.co fp8,fp8,16,1,2,0,0,_ZN5aiter36mla_a8w8_qh16_qseqlen2_gqaratio16_psE,mla_a8w8_qh16_qseqlen2_gqaratio16_ps.co fp8,fp8,16,1,4,0,0,_ZN5aiter36mla_a8w8_qh64_qseqlen4_gqaratio16_psE,mla_a8w8_qh64_qseqlen4_gqaratio16_ps.co diff --git a/hsa/gfx942/mla/mla_page64.co b/hsa/gfx942/mla/mla_page64.co new file mode 100755 index 0000000000000000000000000000000000000000..3beef5071c2e0e21eb9d001aebb588e7f291bca8 GIT binary patch literal 33984 zcmeHw4SW>U-R{Zk=7ZTJ9uUG10xlsS!m@+}!UidBAdBHsK&9wa!a@Q8LlToffzo6M zfzs3hrpSww`bHWlQnctrjTDpBqD2J-qP~b9MXT0&t=G37-qvf~`~2t3*-Vzj1kn6^ z-Vz;~vetS|R(?)#Mm)Kq7;XLb{tgL<^Ri(P12hU?@oEnygcg`g$M z-HZkO_e92=z?_-6*U!G0v6_edD+-D#{O&sn%S#H2EAGDQ0k^+=X~ml?ebd8k|KiG( z{^ET9;>Cp(75mEz{RR2oD_>Swh@66ga&Y+-W&Xv5dzTiMF7g*Y;4bm6D7^oMU%E?| zEU750WDmG2maQ#}`;~i@zj$R~{vFFo3LX+g^B1jLvZSzlqp0$WGJk=r6c?e=zC~;D zjV5GEm21ihf3Ykx<5FFmBe%vQY^_5VwFs3u?c#$*FP4@RuldSOoGv|dM%Y+~9uA98 z>Ff4zp7bz1!p1uE&=aB3*X-e*NjFIsy%Dz7p^JGDDt+B9PM#rMoE>3n9lAI_LZz?U z#p$6h%fbj->(Irb2$e4AqB$NkRxB=FDm$i3G+>wCeHukNYQQ?6R9O*{Bk-u_rQK%QTT|yJ>dSUzJHn8(@ z8`ymb4eXgy5$b}5OQ_$j3mPxCfqj?Tz=2C>V9%tsE;x7z_1ks9q04Qc>2e!5atRIW zncUU|flH|0t_zM{ZUe_Iw}Imq(?EET*SQF4&z#WgfBdA_+{Fc=$InElb-{_}5}mxU z&B=2SHrL_odi=u(wL0x&CnNejb7zN!pid)guVXL&5usM6z3k-Yog0c%PDHT?+w0KF z6H##ys(tNlh9dhDQ9Q!-I^m`kq1pxAeA2zNymVz*ensKZ6@?|0`Ae3qE-c8W$@nEY z`G4Fv@xvKksDQtC+0qjGW9RWyepqCA2%cjmCi}b`wPbdTdTEXHq{xj;vkR=mmSDd=Z)s>~My84q3d^V)XXpyJu zA|>jwu>A^8V?R;6z3X06w~*~v@j4`~v&FMoY*gwLQNgH9!5}hXV*=*@bxd*ADRyVj zZXX&%xuBv5WX!-#EXLV_b%oV5+t`L?JF9KBMJHBO2kvSa5FB8%l^7QjKvQfB+NAAJ zo>J?n9~2y9T^}^4W_BOND7Otrdc;-1tBiaqfAc)na&_{?mYW+q^^DzxP1dz| z1`gEqtFEfqnjq@EY}VU)rDAO7Wneto+3HXN)D15KdsF%fC7=PB5)&X`wiqLy*<%7u zN=L^8(h0SgfR`{PCNP_@Z%kkzAy)!r054Dn>c9pd>!}0|0=&QgrITU;pAxEQA0G?t zN7j0e?y2!?@dy3JzAf}D)cv}@=2LSYsLVF@baOCgdOdw8!JKLJIy$2^$5TI@)is}F z^{_3l9k6#{yPAXS$IW!a1lb!V{?x>+CjOg=e=zZ92KET{FnqWrc#DBOgFOvQ2qqZV zC)me8>1)4WKO=oj@EQYA|A(Y9p}v8rZy@R$i24SizJaK3AnoIN_S@!HRPnj#>}2x| zvrg&j*@0$z&B(2_lc&~iV83s+Z&IkSZzyIP4<@C zgSQ)g+?IWte!G5qjT@gUbvL!|S^uP-?aSVhJs}VEJX<{Ln{-be`renZHIFqld*+nW3XN`-e#j!D`;-IPdxXan1xU2g1>@{?x%T+{q!Ct*M?Zbi~j@>lI zJiU4qF+TfToEoD~<-BP2Ecm6ks`|2xr>D6Trm;;1Dok(b8(mhVMVD9EvF_nq&+gf? zX3lKv%QW(*EtP%Ita~TG*E#0~GP}-tn%K+4D^2Vza9wbn zPGW3u0_u<+`(LEqVC-j*_A~Flj?b6cx0N@Du`^BX^S1x~`y8vYmLV8382j>#tFXIX zfp%SBTtHwvi!$=LS72{lpcW9A!W1Ll=@po!3mg^@=wdb_Kiw;Ej4sd<5IBL^jeM_H zU@nbQ^7EkYnsE%C`TyYaST8;geYVjbEFjR%<{0TXufTX+AP)%a&E^{Unpa?oF3=ee zn8t22@`rf^x^#i*0fA%KO-8=QD{z7?&>Ilg?(>MPF^`%1jXKOL(5?%N3kZy7dyIVU z71&!Bs09S3uwNMYPOrc;UEr{QKo@H?^3%Nn$LIn*0f7_P14h2rE70(<> zJ+7sXL3XHFvBks&*<;Nt${rnLo0_lP-*jCJ=14E*ycm~D?dx+_1(lSlV6ZwER0bNT zD`{1_G8E>7rB>xCT2)fCTGcnYs%ix4MZ4Up*Oj661=lZHuiDD&Lu>5zv|udeSj=f2 zyS=D}#k8ekVv7D9bKC}(D{UilJA-a_JnMmwZltS->Nqi9d0Yj{#N4Q)#RgLiPEQr{ zD3uMSVjjgD5ga_kK-$Nka;6##{60%q9aP?|4q{F1H;nhT4J+z#&9EYxUueGXeJEP@ z5+07$b;3Zj9w2-!8gnb~IPwWkBcJdL@(JHYelOsA(K>DOp9hMo55?HoKM#}vqk*fd z4;$%qKuXtE2aNP)Af>leKWC)311Y_``nZv908)Bi^=Tu007&UW)n|-!6Oht@>bH&b zQ6Qy{SHEYZPXeRaWAdkWa{gI4wo@V6UV*>8c*{-~85 z13rIuD|;UJ;ytbGCE&|HZ)Gn7Pu<_jUID)PU@LnS_#Y3ova`T99&2T90N;ADm7N2= z^K>hF2l$6)TiFM|Kb~l1e+2&NR4WG1U1uJB91sOO0#ty{0ByiyKs)dyU^MU*U<~jq zFcx?Y7zg}-v4x1!7OY>OFUVi8X@MTo5=(OkJ}-*f=yRahZ6RX$9&GHX9(N+*dkZ~YzeFgwj5RrD}j~4%3+nTRj}2twXk)tDwqzdfz`q`!8XI{ zV0Xc`z_!8aVcTIlU^`*EV7p=W!5UzDV2!Z7uzj%oumiA%VFzK4!VbY6hc&?t!;Zk3 zVFB1Pu%oc&V8>uDz>dRSf}Mn&hP?uN4R!{07WO9WZP+>3yRi3QAHY6@eFXa$_6h7$ z*k53OgM9}32kdj$7qBm3tXfwTm>m`ai-W1KcvvEg!+ODb!;)YcEE$#p8whj4QekPZ zp|D}F;V>7>4NHe*z{bF`U>?|b*aX-_m=`t)mJ6E#%Y#jW&4A5<&4$gbrue$n<&LK@ zWP>YxD8<*C5u0nH;&2||a;|r|hc;l`u^VakP#S}r7^{qQTwEH}7iHA=8kCj!(WBHN zF)q2|X+41PDjMe-(KzO6R5h)J#fxzc=NK{IV~CHBM|+8B)^TVE#-WQnrs;U^{g^p! z{l*-(o;Js=-3v>-2PmdOy@M1}0S~cg1Dm`8kFXdc9q~yk0{6N<~EbeZ4NWHIsPKH`A8?X`Ixy)oKIfAUTwKu z-PjUaV{dUs*S1t2y{m=BwN&;_>$WH*IlOO)uaB-vA>7`hE`{cy$BcQP3>eLb-;qH4 z>gwMZ{5qiEclNAH3IA-U{H{Ktz1{t&UTyW$M!n5IQSZJ1bt&W*m2DV=Ho`tn;`a;@ zZ8Tm(ZEUOlozcd2plD<7NaA-_pD_3apy2n7Mtfn$1GTq*tZ45*Hr3l#-D=c302K8e z{*Lg8G$<;6FbC~~9Y@4J>ZAIHs$VziHvvWcLsRQgIy#1^{^K*zE`9$HH&D814$|i5 zOzFcn*QIzVAMpgGkK7{K{i67d0Qo>^j;!kgce`b6{=~sRvKDXC?0)NbS>0-yb8|Oy1q_K^N>$0ja zqX*~e*?P7zrYd&Kn4%lEHtAxFcg1_4?O?3NxSn3L!Ie=ol)ZQ|X{^Em;87391v}N> zw~G{#?;!6W*_Aa|s_%w3`SYRJ3`74=LJw!bcSC zR>H>=?KZ+E6>S0GQ;N2b@M%R`MEI;*u%}I{BkW_-?jr1G)3y)}uxZ-}2idfG!XY+oJK;4pZ3p2{#Q6Rk zWRDqm$PK>UcQ{^t&oYqS?t39#e!sGi-r+kNug}2u4d+nfk>2V1D4soO)Nf)>8F*wO z_+7rwoH~W$k10ub>4(Whzc%nYl$ltjI>1TZ7 z69Xc>yB_JIzG;d2G$X%Z2hz{^ZcPk`bmK0hkNK7->eG$tj9<|OH(z*nf8j47@94DlBo zQ{tSC>Rm9V#JL{jUvx}~bHh%>lsI>!_GL_I)t875#gsU=q_m7Ft-2HOp_mfqo|Nt| zro_1^mFqC3#JMY#J1?fhxome#DRjGUYfKr^|H5L*kRBKoQ-<`xu$WTnh21e_trF;#a=*{PUPe$Qk~DZgJbrj)vTM=_<;?I&S;rkGOd_cErGI(}zkN~!mEWlR~a zGs>7UTyLzwK2l7{Or2E5lu|F1F{RW^WlSmcQyEiA9aYAZQcsmJrPNhrOeyu%J=p#? z9aDCse_CV8C|rYi_4Aq*;>bos*PN$Y(!5PEWhDLbQglmdk7CM5`ekQzOKO8+%1HX< zrRbK_9>tWA^vjFUEvX%fDZ};43+a|tJ(Bdv&ghsFV_J1e(k-nzCdHUik0gIcpX^9G z505deIwk3rRvlBum{#4=s#j91N!yh%rd7AJ>Xj60Qn^Smrd6*b{gTQ>iZQKvCFz$` zj$%xyS5lhv%kCKS+ZAJmbjoo3QtFgpv8L2d&eJbVoibd%?2a)b$C%x^Ww&k_Nw<{W zFE_?Z(ksJtOc`T#q+3dTaw6Ixy)sqA=zqZ~Lx@FP7&ettzUZ$8bl71;;%8Su0sXdA*Bk7kirtC;5hO8I@epl)g2SKLud*^zG9 zQB3(&y5%<&Q+Df?oxg9>s#~^xE_BOuJd4!Vz7h9cdNz6R9E_@LJWs!jq+43Ak<)xm z*Te}sb6+Ku3%}+aNr$vv^QJmEIg(PkEOz|0f$}t?bhSk z^tkS8yxrG$JA2u)0Huj)Zse1wmS*qrlGf)_ZP}*x>Bd>>@{6^pJ7*C)2)K< zbYEZ%kgnOvYqjzk?H2HKzn{D&E3d`MYq0krpYG?|1Eg!NQtOJNdwG@fuX9GyqP{9l zh@>@H@0FuDD?Cn!q%D0_oDfMnvflef+lUk=glk1qt~*Zn_B?mc9VZy~6gPHjEEm>T zy5ofIIN@UB1gW{mI6-PFosARZJ!3LXkoSglHcpTl%zu|S0Z;z6u-EZS7{|%nTR^;r z%s7X7-F)^-U|(IJ77*Bvy=CMldjMDBvDru0yh7k@{SQ3jqmAD?yLIbY=^2zS1o|7}G;N~g*{=Il-8avhV{rg!I zVjb>tR`v9yR?&Jgp2O-}J>T9R*spq_o!@+u%D+2L<+sdN`Tw|8<+tCa^85uV_b*iW z;zcSiEKvEAvQp*WTcz?ft5yE}wJQI?I+a&f zseHYz@(ndA-&m{iA8u0lk2b6P$8{?I$z3Ynx<%!8Z&UeC>s5Zwc9s9z4wc`#Q{_L~ zrShNeR{6i*r}Fz7RQ`)SDu19+<-gpk@(1^+{8#%`{?GxH|N3E-KXOpz|MO9mKXypv zPdu*jC!19Mo5L!9>WIo)npOUEK;^%EM&-{QRr&9pQ~C49RQ}=%DnD^t3^O z-g;W)ufC%4|9DO1ub)x*8)sGi)|)E-{o5*k=bXy_@UF_=e^2Fq{6OV@`cUQn>m!x_ z`D2y;?@v_zpP#DyuYXbb-~Xoa|MQv3|M?G<|KHD5{;w}o9{f_}QBe+VvpINlw1daS zI(UyB4(@O`ctV1MUvY(l_w4E5S6=Dhefl_f-@Xpsub+eW@9*FP1~~XtS2_5gK@L86 zu!9d7;^0?b?cmp3gO48V;F*~YK6b2wj~nOU+1U<$?X?d6o$omK zb=Ns~PL6|5p6uX0pMzh2y@OAk>fqC-JNV3*4t~Q84nAj&gWq_guHR#Nl~=JkNijX9D#Zp87s`PQMSH(;sou7}T5R&T(yf*)Y1*Wav1_sU1(`Oy*c;Ku^^ zu>^jUbiE%s-pOJ3F&};`gdaulqpa)w*e3hqR`{_9ek_L{SHH;m2n9v90U<*e3g9E&Ql~A9e7fzU%$q ziK0K~`A51YgQ01-SD;(R`y`Nd1q7xWWqKG~hF4&go?xWM1O$3$e(z%&R>kPf**U}$A0+nXxIA@8ov&~k0bEoDEv6y^?roLugBp>GyHfC ze!SH6euT!aCioG6AIIRw$*%V!G=3e1AJ4##7vRU~uJ_{`8^3y*;}@ZI{Gv3CUxe22 z>lO6Fo9Ku4&<`JXeLsZ8uh-zm+wkK9`0+{C`w<$y&cKgz@Z&@H@oCrl5gNbF!jE_1 z$4BtvFJ13PX#Dyc{P+TX*t69U@z-{Jd@kdS&)~_P~Jyso&f__LtKe)QSA420- z68snlKZe2&ch~z78oxC7;e;Q<;75Ab`w<$ylHo@x{1^^DGP>T6(D*e5eoTNLx$t8~ z*ZUEQzp~)RMEEfUe$486KSJY|2Yz_rM;`o`-SvKi#;@`4V-ozB20!L@y&vD$_;sZ@ zei2&7FG|z+MQ9nn9*(LJTAt3>ccH({pm+1f#p!yF9u_^Xu9>tr(({Bq2R$#2G1Bp3 zq4)KMmew0u8fkg~ywiXDccB5s$Lo4Rf@otrX@XnOhG%Gv$CDPD?WrFREieW3QYsje zI$#Y`@!nrgy%%juolo2g6&EkwkMD|)i^2=kp!J#BpcC(V3u%8wJyZK@ukpE{@wuV# zr9&(E6fWRy&qjQ5B9}qZ^^)%9xyaK1R*BBJM^BBj%cqL3@F!_mLe1^$S z3gfd(eo7lZVmtO@H}<0u`*Fb1-#ftH2YxU3hb{fR6MO^sec%sT`g<4nJ>d6)f7H_7 zhrk~Je-!+2OMgEOz8U;;;9s)zcN6#k_+#KtTKfAi_-DYs0RHqh*xyD>fqtUzp0&S? zn8M_({Y^aeHcw-I}oytThy!G64n{df=i@v)`9UjzR(_z%E;V(ITQ;Lm~o5d5c> z{yq!-UGN`)|BI!+{|5dG@b=M&!86;A@6W)02|fmVVrJX%{SWYL6k=5HJhSci{v5ml zUIpJP^P4-qd!yfmqo4YO^*8ZulehLa@zmc`-rC>9)A3EbwZD6hM9hc%NW*@(GTX+J zB=7^l4+ZbeY#UEB@J{f2V1Fl}f2hCv0Tlw${Y^aeH}RJK zPPVzAi=Gd*J7z@8>C-v}L+wAl49dJnT$I4*3`;t+uA zpJItNakrU|57AiCO#J(o!BSeCZ{}m{HI{W|+DwLE#C~E?HJwSTU-ieDSE_;#Dg~f>~a;xY9^C z(+djxnf`($W3$F(Wfi9T{fpBJ(lbYo%UoPokU4H_cHwCJp!4EIi?0c}eclSDS1hq? z!jIc$(eF4B--s8hS{J-G6H6>RCOgi|6vbj$S=hEln&sIDv!DqclWz>;C!2hpt9|)d zCO=_Rd;S)a=Naw!0+TPgy*+Py$ngIVU2wE4YmN0-_C=ww8>a{iGhhXI;@Xj^eGTI9S$&3_gQmdg5a{*fq$ zS(z{E8;#&!KNz(w^W}Uc(TO}-WWMxY;#J7AS@PxlC{b>o+MtE@!Ma?H1aY!``8$jf zz1G554^Pu@PW9#d%~Fw|1uu}MC3l&?(El(8#v+62(mv24$BDe}i*%U)qmO7+_8;A@ zMax7Oak74s7{FPRSwBy96cxr2ws|(w{Z+Ytf>0X3BId u0m*?d>AyUF7J!q!n>7q5sun4^|1wW*SEh@w5~)7@$__35unJ*G_WuB4?3dI4 literal 0 HcmV?d00001 diff --git a/op_tests/test_mla.py b/op_tests/test_mla.py index 7322b30bec..2fc509483a 100644 --- a/op_tests/test_mla.py +++ b/op_tests/test_mla.py @@ -384,6 +384,8 @@ def test_absorb_decode_bf16(): kv_indices, kv_last_page_lens, max_seqlen_qo, + page_size, + nhead_kv, sm_scale, num_kv_splits=split_per_batch, ) @@ -429,6 +431,8 @@ def test_absorb_decode_fp8(): kv_indices, kv_last_page_lens, max_seqlen_qo, + page_size, + nhead_kv, sm_scale, q_scale=q_scale, kv_scale=kv_scale, diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index fdba21f97c..f0c426b638 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -24,6 +24,150 @@ def check_support(dtype, kv_dtype, nhead): return True +def init_3buffer_kv_cache( + num_page: int, + page_size: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + scale_dim: int, +) -> tuple: + """ + Initialize KV cache for 3BUFFER layout with FP8 quantization. + + Generates random KV cache data and applies per-channel quantization to the nope buffer. + + Args: + num_page: Number of pages + page_size: Size of each page (block size) + kv_lora_rank: Rank of KV LoRA (nope dimension) + qk_rope_head_dim: Dimension of RoPE (rope dimension) + scale_dim: Number of scale factors per nope buffer + + Returns: + tuple containing: + - kv_buffer: Concatenated buffer (BF16), shape (num_page, page_size, 1, kv_lora_rank + qk_rope_head_dim) + - kv_nope_buffer_fp8: Quantized nope buffer (FP8), shape (num_page, page_size, 1, kv_lora_rank) + - kv_nope_scale_factors_fp32: Scale factors (FP32), shape (num_page, page_size, 1, scale_dim) + - kv_rope_buffer_bf16: Rope buffer (BF16), shape (num_page, page_size, 1, qk_rope_head_dim) + - kv_nope_buffer_fp32: Original nope buffer (FP32), shape (num_page, page_size, 1, kv_lora_rank) + """ + assert ( + kv_lora_rank % scale_dim == 0 + ), f"kv_lora_rank ({kv_lora_rank}) must be divisible by scale_dim ({scale_dim})" + + kv_nope_buffer_fp32 = torch.randn( + (num_page, page_size, 1, kv_lora_rank), dtype=torch.float32 + ) + kv_rope_buffer_bf16 = torch.randn( + (num_page, page_size, 1, qk_rope_head_dim), + dtype=torch.bfloat16, + ) + + # Create full KV buffer (for golden reference without quantization) + kv_buffer = torch.cat( + [kv_nope_buffer_fp32.to(torch.bfloat16), kv_rope_buffer_bf16], dim=-1 + ) + + # Generate random scale factors + # scale_values = [1.0, 2.0, 4.0, 8.0] + scale_values = [1.0, 1.0, 1.0, 1.0] + scale_indices = torch.randint( + 0, len(scale_values), size=(num_page, page_size, 1, scale_dim) + ) + kv_nope_scale_factors_fp32 = torch.tensor( + [scale_values[idx] for idx in scale_indices.flatten()], dtype=torch.float32 + ).reshape(num_page, page_size, 1, scale_dim) + + # Apply per-channel scaling and quantize to FP8 + kv_nope_scaled_buffer = kv_nope_buffer_fp32.reshape( + num_page, page_size, 1, scale_dim, kv_lora_rank // scale_dim + ) / kv_nope_scale_factors_fp32.reshape(num_page, page_size, 1, scale_dim, 1) + + kv_nope_buffer_fp8 = kv_nope_scaled_buffer.reshape( + num_page, page_size, 1, kv_lora_rank + ).to(dtypes.fp8) + + return ( + kv_buffer, + kv_nope_buffer_fp8, + kv_nope_scale_factors_fp32, + kv_rope_buffer_bf16, + kv_nope_buffer_fp32, + ) + + +def split_3buffer_kv_cache( + kv_buffer_bytes: torch.Tensor, + page_size: int, + nhead_kv: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + scale_dim: int, +) -> tuple: + """ + Split concatenated KV cache buffer back into 3 separate buffers. + + This is the inverse operation of concatenating after flattening last 3 dimensions. + + Args: + kv_buffer_bytes: Concatenated buffer (uint8), shape (num_page, page_size*656) + where 656 = 512(nope) + 16(scale) + 128(rope) + page_size: Size of each page (block size) + nhead_kv: Number of heads in the KV cache + kv_lora_rank: Rank of KV LoRA (nope dimension) + qk_rope_head_dim: Dimension of RoPE (rope dimension) + scale_dim: Number of scale factors per nope buffer + + Returns: + tuple containing: + - kv_nope_buffer_fp8: Quantized nope buffer (FP8), shape (num_page, page_size, 1, kv_lora_rank) + - kv_nope_scale_factors_fp32: Scale factors (FP32), shape (num_page, page_size, 1, scale_dim) + - kv_rope_buffer_bf16: Rope buffer (BF16), shape (num_page, page_size, 1, qk_rope_head_dim) + """ + num_page = kv_buffer_bytes.shape[0] + + nope_total_bytes = page_size * nhead_kv * kv_lora_rank * 1 # FP8: 1 byte/elem + scale_total_bytes = page_size * nhead_kv * scale_dim * 4 # FP32: 4 bytes/elem + rope_total_bytes = page_size * nhead_kv * qk_rope_head_dim * 2 # BF16: 2 bytes/elem + + nope_flat = kv_buffer_bytes[:, 0:nope_total_bytes] + scale_flat = kv_buffer_bytes[ + :, nope_total_bytes : nope_total_bytes + scale_total_bytes + ] + rope_flat = kv_buffer_bytes[ + :, + nope_total_bytes + + scale_total_bytes : nope_total_bytes + + scale_total_bytes + + rope_total_bytes, + ] + + nope_bytes = nope_flat.reshape(num_page, page_size, nhead_kv, kv_lora_rank * 1) + scale_bytes = scale_flat.reshape(num_page, page_size, nhead_kv, scale_dim * 4) + rope_bytes = rope_flat.reshape(num_page, page_size, nhead_kv, qk_rope_head_dim * 2) + + # Convert bytes back to original dtypes + kv_nope_buffer_fp8 = ( + nope_bytes.contiguous() + .view(dtypes.fp8) + .reshape(num_page, page_size, nhead_kv, kv_lora_rank) + ) + + kv_nope_scale_factors_fp32 = ( + scale_bytes.contiguous() + .view(torch.float32) + .reshape(num_page, page_size, nhead_kv, scale_dim) + ) + + kv_rope_buffer_bf16 = ( + rope_bytes.contiguous() + .view(torch.bfloat16) + .reshape(num_page, page_size, nhead_kv, qk_rope_head_dim) + ) + + return kv_nope_buffer_fp8, kv_nope_scale_factors_fp32, kv_rope_buffer_bf16 + + def cal_diff( x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False ) -> None: @@ -50,13 +194,12 @@ def ref_masked_attention( q_scale=None, kv_scale=None, ): - if is_fp8_q and q_scale is not None: scale *= q_scale if is_fp8_kvc and kv_scale is not None: scale *= kv_scale - attn_weights = torch.einsum("qhd,khd->hqk", query.float(), key.float()) * scale + if is_causal: s_q = query.shape[0] s_k = key.shape[0] @@ -67,32 +210,82 @@ def ref_masked_attention( attn_weights += attn_bias lse = attn_weights.logsumexp(dim=-1) - m = attn_weights.max(-1).values - attn_weights_exp = torch.exp(attn_weights - m.unsqueeze(-1)) - l = attn_weights_exp.sum(-1) - if is_fp8_q: attn_weights_fp8 = attn_weights_exp.to(dtypes.fp8) attn_weights_exp = attn_weights_fp8.to(torch.float) out = torch.einsum("hqk,khd->qhd", attn_weights_exp.float(), value.float()) - out = out / l.transpose(0, 1).unsqueeze(-1) - if is_fp8_kvc and kv_scale is not None: out *= kv_scale return out.to(dtype), lse +def torch_mla_extend_3buffer( + q, # [total_q, nheads, headdim_q] + kvc_cache, # [num_page, page_size*(nhead_kv*(kv_lora_rank+scale_dim+qk_rope_head_dim))] + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + page_size, + nhead_kv, + sm_scale, + kv_lora_rank, + qk_rope_head_dim, + dtype, + is_causal=True, + q_scale=None, + kv_scale=None, + scale_dim=4, +): + num_page = kvc_cache.shape[0] + (kv_nope_buffer_fp8, kv_nope_scale_factors_fp32, kv_rope_buffer_bf16) = ( + split_3buffer_kv_cache( + kvc_cache, page_size, nhead_kv, kv_lora_rank, qk_rope_head_dim, scale_dim + ) + ) + + kv_nope_buffer_fp32 = kv_nope_buffer_fp8.to(torch.float32).reshape( + num_page, page_size, nhead_kv, scale_dim, -1 + ) * kv_nope_scale_factors_fp32.reshape(num_page, page_size, nhead_kv, scale_dim, 1) + kvc_cache_bf16 = torch.cat( + [ + kv_nope_buffer_fp32.reshape(num_page, page_size, nhead_kv, kv_lora_rank).to( + torch.bfloat16 + ), + kv_rope_buffer_bf16, + ], + dim=-1, + ) + + return torch_mla_extend( + q, + kvc_cache_bf16, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale, + kv_lora_rank, + qk_rope_head_dim, + dtype, + is_causal, + q_scale, + kv_scale, + ) + + def torch_mla_extend( q, # [total_q, nheads, headdim_q] - kvc_cache, # [num_page * page_size, nhead_kv, qk_head_dim] + kvc_cache, # [num_page, page_size, nhead_kv, qk_head_dim] qo_indptr, kv_indptr, kv_indices, + kv_last_page_lens, sm_scale, kv_lora_rank, qk_rope_head_dim, @@ -101,6 +294,7 @@ def torch_mla_extend( q_scale=None, kv_scale=None, ): + num_page, page_size, nhead_kv, _ = kvc_cache.shape is_fp8_q = q.dtype == dtypes.fp8 is_fp8_kvc = kvc_cache.dtype == dtypes.fp8 @@ -118,7 +312,9 @@ def torch_mla_extend( os = [] lses = [] for i in range(bs): - kvc = kvs[i] + cur_num_page = kvs[i].shape[0] + real_kv_seq_len = (cur_num_page - 1) * page_size + kv_last_page_lens.tolist()[i] + kvc = kvs[i].flatten(0, 1)[:real_kv_seq_len,] q = qs[i] k = kvc v, _ = torch.split(kvc, [kv_lora_rank, qk_rope_head_dim], dim=-1) @@ -157,6 +353,8 @@ def test_mla( decode_qlen, max_split_per_batch, non_persistent_mode, + paged_layout, + scale_dim, ): ret = {} @@ -170,6 +368,7 @@ def test_mla( kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int) seq_lens_qo = torch.empty(batch_size, dtype=torch.int) seq_lens_kv = torch.empty(batch_size, dtype=torch.int) + kv_block_nums = torch.empty(batch_size, dtype=torch.int) kv_last_page_lens = torch.ones(batch_size, dtype=torch.int) if varlen: for i in range(batch_size): @@ -178,22 +377,49 @@ def test_mla( seq_lens_qo[i] = max( min(random.normalvariate(ctx_lens, ctx_lens / 2), ctx_lens), 1 ) + kv_block_nums[i] = (seq_lens_kv[i] + page_size - 1) // page_size + if seq_lens_kv[i] % page_size == 0: + kv_last_page_lens[i] = page_size + else: + kv_last_page_lens[i] = seq_lens_kv[i] % page_size else: seq_lens_kv.fill_(ctx_lens) seq_lens_qo.fill_(ctx_lens) - - kv_indptr[1 : batch_size + 1] = torch.cumsum(seq_lens_kv, dim=0) - kv_indices = torch.randint(0, num_page, (kv_indptr[-1].item(),), dtype=torch.int) + kv_block_nums.fill_((ctx_lens + page_size - 1) // page_size) + if ctx_lens % page_size == 0: + kv_last_page_lens.fill_(page_size) + else: + kv_last_page_lens.fill_(ctx_lens % page_size) + + kv_indptr[1 : batch_size + 1] = torch.cumsum(kv_block_nums, dim=0) + num_page = kv_indptr[-1].item() + kv_indices = torch.randperm(num_page, dtype=torch.int) qo_indptr[1 : batch_size + 1] = torch.cumsum(seq_lens_qo, dim=0) max_seqlen_qo = seq_lens_qo.max().item() max_seqlen_kv = seq_lens_kv.max().item() total_qo = qo_indptr[-1].item() - total_kv = kv_indptr[-1].item() + total_kv = seq_lens_kv.sum().item() + kv_buffer = torch.randn( - (num_page * page_size, 1, kv_lora_rank + qk_rope_head_dim), + (num_page, page_size, 1, kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, ) + kv_nope_scale_factors_fp32 = None + kv_nope_buffer_fp8 = None + kv_rope_buffer_bf16 = None + + if paged_layout == "3BUFFER": + ( + kv_buffer, + kv_nope_buffer_fp8, + kv_nope_scale_factors_fp32, + kv_rope_buffer_bf16, + _, + ) = init_3buffer_kv_cache( + num_page, page_size, kv_lora_rank, qk_rope_head_dim, scale_dim + ) + # for none absorb (mha) qk_head_dim = kv_lora_rank + qk_rope_head_dim sm_scale = 1.0 / (qk_head_dim**0.5) @@ -222,6 +448,7 @@ def test_mla( qo_indptr, kv_indptr, kv_indices, + kv_last_page_lens, sm_scale, kv_lora_rank, qk_rope_head_dim, @@ -282,6 +509,7 @@ def test_mla( meta = aiter.get_mla_metadata_v1( qo_indptr, kv_indptr, + kv_last_page_lens, nhead // nhead_kv, nhead_kv, True, @@ -291,7 +519,8 @@ def test_mla( reduce_indptr, reduce_final_map, reduce_partial_map, - kv_granularity=max(page_size, 16), + page_size=page_size, + kv_granularity=max(1, 16 // page_size), max_seqlen_qo=int(max_seqlen_qo), uni_seqlen_qo=decode_qlen, fast_mode=True if not non_persistent_mode else False, @@ -301,10 +530,66 @@ def test_mla( dtype_kv=kvtype, ) + def test_absorb_decode_bf16_fp8(): + out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1) + kv_buffer_fp8 = kv_buffer.to(kvtype) + kv_scale = torch.ones([1], dtype=torch.float, device="cuda") + + out_ref_fp8, lse_ref_fp8 = torch_mla_extend( + q, + kv_buffer_fp8, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale, + kv_lora_rank, + qk_rope_head_dim, + dtype=out_dtype, + is_causal=True, + q_scale=None, + kv_scale=kv_scale, + ) + + (attn_logits, attn_lse), us_asm_decode = run_perftest( + aiter.mla.mla_decode_fwd, + q, + kv_buffer_fp8.view(num_page, page_size, nhead_kv, qk_head_dim), + out_asm, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + page_size, + nhead_kv, + sm_scale, + num_kv_splits=max_split_per_batch, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + intra_batch_mode=non_persistent_mode, + kv_scale=kv_scale, + ) + + err = checkAllclose( + out_ref, + out_asm, + msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", + ) + err_fp8 = checkAllclose( + out_ref_fp8, + out_asm, + msg=f"mla_decode-absorb_fp8 [golden fp8 vs aiter_asm]: {us_asm_decode:>8.2f} us......", + ) + return err, us_asm_decode + def test_absorb_decode_bf16(): kv_last_page_lens = torch.ones(batch_size, dtype=torch.int) out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1) - (attn_logits, attn_lse), us_asm_decode = run_perftest( aiter.mla.mla_decode_fwd, q, @@ -315,6 +600,8 @@ def test_absorb_decode_bf16(): kv_indices, kv_last_page_lens, max_seqlen_qo, + page_size, + nhead_kv, sm_scale, num_kv_splits=max_split_per_batch, work_meta_data=work_meta_data, @@ -354,6 +641,7 @@ def test_absorb_decode_fp8(): qo_indptr, kv_indptr, kv_indices, + kv_last_page_lens, sm_scale, kv_lora_rank, qk_rope_head_dim, @@ -373,6 +661,8 @@ def test_absorb_decode_fp8(): kv_indices, kv_last_page_lens, max_seqlen_qo, + page_size, + nhead_kv, sm_scale, num_kv_splits=max_split_per_batch, q_scale=q_scale, @@ -405,12 +695,100 @@ def test_absorb_decode_fp8(): cal_diff(out_ref, out_asm, "out", True) return err, us_asm_decode + def test_absorb_decode_3buffer(): + num_works = work_indptr[-1].item() + for i in range(num_works): + print( + f"work_info_set[{i}, 0]: {work_info_set[i, 0]}, [{i}, 1]: {work_info_set[i, 1]}, [{i}, 2]: {work_info_set[i, 2]}, [{i}, 3]: {work_info_set[i, 3]}, [{i}, 4]: {work_info_set[i, 4]}, [{i}, 5]: {work_info_set[i, 5]}, [{i}, 6]: {work_info_set[i, 6]}" + ) + + out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1) + + # convert to bytes + nope_bytes = kv_nope_buffer_fp8.view(torch.uint8) + scale_bytes = kv_nope_scale_factors_fp32.view(torch.uint8) + rope_bytes = kv_rope_buffer_bf16.view(torch.uint8) + kv_buffer_bytes = torch.cat( + [nope_bytes.flatten(1), scale_bytes.flatten(1), rope_bytes.flatten(1)], + dim=-1, + ) + + out_ref_fp8, lse_ref_fp8 = torch_mla_extend_3buffer( + q, + kv_buffer_bytes, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + page_size, + nhead_kv, + sm_scale, + kv_lora_rank, + qk_rope_head_dim, + dtype=out_dtype, + is_causal=True, + scale_dim=scale_dim, + ) + + err_ref_fp8 = checkAllclose( + out_ref, + out_ref_fp8, + msg="mla_decode-absorb_fp8 [golden fp8 vs golden]:......", + ) + # print(f"kv_buffer_bytes shape: {kv_buffer_bytes.shape}, kv_buffer_bytes stride: {kv_buffer_bytes.stride()}, kv_buffer_bytes: {kv_buffer_bytes[0:1,]}") + # print(f"q shape: {q.shape}, q stride: {q.stride()}, q: {q[0:1,]}") + # print(f"qo_indptr: {qo_indptr}, qo_indptr stride: {qo_indptr.stride()}, qo_indptr: {qo_indptr[0:1,]}") + # print(f"kv_indptr: {kv_indptr}, kv_indptr stride: {kv_indptr.stride()}, kv_indptr: {kv_indptr[0:1,]}") + # print(f"kv_indices: {kv_indices}, kv_indices stride: {kv_indices.stride()}, kv_indices: {kv_indices[0:1,]}") + + (attn_logits, attn_lse), us_asm_decode = run_perftest( + aiter.mla.mla_decode_fwd, + q, + kv_buffer_bytes, + out_asm, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + page_size, + nhead_kv, + sm_scale, + num_kv_splits=max_split_per_batch, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + intra_batch_mode=non_persistent_mode, + ) + + err = checkAllclose( + out_ref, + out_asm, + msg=f"mla_decode-absorb_fp8 [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", + ) + err_fp8 = checkAllclose( + out_ref_fp8, + out_asm, + msg=f"mla_decode-absorb_fp8 [golden fp8 vs aiter_asm]: {us_asm_decode:>8.2f} us......", + ) + cal_diff(out_ref, out_asm, "out", True) + return err, us_asm_decode + err = None us_asm_decode = 1e12 - if dtype == torch.bfloat16: + + if paged_layout == "3BUFFER" and not non_persistent_mode: + err, us_asm_decode = test_absorb_decode_3buffer() + elif dtype == torch.bfloat16 and kvtype == dtypes.fp8: + err, us_asm_decode = test_absorb_decode_bf16_fp8() + elif dtype == torch.bfloat16: err, us_asm_decode = test_absorb_decode_bf16() elif kvtype == dtypes.fp8: err, us_asm_decode = test_absorb_decode_fp8() + ret["decode:err"] = err ret["decode:asm_576"] = us_asm_decode @@ -437,7 +815,6 @@ def test_absorb_decode_fp8(): list_dtype = ["bf16", "fp8"] l_kv_dtype = ["bf16", "fp8"] list_nhead = [(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)] - parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="config input of test", @@ -454,7 +831,7 @@ def test_absorb_decode_fp8(): "-qn", "--qk_nope_head_dim", type=int, - default=128, + default=512, help="""qk nope head dim. e.g.: -qn 512""", ) @@ -552,7 +929,25 @@ def test_absorb_decode_fp8(): help="""variable kv seqlens per batch. Default: False. --varlen # True""", ) - +parser.add_argument( + "-pl", + "--paged_layout", + type=str, + choices=["LEGACY", "3BUFFER"], + default="LEGACY", + help="""kv paged layout for persistent mode. + LEGACY: kv buffer is common buffer with nope and rope parts. + 3BUFFER: kv buffer is 3-buffer with nope, kv_scale and rope parts. + e.g.: -pl 3BUFFER""", +) +parser.add_argument( + "-sd", + "--scale_dim", + type=int, + default=4, + help="""scale dim. + e.g.: -sd 4""", +) import pandas as pd args = parser.parse_args() @@ -582,6 +977,8 @@ def test_absorb_decode_fp8(): decode_qlen=decode_qlen, max_split_per_batch=max_split_per_batch, non_persistent_mode=args.non_persistent_mode, + paged_layout=args.paged_layout, + scale_dim=args.scale_dim, ) df.append(ret) df = pd.DataFrame(df) diff --git a/op_tests/test_mla_sparse.py b/op_tests/test_mla_sparse.py index 491340faec..c18fd21300 100644 --- a/op_tests/test_mla_sparse.py +++ b/op_tests/test_mla_sparse.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import aiter @@ -446,6 +446,7 @@ def test_mla( meta = aiter.get_mla_metadata_v1( qo_indptr, kv_indptr, + kv_last_page_lens, nhead // nhead_kv, nhead_kv, True, @@ -455,6 +456,7 @@ def test_mla( reduce_indptr, reduce_final_map, reduce_partial_map, + page_size=page_size, kv_granularity=max(page_size, 16), max_seqlen_qo=1, uni_seqlen_qo=1, @@ -509,6 +511,8 @@ def test_sparse_mla_bf16(): converted_indices.view(-1), kv_last_page_lens, 1, + page_size, + nhead_kv, sm_scale, num_kv_splits=max_split_per_batch, work_meta_data=work_meta_data, @@ -570,6 +574,8 @@ def test_sparse_mla_fp8(): converted_indices.view(-1), kv_last_page_lens, 1, + page_size, + nhead_kv, sm_scale, num_kv_splits=max_split_per_batch, q_scale=q_scale, From 1c7550dc25a2074a95b9debd9d9571c98981dca1 Mon Sep 17 00:00:00 2001 From: minmengdie Date: Wed, 28 Jan 2026 02:58:47 +0000 Subject: [PATCH 2/4] fix the github-actions --- csrc/kernels/mla/metadata/v1_2_device.cuh | 2 +- op_tests/test_mla_persistent.py | 26 +++++++++++------------ op_tests/test_mla_sparse.py | 23 ++++++++++---------- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index 8142c4c348..9f5a80321f 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -469,7 +469,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba params.uni_seqlen_qo = uni_seqlen_qo; params.ori_seqlen_qo = ori_uni_seqlen_qo; params.is_causal = is_causal; - params.topk = (topk + page_size - 1) / page_size; + params.topk = (topk < 0) ? topk : (topk + page_size - 1) / page_size; params.qk_batch_ratio = qk_batch_ratio; params.k_fixed_over_head_num_blocks = max(1, (16 + page_size - 1) / page_size); diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index f0c426b638..1cd04937b7 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -8,6 +8,7 @@ import random import itertools import argparse +import pandas as pd torch.set_default_device("cuda") torch.set_printoptions(sci_mode=False) @@ -172,9 +173,9 @@ def cal_diff( x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False ) -> None: x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() + # RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() + # amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") if use_fp8: assert cos_diff < 3e-2 @@ -212,7 +213,7 @@ def ref_masked_attention( lse = attn_weights.logsumexp(dim=-1) m = attn_weights.max(-1).values attn_weights_exp = torch.exp(attn_weights - m.unsqueeze(-1)) - l = attn_weights_exp.sum(-1) + l = attn_weights_exp.sum(-1) # noqa: E741 if is_fp8_q: attn_weights_fp8 = attn_weights_exp.to(dtypes.fp8) attn_weights_exp = attn_weights_fp8.to(torch.float) @@ -243,7 +244,7 @@ def torch_mla_extend_3buffer( scale_dim=4, ): num_page = kvc_cache.shape[0] - (kv_nope_buffer_fp8, kv_nope_scale_factors_fp32, kv_rope_buffer_bf16) = ( + kv_nope_buffer_fp8, kv_nope_scale_factors_fp32, kv_rope_buffer_bf16 = ( split_3buffer_kv_cache( kvc_cache, page_size, nhead_kv, kv_lora_rank, qk_rope_head_dim, scale_dim ) @@ -396,8 +397,8 @@ def test_mla( kv_indices = torch.randperm(num_page, dtype=torch.int) qo_indptr[1 : batch_size + 1] = torch.cumsum(seq_lens_qo, dim=0) max_seqlen_qo = seq_lens_qo.max().item() - max_seqlen_kv = seq_lens_kv.max().item() - total_qo = qo_indptr[-1].item() + # max_seqlen_kv = seq_lens_kv.max().item() + # total_qo = qo_indptr[-1].item() total_kv = seq_lens_kv.sum().item() kv_buffer = torch.randn( @@ -424,7 +425,7 @@ def test_mla( qk_head_dim = kv_lora_rank + qk_rope_head_dim sm_scale = 1.0 / (qk_head_dim**0.5) - us_asm = None + # us_asm = None # if batch_size * ctx_lens * nhead < 32 * 8192 * 16: # us_asm = test_absorb_prefill() torch.cuda.empty_cache() @@ -506,7 +507,7 @@ def test_mla( reduce_partial_map_size, dtype=reduce_partial_map_type, device="cuda" ) - meta = aiter.get_mla_metadata_v1( + aiter.get_mla_metadata_v1( qo_indptr, kv_indptr, kv_last_page_lens, @@ -580,7 +581,7 @@ def test_absorb_decode_bf16_fp8(): out_asm, msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", ) - err_fp8 = checkAllclose( + checkAllclose( out_ref_fp8, out_asm, msg=f"mla_decode-absorb_fp8 [golden fp8 vs aiter_asm]: {us_asm_decode:>8.2f} us......", @@ -686,7 +687,7 @@ def test_absorb_decode_fp8(): out_asm, msg=f"mla_decode-absorb_fp8 [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", ) - err_fp8 = checkAllclose( + checkAllclose( out_ref_fp8, out_asm, msg=f"mla_decode-absorb_fp8 [golden fp8 vs aiter_asm]: {us_asm_decode:>8.2f} us......", @@ -730,7 +731,7 @@ def test_absorb_decode_3buffer(): scale_dim=scale_dim, ) - err_ref_fp8 = checkAllclose( + checkAllclose( out_ref, out_ref_fp8, msg="mla_decode-absorb_fp8 [golden fp8 vs golden]:......", @@ -769,7 +770,7 @@ def test_absorb_decode_3buffer(): out_asm, msg=f"mla_decode-absorb_fp8 [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", ) - err_fp8 = checkAllclose( + checkAllclose( out_ref_fp8, out_asm, msg=f"mla_decode-absorb_fp8 [golden fp8 vs aiter_asm]: {us_asm_decode:>8.2f} us......", @@ -948,7 +949,6 @@ def test_absorb_decode_3buffer(): help="""scale dim. e.g.: -sd 4""", ) -import pandas as pd args = parser.parse_args() list_dtype = [dtypes.d_dtypes[key] for key in args.dtype] diff --git a/op_tests/test_mla_sparse.py b/op_tests/test_mla_sparse.py index c18fd21300..26c1685231 100644 --- a/op_tests/test_mla_sparse.py +++ b/op_tests/test_mla_sparse.py @@ -10,6 +10,7 @@ import argparse import triton import triton.language as tl +import pandas as pd torch.set_default_device("cuda") torch.set_printoptions(sci_mode=False) @@ -30,9 +31,9 @@ def cal_diff( x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False ) -> None: x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() + # RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() + # amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") if use_fp8: assert cos_diff < 3e-2 @@ -74,7 +75,7 @@ def ref_masked_attention( attn_weights_exp = torch.exp(attn_weights - m.unsqueeze(-1)) - l = attn_weights_exp.sum(-1) + l = attn_weights_exp.sum(-1) # noqa: E741 if is_fp8_q: attn_weights_fp8 = attn_weights_exp.to(dtype) @@ -278,11 +279,11 @@ def triton_convert_req_index_to_global_index( f"BLOCK_N ({BLOCK_N})" ) - num_batches = kv_indptr.shape[0] - 1 + # num_batches = kv_indptr.shape[0] - 1 num_tokens = token_indices.shape[0] # num_requests, max_num_blocks_per_req = block_table.shape - max_num_blocks_per_req = 65536 * 32 + # max_num_blocks_per_req = 65536 * 32 tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device @@ -362,8 +363,8 @@ def test_mla( kv_indices = torch.randint(0, num_page, (kv_indptr[-1].item(),), dtype=torch.int) qo_indptr[1 : batch_size + 1] = torch.cumsum(seq_lens_qo, dim=0) max_seqlen_qo = seq_lens_qo.max().item() - max_seqlen_kv = seq_lens_kv.max().item() - total_qo = qo_indptr[-1].item() + # max_seqlen_kv = seq_lens_kv.max().item() + # total_qo = qo_indptr[-1].item() kv_buffer = torch.randn( (num_page * page_size, 1, kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, @@ -373,7 +374,7 @@ def test_mla( qk_head_dim = kv_lora_rank + qk_rope_head_dim sm_scale = 1.0 / (qk_head_dim**0.5) - us_asm = None + # us_asm = None # if batch_size * ctx_lens * nhead < 32 * 8192 * 16: # us_asm = test_absorb_prefill() torch.cuda.empty_cache() @@ -443,7 +444,7 @@ def test_mla( reduce_partial_map_size, dtype=reduce_partial_map_type, device="cuda" ) - meta = aiter.get_mla_metadata_v1( + aiter.get_mla_metadata_v1( qo_indptr, kv_indptr, kv_last_page_lens, @@ -598,7 +599,7 @@ def test_sparse_mla_fp8(): out_asm, msg=f"mla_decode-absorb_fp8 [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", ) - err_fp8 = checkAllclose( + checkAllclose( out_ref_fp8, out_asm, msg=f"mla_decode-absorb_fp8 [golden fp8 vs aiter_asm]: {us_asm_decode:>8.2f} us......", @@ -748,8 +749,6 @@ def test_sparse_mla_fp8(): --varlen # True""", ) -import pandas as pd - args = parser.parse_args() list_dtype = [dtypes.d_dtypes[key] for key in args.dtype] l_kv_dtype = [dtypes.d_dtypes[key] for key in args.kv_dtype] From 1fe8d524e023abed0f816fd755a45f3d167d9fbe Mon Sep 17 00:00:00 2001 From: minmengdie Date: Wed, 28 Jan 2026 08:05:49 +0000 Subject: [PATCH 3/4] upload kernel --- csrc/kernels/mla/metadata/v1_2_device.cuh | 1 - hsa/gfx942/mla/mla.co | Bin 33728 -> 0 bytes ...16_m16x4_n16x1_coex0_mask1_ps_page64_ds32.co | Bin 0 -> 36200 bytes hsa/gfx942/mla/mla_asm.csv | 2 +- hsa/gfx942/mla/mla_page64.co | Bin 33984 -> 0 bytes op_tests/test_mla_persistent.py | 10 ---------- 6 files changed, 1 insertion(+), 12 deletions(-) delete mode 100755 hsa/gfx942/mla/mla.co create mode 100755 hsa/gfx942/mla/mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps_page64_ds32.co delete mode 100755 hsa/gfx942/mla/mla_page64.co diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index 9f5a80321f..b6d2d61784 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -406,7 +406,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba torch::Tensor& reduce_partial_map) { constexpr int32_t kPackedQoLenPerWg = 128; - // constexpr int32_t kPageSize = page_size; const hipStream_t stream = at::hip::getCurrentHIPStream(); hipDevice_t dev; diff --git a/hsa/gfx942/mla/mla.co b/hsa/gfx942/mla/mla.co deleted file mode 100755 index e878cf163a29d7d8a7e875d692af01ece8aaf2d4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 33728 zcmeHw4SW>Ux%bKJZW3k_Ji!111YJU)fXfmBgbgBYfQ9fW--1=bk`DreBqjlY;${*8 zVyX}$T12!F5s{)*FKV=ytQIXQC=lyKtu40I+uN(Xy|=f0Ti@1w|IeH`%Vt?j0KM_v zH<`g9+#$#n-v9Pl&##|@xnCl*0tjx3#=VtNv-<>6} zSkTngBZ|+%fZ3VEc(WeSa+@aA7J8VfL@sn{2NM{Re5&A?i8)0=m&n4}k*zr!&GOxC zu}r9}$>Voo;g5rvM?t0ASD_>FZ3zok1DVHy=a;>j8 z*O!-9P*HKPyug>Adt3R+$^w++=a+-attj*573^P8T)Ny>{Ftl6x2E8c8yW>mih8!qxdL|4lLi0Yjh!d zs@zak@aW1>8JB9}T)8zKWozx4s6}biVH2M)yjWUNyx}XII73?Kh_bPEEgTZ1(bsL^ zB57fIl#R7(p*u>Wui3)A6BkGmJyEvSu8BEO8hza+PMRrAoD*eh?V7kaN~5pa#2KM4 z%d#k2YuCi0D2*;?qB$PytH>)}AstgDIrhc2Omy%W#5;OR?f-^K-pFSmoH%kALEC3LWN(m59d zE}?xJ7aYCZ4vt-J2gfg_gUBGSV-eKexuMzr_(`$3iwi`LpNZ1yf)mdr8ae-*k!Pc9 zuHD)7`0t{$>adX=jOh2yn-dy>{upI@?OXZRD6KkdWd}d+$)PyqL=20vy>_iU5fdM! z+1GAnD6&5h!=r4k17>PbnqAP$7hNmLOV^g=RursQQ&3WwTexz4L4Gbx#&6Tf|MSL) zAI|tf6@0}jSCm+vJCCO(Ct@5;t~I{(xrN2P%G}#Z%U7G78)fi1odEuLO|dUmoKEHz zt}V&?iL0!9d zMG?rDfm>OuqXp{|Bv6VT5NyAmKbyaP-o{R$RzhXvsapf*)tmG!wY6lZ_urb)n>aJX0W>EldJ)<9dbWpcXN<^znPAzAp6k7 zdrjP7;tx!G(7DBB|bM2&Q4OaGK^9WayUhN&feZ<(>(Pzoo{7&yZY9{kW@N)U&)<+q1^!=q;?N*_{I)doy3W|&H9C8}Ij45&j+`Co zTWguqQRC!k@oe;|c;wE!-05ghT>jokJq8bVI*X_*w?_}o_?)xxYOFqu^CDy2F4pJF zhE0mo-97=jcqkhVR}n%TbW<8mHVw&cXO_1B_-9&oztYdr;|+$&2d4u@!8fe zPnz4~DI(p016g;c>cP~3x{iZ0I55pXVW%r?2XhXtai%+h>FEg`3<=r@SJ7s0L1%hG zJaU?(*SIF6rLw;_>z?tjb?(`#m{n&ZVGPFDd8F%k2aKZ{U6e9Aoepz z`rm+pW}4aG6-V@V_$xFHFnn{(5efJ4+uMOfkwG9)W4Pz###FPG&L6(>(%5>jK>Yf#aFgDED{-PNwlmejfB)GmgPC{~vrF z)#CHe=NSHA0fAOF*U;lV0uywBJRmTc%`?h1kH8dNpd%nKjmvXEUy)J)HN%04RRl%Tgm4UjF=GT?M5C?ceu(*Ov(uTjg}EF-mn(sFLx?uqSww9d zn48?r{FP!JRMO&tsRpN~iaC?&22(L-Vm=5C9AqHvPf$5t1qL?0O;{CFPF4l6ruG@a zlPyDvx(yvtL~{wv?a7C2x`*&3o30ZEYTs-;{q0a0&<0#z^^&1)0+L=^6)^N|K+@~0UN`ieK+^YA9XIqwAn6CHP8<3mAnAvz z&KP{Z}Tx3sdK0$;nc zl^p}VaZf9I1Nhc`t?X^!I}f$8cYvoJX=U#MfBr-(`#JERpKN6x06%=Lm3;{O=*3oc z7WnZit?XmqZ(eJ~Ah`R?v(E!!fJcA|@KvA%cnoL-z74bi-v!13KLEx7&oWl6d@vyS z!w*+iE58`<`9~kFrqAT_Pi;#bg#Th&>LL8hwsaZc=eDKB@fwe|pQ4X$z|Vjvqj=5N zqeoF*b)KHLJa22B9@~QRm==oRe9PS1yRk8+y4{L+OId*9bP;4Rt;A@z_3$WF-pkX?}7 zkUfxxA&ro|kbRK-kOPo|kVBBCAWuV{g&c-F4{3tD1UUj}h6Et5LXJXSha7{v2{{gV z8*&nI8uBjWJ;)iz2asPtehE1X`8DLXkWV4MgZv)y8RQR;KSKT!@?Vg@K>iB(0`hmr zmk?H^D+NR7e7(JA^}eK$0Q7AR44EBn5I6!~sc#q(KHlhCqfvoDdfz9g+bV z4atPKA>$z9Arl}T$VA9w$P`EpWIALfWHw|DWL_1;*0oMo0*xt~o#}%qw%&$wE^S0S z;(w>3+UXkHh;hhj=&r#uCOI&68G3wt8nqX7)Px$;mF2b(YLOVHTnV%u!1!gud4>(g zT#c%x)vyFH?%}*I7JMx62?^-0dzw9LTpDDKOYu18=*aO*$A0p2<{0)9a}0aM9K(KQ zj$tRvF|5@b!`?T?uwU8qjZDl@=Z%eswcwU|iot+8S&V_ZJOcMH#n2l)0{1bCp&#%F zJVY@R+C5D%6!0*!8QA0zc!b3odcY&_D2p@nV;+IWS-hd2^a#Yr@!1)ivh;HFFi`WP z8Q6DOrhzFGqaDPUqz!+M9WwAV`w=i@%zNxf15Yz{PxJe+2N|P!eMX3SeMX3S{l|)W z{l}u-eWD(-n%nGUZgYsa&2bm8&1X8e&F9Q*;{0)Hwc0XN-O>_QV{J*c)wWa}y{m=B zy;SybYkiE;H`2GnH`wY@+1{lwrm0R_J+3GE|3 zBdWi6o<7vBw(1q5-8P_T_i+EZ6ta)%G!8%?5uYpZdk2X=_6?;z>Z^Wc^sy5t z`q)35_&rr8489R4_yZ%+U&L`i{T&=5`a6_G?G9A68to1NMZ2f26E;zPiRwR{jea7I zC*q&=Qv1VI?;Gu#fTI23X>}>>9ZS^y`B~_fzJG`zNN<`8-Tb^se`#S|iigS(TabQa ziRkyMjy5Py6_T}os;LHygd!yYsCA^l`^T}ot}LHy}0u*Zyz zNPqWEv7Ptsu1o32ST=v2ttE|ZLHw5K$4niZr)TL|%4mPw=+Q;~Zl71mB1F zj0}SxMEs21Mt+KprM|57RO-|^D`DDgc!R^IZG&ZqQ#b+?|_woqrtqarw z0{dKS97cVT9lb`ZGVIYtt+JjMo89x}hjAFg*?k7?)N>FoQLOfWq3>c38MsHE2A*O! zoGYPzBc00t_vzEYQ!Mv{p&wvR8hA*b0sa^wYkW?0E-;wAd9q)Bh4t<4?A3jEniXOD zsU9&8uwIrFzs>1VZE0z0kL)ymDtqfxs{Q8Lfxv;lG+-KV5O5GMHG8suVD=P$T6T_q zQ1*0xj{y!f)spTXI5J&HgA9Ub0~|_FnSkSi;(HGv;`@nyith)roRg_ugjh<1Z!2!2plyMbmq%KI&tKVPR~ zE^-d<0lRxxGW^3l&$Ma_70pSwMA2M?wS{h zT16XAxIxh-5ZQuxn@G4t(Iyk#sc2IO?^d)N!g~~LI^lhaHk0rHMVn3dkfO~Y zd_>Xa5k9793kaW3v_*tZD%xVgXB6!w!sit2X2KU0Z7Jc)infgK6-8T4_?n{S6TYEn zg@h**t%&fHqOBtQxuO*lzOQH{gdZtd8R5r@R!;bdqE!<9tD>zV{9Mu26aI&yZ6y4& zqHQAl8%{0>|Dk9)VT?tqA+%YvTEcD?Z7X4-McYQ0WYOvfds?)+2>V#H?S%a;T0P+a zi`GCm$fE5e9BR?-Cme3kb`g%WXuAo=ShPKaSr+YK!s{$rBVo2h+e_%RX!{7KS+xCx zvn<*H!nqcWzNf=57S?+YCa`CW@6XfhIRg*7z&ChbN|4{P4Cp((Zzjm^S0?oPy&oj# zGx2@HIn+4lyS%?oU@sc&o7l?+9+?1sxA%(#eHQ8=Qkn>Tk5^69XG6!i))eRud;2Ep zH$Xqj*mUTP-eHOC6{CHCy=LIi+2Hqj$0g1K?(wPJ)9C+CKX5xl!Vm@03{dw;P zi5r1U-rpzcIU?3-yh+4be(*1Ozeo%eimCsLo(AYgz0Gcj_}uUn~{A2$Y#Vs1ZJ(@g{~% zd11;_Ys;O=`SGDX)3E*NfWX74up<;7;=F=%86Sq{E{K=$VR#OMc#048Sw_E20f9$| zkBkq)a~)J386WDijXnYafk&xMWPBK&AE9y?ABN{ih^P2azrpD5SU})$;v?h3@O%oD zM~e?B$BFZlla%8GzDxCFd>Ed8A^xJ{L!6&cy9>sLIB%o!i;fR*p4XxH5a)r^zl;yV z^GC#o;zOJ#l73!%i1S9`+l>!#UP<-ZjSq3&N%gwoLy?2~w#0{_TwX+c7|QKM#D}3= zUqpN;bAMg&;ko!QGQZLlA0j^LiVr&!AIk5Uj1T4aOU8#X$KGCiD0A==G4@h?D0A^L zK9o86Z)1EIng5gVVPqb#2Kz?wAv5!rGCq`fOc@`_e5Q;KWnNRphcdq@<3pL}l<}d= zcgpxs<~{df``>hY*j`>TJU)!UHIRS*qNas7k|bNr`T0cTkm&je&B;cb7?HarUdD?T zlTV~}6dy**D|R%WNPSRz7%i`ODfvX|kK)51) zJQpu^3<>0&WiEXc6|99jQ*Z$-Dd?LjzbUy;k!4w}x%P-3K@M7|b z)E~u%(ejEiK5Q?aNPSRz7%i_T-~H38~*pjt|ZJ;f3PE@O+{XADVf@^WsA@kJxT}Xyy^yjStN{ zVpn|F_MW~@%O}cv0VCqWQ2sC?K9qTl$oNp^H@f0Oas4{l_2#a8Vpn|F?)7H#UflNL zL-~EXU_Q~jcelOxu)TcZw<$jC$|rXGKF9EU;^$w8d}2DD)#+{Bf_oL+Titk0L{+w& zpI3~MN2K_XViAfL3F&$>VMp#iq=b(8@hf>*pd4Qsa~|}!;yJDsz>o1 z;l*4JruxyY0~3ES*MB4Pc;R_Fiqoh+%HI)&$6=(0$6us}$62IDi>JQr_vXoa2GfJ? ztwG$f=Ei-4@_xa#_X@sofqMntxWK)FZ=8&NuizUequncbp?d^75|dG#$b0%&WGodK zLq%Sbjkp%ul~X&HQ|r3s+I7vfqxa;6a%bi}urfB1xwQ7KwNAu1YUIq!d-vouR+($- z=rvY(58op6_uq9do*&zyd+%z1bS+h0LzUM~w}Yp9>g2Ujd5u(F8+{n%bkE#gAYBWU zd8>F_e^iz8uRli1Cw*0%5G`L6ejggmVUck{wEWOl#R<{!KjHV4(Ke#R36c38s@D}K ze0%QU?}`(Q`*!zrp=3)BJv`>>CU^1dE{{d9pT0fGIg5BaWw zUU-(M@4!^0H?SX&o-IlN_5t<>;@M+AJ=@f4fWn3l4pdk#W;@lr8qW&@uK@B%E`QRb z3_r{~#Xorf?+s(8l0JEa#UR$k;+%EQu&%s9`}jVix=lgg|7DzDa6zPU!_TWVE)$5xf!xlQHYuT%Nmcd2~Gc9q{#ukw2v zRDR!1mH*&=l|QgclpChgJT<^D2L_N##FzN#!peQF%+V%3len{HL#~{I#Ph|Jmy*f8&_S z-+EK!CyuN9owrqf>ZHnBPpkar@2dQt-&6VfXH@>-2P*&Q7b^eNFIE2WS(X3h*DC+y zw<`aaPgVY}zf<|Y{a)ptf2Q((|AWf^|Lw0T|G!_T{2zZ; zdGJe>$HdsV#bW0+o1MqS*?G5ac5b)Zd19iSUvY(w%&z^SPySJV9>0{^p z`q_E^{&s%#)pkB$fSnHwb$DDb=TSX_1D{ZcD9{Qnq=o*uboewYUk6Y+4+nac0OyCo!@YSozI=Z`7?Vf3nu%Jz;M(GY?}kX5&3kix=zqk|j09d!uI1d!s&Rz5~yx*E)x!;k{3e zpwpSa@NmM6xU7aOJnNV|d`MBUWthK=iDznz_lTi<(u|o6l=}!e)r1^;eZ+I=&V*ie z)p+lejyi5tEu!brsoxrBd;)D}ANudrE&g1ccsvu2XWlRLyq!MN-Ov-yviHQZ?87Xh z{Vq$UKf~hoXIjSLdy{jrFP^2OXYu+*JX6{+~&EK!~k(XRisg6|*J zuAkYC?p?RxS$9IT_I%0wXoxQ*f9xqOzi|a^h9CDc-Y~A9g|^4PAAw= zZ~Egp*pUr8ys%?hC)mNe3xCivhICB^L(?#iK$nj9FF;yaH_%#!D%z+&XV8`N4utSbtvtY+u*s&0HEa?P0{tprmcx$1 zPOw9cU*Cltxv(P-cC6?GJLLFvGwfIfJMv*iQ770T$FEyphYxlXz>bxjV8=H$ekGaX z7h(ALMLLaNgkj^?D)^xUekg|@)^&m(V_+dNz&;UQ&-wA$@ z|9K0aR8yLAi8=hZb+E{^hfB0qoaqM4!k=pl>0N7m9sFqUZt$LX!4C&N z68sqOS>Tr?2!0&+3E(G!p8{Uzg1;90_24Igp9;P)S@7e*d%#Zyp94Oi3I00p+2Fn4 zr-474B6yKw`0{;`V=&OAGXpaM0yC+O&JAu{K;Q(HWatw;0;jaiF(`OXFphg+q zcI-z3_Tzr+hxGUN!S4WnFZdsTm;SyBd_DM`;CF$S{=OUhJ>c&H{{VRD@7>@Z2EP~l ze(=)Y4}$+8_(#D%4qp0u5BNs#`@kOnFa7-x_(#A$2L1`~-}SH-Y}FDzVSZRzvr7UW?m z7hByDi*zY192Yfdu?xVpPqElc++&ujrk-ZvpT7*2(&}Qf9OJIBY%+CI4Z+NnCHOh1 zaSr@uv18`U$&QS47ycYWMiiHptSBhYE%Q|t(GO4MR+Q#e7U74f%FEVNho$4pLQyS_Er{6GaI^kzhnW>k zaGU(LYufUYOy22i%g;9Xlo4(DB_{vb$hLgG$rs<;mN!0R*zcf=jbY11V?B>VXTDy z;Ry*WajY?d|H$Nr;CHpdm*1IuAcFss$sa2YuZ=gV`BD5B30$H(dc|mdR$n25n=4yI+MpkGwd$L>|IeCm+%@`G`kdc`YV%-zv z>Zs9UL#tCltT!vPn&o!O3Zuu4maAhjyrI=GAvVh$+VI$knXrR>?cXMRML|J!adUmMplu^>ikt0XT4Ucq>k(Q0jPM;L{`*M*w{oY*e z4%O zXp!?0e%g!uB^WG~?dALRXn}`AVV#Wwgk0X}`p)QDzA%m-C}U zxqa$`7TSmKA8Qh{g2O4 z#+qy|aS>SQyYQ+--!nDSutca=L8Sfi_@QA}+HSTmpvW&Yx&N|EZddB$-e~R5nC*QL J4I)(b{{R^JSgim6 diff --git a/hsa/gfx942/mla/mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps_page64_ds32.co b/hsa/gfx942/mla/mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps_page64_ds32.co new file mode 100755 index 0000000000000000000000000000000000000000..7016877af0fcb2aa07469462a8f6548317f3bf39 GIT binary patch literal 36200 zcmeHw4SZD9weQZEd~qg;;ea@j00V{)EQBOXhDk)30KxDvdBBjG2 z1Vlu@h!!JSL_kEe*o!t=nhc1D2uc7EwBAem>b=_6-q(Bkaa&*Oy#G3T?_n|wCJ<=6 z_htAk{%8O9-fOS3GH35~*6cI0Jb%(SB_hHZ$@t5}&a(({OyCs9Up~qYCPfkFV6pf= zfyJ?C(A3w-t1m*IP$n@}v?E%MDX7-aLlh+@@-p><35-epCeAZ~1(m!kkp=Z5TMJf- z`WIWH=*Al3I4}DtD7L4c>@Ux6`WH3Dbbjp}Px^cs`q6lo`=YT7f0DuU3cORL@K3p`cjp0esE*6m1lSIw*b1tH6GLy7gvERtgdv=_3WHiR#D_G+mT-GUg&vY%1h}LB_-9K z8nz?7dj3*R>?`Su-DQhBg$w4F7r)Az78WfkDe+XT;$2>>bQjA`v0*yxDtfrkP$4zd zJY4B{asKeED>ZSN94#@-Xlnf)wXb! zv@j#gU~O9H2-E2*TexlX9nwT+n9#Rt8L<#Nz%lrVMc4y#Jj?Dy4ogA3|uU8 z!i?6YiKSsWUD8A`A8fClTQ*NRrjjd&l-EAZB5f%Mo=|E!*ndd}`!BnK{@WK8l`W7C zIHq5LhF}jItLvzOH62xeRfU#<%laekD00`#Ee&{K;}xiA<%LZhRj|3E3btH~a)@26^p^ptem5tp3MN@Zl~m5IuG_Os`8WJXdJsj0;Ad4>MewtLw2Z z!}Mymk?oA=w@sfKn1X&EX1uno{7aZ#?Y6R=fA6t@IOTW*3o~AuRvwRt4b$zan;D4g zk4Gei8Lu5?TEcX@q?vD}&#S6fR9RTE{1Unx5Xy}@c%5zl&*c@~JvrNr+r+Bu;S0;$ zh3-uILyp1+OEc|-3p4FYY=!0cKeKRdg=a}d;X-%yg3Q9o>cUF*JdfR0SX@0kEB|Nd zl~wZ>yK6iFt1g($_N7;O9>k4ktI?mtA3Y-CmGtU)l~vHZsJv#*OX<}QFD$AkJ2$q& z_M|T;M)(;}eOzHbPJhU~*i%wf0hf|zobo<;v8Sqfenok`yvgUc;;sItD$RtuahcB$ zwPS$Yk%91OoNx-B+DqvzICdkm68FHilD+fKzxiJ#9Fq`7FPReHCH&uEA9P#;ycjU& z$`0h&@!n36?+K@C6v$DgI93Oz$UIGiKVr4tkBXRR-+91VrdZc1kyd|XWU3$S{EEVn zF#}h#XlpZ$6;|71Vk?^>SzVJUD#5$lx30O5zmK6SAvW3vRct+U(l}J7)H&+=`}-U9 z4y#qy`}eP%dc9(_+X$q5Oh11=qn_H&oW+{IOj^}EbE~7Cv2_@7ZL_0qUtRCluXaN` z@Ar}Dx1on(jPntY(pMYIijUmz5wItfuTy*$AXB1!1k4m|)U(KFpOwl{(Y_2qHQMJS zjE?qACA=ZpH-|7$@l^txKpm(9w*pyr#kUvW1p27lE86#a0NHO};m{qmj`f56gN<`a zbEN5m^})5Lh3(X4BRkOK&zo3Je++-#gnAuLsLONIPh@MGPOy5&ddSm|h9*CItcl*i ze)gHbCk1X0_(Oru8raR>%@}vKf3|_${oM_W_s1J}z5jXxZ}i`2Alg4d88oyv5bX^_ zdjrwlK(sfI>|ep&Y0}w=rYVzK^c8GR(~!FH_2b8kua6uuw0>B6qrM`4c8|iYacuUX8&sm={vH<-Y>m4f^bw>ewn4h(w zfHgKbrk&JpV{bQ2&z-JMUz4LdavE0-rFI);=wmkCQJ*!`+IpD8!LW@P$ zZ@jTqXNO_uFPm1dUo{of-M*nXqoE2_$jqc{MItFVb0a1kP)G&z_|#ii-4=MMbNPy7=&QQzzqG z6q`4RtvWfurm$qT+CYWr%}G&}-sGq%ZzPT#iHUkn_wKdRrsDi3(7Dt%#rPe@d!v>W zq&Nx+N~vtc8(Vh_(ES5$(sjI9{Wql;$j{53vVqxBYi$mz-{FXJVv0x}VJ}??&Tn(X z#Ug=BlfjtEFd6e;=E#)VDQL_$SltB+#}n_<+aIXMo|IcE4c z+SnaNy~D|Iq|VXlJyzDd+Ho5K90%kMWf#8 ze+=S3tY41U*wR|2UMQ<->5s6Ut+d%vRx!KPZ@0&>Zj2>m z+e)dA6>F2jRyLongG$OuTVzVDEjop-kG7N;tm}T8JrN-lOEM|Ov5w0!DJrrQv2Q7> zV{9YplVYMv{n62-43#nd0ft_E0AF`$JpTZ!yI52FHw`q9eBxKedBNa33ka9{l?h(| za>nBN*b;l2hLrY>N-ymllTq3m@y6o^8S6b19F8BO<6)@7@fteL%qU&qcIzwM?$xw@ zQGC(M7o|H14@BuY;lU`~NBB+@)^6Y-)Dymodcwn~Cwvd}iFcg8;RxZ>^Oi58bh6cF zI;hdKj3)mIL8k^GO&^TuH}pOPqDXv|91Qo zJ0Nr%5IPRDL&w`f$Ipb0gF?qaq2r*`F-_P(NOlmC9l-rU$Ipe1cZ802gpPOGp`%IY zXcjsS2_1)ojzdz%bYTY}*+EEl0DVHo0iokvq2pblj&H15H5OX|PD3 z4;Tf+dzM854*_F!cBS5_8E)_TH%DuS$?}yBR6hVq1C6H3c1CTOEIiwO&1*w57 zhAe?Bg)D=3Av&ZMQU_TLSp!)MSqE7U*$AnJY=S%u*$iobY=Jxv*$UYP*$&wW*#+4R z*#mhEvKR6b$UexMkVeScko}M*h!1iQatLx5as+Y|at!hzZ^kbj5#9`c`%KSKTt`3vN0$loB}K-h9!Q6Q0!Xh(&2`RNK zk*_BTU9jV}%Z{FNl$&vFuUo4BW&b4BYJGxP>W3d8?D-c4jilyPO>N zQ0$9-dnxt>?qg8~Haa=(XVFI4=j3>Z#TexyPL9V|tWiGUXm7!&vXx35r}cbHc-$J9nPpNy(&UVdm@Gp*05 z?3brDMkq*eu5AwWFV)pYH6+WqmCBpCH6))D^ZrRO@1GR&{#Rm6HxO%wfmk~TkHSB+ zrhfpWHJx-l9p8|AO6WQzbe(F4u46*i@gQBByTeX8hmiLKjni;_Lo(r(8yk{Oi}6m2 z@lMO}L>%yyhyx5H+fn{djQ3H{c+dA~NG2bW-mU$`*dgyfs@pb@kG*|xL-NOB?2pCR zAGc%d6JqR^ps{xjZAd;Nbe$2p&PZJ%j`&K%5eCwEpnOv3Iu)dASEjH%@whe>T_sZSw%NYr3x?*+=~lQ!hWj=T_oaDj%BPkbF+~ zn~3kJeC%QVUik0_uwBIa#GhDk zq0h-xI2QK+RCoGO-sjBY#PfTCqA+>Q8PlVhQ`jop*W8Bt8Qni!&(U*~k=~e*BTHv& zXw>=oZHsf@o(k(Q)@(=VN?UGeD*NC>?^_k-1CM@qKe3j6t_I94ADl=c-flPeLBvnt zd@}J4hry>2KaKN!h|kS6_`$@_XfX1|^giv&PPL>?X^13Do3S*Kwe&t4)tBd$Sxebj zgc7}-9Fug8$v%#KXbqu$y+k{z58SR#0RJKQmyPl+_Nsw<^oih4;EoV}r2F7hcJxHQ{tD~c-_|QJ zJ1r6^-Iksa&#>Mmt2fGKSEJI>)Sh{1Uihme)m&CL2sj9s222AE1`Y z%bV&QoHyOub6|>^YI1l7KL`UPSo;c5* zrd)5j^Yv57_bHYPLbqbcBJ?O08{vG#VkcavSR8~8DwbTrMT%u4;lqmMHo_k$7AN5f z#WI?3m0}r7_^4tTNBFp6DIk18u}mO*O0i5Ld`7WMCVWn@OeK6lu}ml2p;+!9d|9!~ zB79Y`+(r1hV!4~}4aIUV;aiI3e!`zCmN|q66iX4|yNacl@O{NnLU>%UloGZmmInww zRxD+N=M+mh;pd8_lJJ*`rHb$i#Zp7~m10>;_&de2gz!HU%TmHWDVAk~e^o4A!oMpP zoiM^=sU?guS?UP8nJlXb<4u+|gxyV+wS?E3Eb9nwG+EXY_AyyD681M)>Iny$ESm@i zn=DTg4mDXe6K0w$4TRYy%ND{MljV8BTTPa&gn1^*Ho|<9Wjo<`lVvBN%VgO_IL&0) zO*qqJ*+V$nWTERg9dp`7*V}RIb>n*5%ib_>Ul#a!*U>n6UD{FJ+5*EfXBI8@8)rCE%>)xYJ#tf$GhjfJlwyFxFK>Sy zc<{gpbYjNrwG?x(uo*!*hb=tyU{*j%4%jP3JrJVfn+WAm$#+asM6n_rC_AL){@xjx0P z{fLj_F{%rW&96p2kaS)uHmA>6aKC+mK4Sr%rnVHDUyU3h#f6lsAf1;Tn5e>Mn}Sk;5TAY;2Ca4%J;0 znwH2Go&!NU)4yD*!ejX)bbNRVcdt-C?`PDx%HV@5v%h)_L|GgULNU=E+ zd3za~%ly5J&1D{6#^y4gFJp6=*O#%m%z|? z>sssq(hd12s`lbMI^r4n9Fo>pBR;=GUj4hAM5wijj_Yz)4$7k zbh4e|a>Mp8`Sfd^N2gd?t6zP{L@YQd#p=QNbIPX^Z^Y_BxpO0zF8WZcPWMT)r-s^6KD{eeUm6+6n}_DpW!^j_W|#T7 zi}L9rZyuUY?~2tg#Ohsnbi~?SdGxEAN0--Q7Und{pNHnvWvt#-9$n_kM?(+g&qMR- zCq|1Rgz$#xl=hs&o^Y<{is=rp#B&BNu>DK`JE=F!P^8Jma8r&DZxP4nm!&x$;^ zjLk#y>Q^J5PUjhj%|(9vvSV|RGym4u9Qkn?>xyG@ku$$kY<@NJ=tgWV^5+-D<|20< zHZ~Xe^NV70kvk6?nIeRFpF3zW4B7aWtJFT}As}s_G6+*e!q^*76)Hd8ckA*lVO_CZkH_Srfqlwye{$=+ z$?socZ}R(>*qi+RiST=q-#-y`I?imi` zIK`fR8Ar*SXIuLb#$o<6a+_jrvD}9sbE56thamSD|0DMy)WUY!6TBKo`wZm%0=ciC z9z5;!mHPoXbFYitu?GuoB#aMhk-$M!W**w;UqHZ-$yc;-{X4d3TL{5A?b55~@kXFoY6=^T@N z9Q&}(jrzV$j{S6wRv*Xyq(go;QExmy*>_N?k_7Arq~|HEz&^nKKzt_#p3^*fqW3_B zr4bHNSZ@~9lDGojV*%_A)W+DoJ;vmEVP=c=#WVC1E-gL2c!5PA9@6qtygl+$y>#3c z&vhm(FOBR6?6bTwQoHY7RV%z-)!cJbZElgOd5TqSUWuyBFIBY#52)J0GF7W6SG5N# zRjs;8)fUyL+Cz&~?cpV=_Wh-*_Jd`rw%n^~D|A&`S*vQR>QwEK)vEUB8dZC2t*Sk~ zPSrN7SG6ZLs@jwFs`k_-Rr}%7s`kugReQEU)t=j;YCn2j)n3@DYAY_it3~cfVD&|MTyv_8-4jwLko)s{QGY zs`g)hR<*zWMb-ZA*Q)mSzp0x48&!*lFl#20S&ND?YcVlqty?#u1*b_cv=d-DK7V4m4|n2AQ?NgU#Ca zzGv2k4mE4ThMBd@OtUt8xLM22HftkBn6;c7vv$iZX6@En&Dy9@W-Twztc@9C*7Ebs z+U>WSwejQ4+Qf-w&E+y{Q>K`;Y17Qwj2UKa=1jA8=bdJ4_H46u&pqZl^~mNwMAkL? z*#}L`$=Jln_+H4n?$Y(yvulm-iJVB^6ZvV=Bl!N5I@^#Ge9xoRZ?nZQJor2@CZ|3J z&tLb-9#Y!NleU@goIN1 zz7*11Ym1AcakfK$@A$Y2ZQ}4`26C5#*2C zV8>|KF%EW2=ma}hH*Uu$*f9on+zvY?c7h#xJh#IMJI2C}0@yLB6YSV1{E-Jc@?pn# z*x~8~I}#JPKj`^r`fLZ!I;J~0X6g9e9+WeD9BoFMZU$#}a?I7^jk3eXaU`wpH<*Ta zS>m$uz0!N(`}BbJUb=tCkWyyFWJu*K!RKD!C&LfZ;fGo9!`+?W2RVODfgLko$DOd_ zo=&hs&RPF zV=nBN2RjyYf*o@HngctEVMi(KDC-0}ba?&0mB;^Ve$;wLF)nGj=EPZbnc%a*=YYQz{G2$> z-v)j(_;KJTfY%c_KMMR9@VA4X2!3l%&O5=61z!Mu5_q45^LgO&!H);;0)IZ4^E}7! z%{iW9FfdDJ2HJfbbE%DQ;75n>cEOJe z;d2E);Q~Jy=P@1UF$?E$H_k))dkXj&;O_)~4|wVCso?Jbe;4?B!ApNn13we|Z1DGi zm;SyVd=dB(@DG5O{w@SR7yLZ%3&2Z%&jDWyz7%{JcNSDb7Rsy8?VQ_=muMAH4K;CHNZf zOTaG!Fa7->_(k9!2LA)_(%)Y2wcuBSUkhIPdpY=(;2#117 zeiisf!9NcEn)$mY{FV+sT_56a;~G@9-_$grS zUjqLMcF*c7?*RWY_*cPSGk^Djf5_iA0E6G(#FM{?5B4|l;M9zG&F$CFu7;+b1C7hhO> zsTY=3ykeO1$<{6CcMOT&PCsT9^t!<~j$vyC29Y=%Ao#JVt@)vXpA*761Ydk}>-J*= z|HtpO<|hk&^N`m3Y{8G@0|hFJ1-~paKp6+d9~tLWW^2ta741_zt@%}gKksVI(@&a_ ze>U4&^Un%C`IgrFPQmXF;a?a0=@9+<1>ZQXwVw9`f7k7;c|2)>R}4!IvHx?y7f&}1 z0{?v}_{bRn$~gFg;Mx7H`ELZjHH7bGyn$lakwvZB_Z0lHg{}F1vi-c){LO;zS=pM; z7W|TRt$C;5GeX|CCy7^LS;32~+fM`^?)P0E2OsYDU0Hg0MU9807e8E%U#w;6^U4>c zm%6JBtYow7_Q0WI zoID(!6F9VGj}~lZh9jUMZ*)MfEkB^ao-smduxAA{fK2Zj zJvq-co=<44-;7sKqPKMn1V0tSU-Cd;C4NIR_$Ajb$x&Zg6JLRu^;M9LH*@AE^&?M zFM-Y>_%#Rx;-r50d(IM_!HsboS{??D`bR^I*G&R@a4ZO!w~WT32|+f5`qDYjOU_RP zD|uOVCm+#E`Y(|e@txqDll>cC=|73H zP%nKK+_bcgmokK5@hDm$(ti2=p=Cqbj>Euf^CSFFp1-V+U-R{Zk=7ZTJ9uUG10xlsS!m@+}!UidBAdBHsK&9wa!a@Q8LlToffzo6M zfzs3hrpSww`bHWlQnctrjTDpBqD2J-qP~b9MXT0&t=G37-qvf~`~2t3*-Vzj1kn6^ z-Vz;~vetS|R(?)#Mm)Kq7;XLb{tgL<^Ri(P12hU?@oEnygcg`g$M z-HZkO_e92=z?_-6*U!G0v6_edD+-D#{O&sn%S#H2EAGDQ0k^+=X~ml?ebd8k|KiG( z{^ET9;>Cp(75mEz{RR2oD_>Swh@66ga&Y+-W&Xv5dzTiMF7g*Y;4bm6D7^oMU%E?| zEU750WDmG2maQ#}`;~i@zj$R~{vFFo3LX+g^B1jLvZSzlqp0$WGJk=r6c?e=zC~;D zjV5GEm21ihf3Ykx<5FFmBe%vQY^_5VwFs3u?c#$*FP4@RuldSOoGv|dM%Y+~9uA98 z>Ff4zp7bz1!p1uE&=aB3*X-e*NjFIsy%Dz7p^JGDDt+B9PM#rMoE>3n9lAI_LZz?U z#p$6h%fbj->(Irb2$e4AqB$NkRxB=FDm$i3G+>wCeHukNYQQ?6R9O*{Bk-u_rQK%QTT|yJ>dSUzJHn8(@ z8`ymb4eXgy5$b}5OQ_$j3mPxCfqj?Tz=2C>V9%tsE;x7z_1ks9q04Qc>2e!5atRIW zncUU|flH|0t_zM{ZUe_Iw}Imq(?EET*SQF4&z#WgfBdA_+{Fc=$InElb-{_}5}mxU z&B=2SHrL_odi=u(wL0x&CnNejb7zN!pid)guVXL&5usM6z3k-Yog0c%PDHT?+w0KF z6H##ys(tNlh9dhDQ9Q!-I^m`kq1pxAeA2zNymVz*ensKZ6@?|0`Ae3qE-c8W$@nEY z`G4Fv@xvKksDQtC+0qjGW9RWyepqCA2%cjmCi}b`wPbdTdTEXHq{xj;vkR=mmSDd=Z)s>~My84q3d^V)XXpyJu zA|>jwu>A^8V?R;6z3X06w~*~v@j4`~v&FMoY*gwLQNgH9!5}hXV*=*@bxd*ADRyVj zZXX&%xuBv5WX!-#EXLV_b%oV5+t`L?JF9KBMJHBO2kvSa5FB8%l^7QjKvQfB+NAAJ zo>J?n9~2y9T^}^4W_BOND7Otrdc;-1tBiaqfAc)na&_{?mYW+q^^DzxP1dz| z1`gEqtFEfqnjq@EY}VU)rDAO7Wneto+3HXN)D15KdsF%fC7=PB5)&X`wiqLy*<%7u zN=L^8(h0SgfR`{PCNP_@Z%kkzAy)!r054Dn>c9pd>!}0|0=&QgrITU;pAxEQA0G?t zN7j0e?y2!?@dy3JzAf}D)cv}@=2LSYsLVF@baOCgdOdw8!JKLJIy$2^$5TI@)is}F z^{_3l9k6#{yPAXS$IW!a1lb!V{?x>+CjOg=e=zZ92KET{FnqWrc#DBOgFOvQ2qqZV zC)me8>1)4WKO=oj@EQYA|A(Y9p}v8rZy@R$i24SizJaK3AnoIN_S@!HRPnj#>}2x| zvrg&j*@0$z&B(2_lc&~iV83s+Z&IkSZzyIP4<@C zgSQ)g+?IWte!G5qjT@gUbvL!|S^uP-?aSVhJs}VEJX<{Ln{-be`renZHIFqld*+nW3XN`-e#j!D`;-IPdxXan1xU2g1>@{?x%T+{q!Ct*M?Zbi~j@>lI zJiU4qF+TfToEoD~<-BP2Ecm6ks`|2xr>D6Trm;;1Dok(b8(mhVMVD9EvF_nq&+gf? zX3lKv%QW(*EtP%Ita~TG*E#0~GP}-tn%K+4D^2Vza9wbn zPGW3u0_u<+`(LEqVC-j*_A~Flj?b6cx0N@Du`^BX^S1x~`y8vYmLV8382j>#tFXIX zfp%SBTtHwvi!$=LS72{lpcW9A!W1Ll=@po!3mg^@=wdb_Kiw;Ej4sd<5IBL^jeM_H zU@nbQ^7EkYnsE%C`TyYaST8;geYVjbEFjR%<{0TXufTX+AP)%a&E^{Unpa?oF3=ee zn8t22@`rf^x^#i*0fA%KO-8=QD{z7?&>Ilg?(>MPF^`%1jXKOL(5?%N3kZy7dyIVU z71&!Bs09S3uwNMYPOrc;UEr{QKo@H?^3%Nn$LIn*0f7_P14h2rE70(<> zJ+7sXL3XHFvBks&*<;Nt${rnLo0_lP-*jCJ=14E*ycm~D?dx+_1(lSlV6ZwER0bNT zD`{1_G8E>7rB>xCT2)fCTGcnYs%ix4MZ4Up*Oj661=lZHuiDD&Lu>5zv|udeSj=f2 zyS=D}#k8ekVv7D9bKC}(D{UilJA-a_JnMmwZltS->Nqi9d0Yj{#N4Q)#RgLiPEQr{ zD3uMSVjjgD5ga_kK-$Nka;6##{60%q9aP?|4q{F1H;nhT4J+z#&9EYxUueGXeJEP@ z5+07$b;3Zj9w2-!8gnb~IPwWkBcJdL@(JHYelOsA(K>DOp9hMo55?HoKM#}vqk*fd z4;$%qKuXtE2aNP)Af>leKWC)311Y_``nZv908)Bi^=Tu007&UW)n|-!6Oht@>bH&b zQ6Qy{SHEYZPXeRaWAdkWa{gI4wo@V6UV*>8c*{-~85 z13rIuD|;UJ;ytbGCE&|HZ)Gn7Pu<_jUID)PU@LnS_#Y3ova`T99&2T90N;ADm7N2= z^K>hF2l$6)TiFM|Kb~l1e+2&NR4WG1U1uJB91sOO0#ty{0ByiyKs)dyU^MU*U<~jq zFcx?Y7zg}-v4x1!7OY>OFUVi8X@MTo5=(OkJ}-*f=yRahZ6RX$9&GHX9(N+*dkZ~YzeFgwj5RrD}j~4%3+nTRj}2twXk)tDwqzdfz`q`!8XI{ zV0Xc`z_!8aVcTIlU^`*EV7p=W!5UzDV2!Z7uzj%oumiA%VFzK4!VbY6hc&?t!;Zk3 zVFB1Pu%oc&V8>uDz>dRSf}Mn&hP?uN4R!{07WO9WZP+>3yRi3QAHY6@eFXa$_6h7$ z*k53OgM9}32kdj$7qBm3tXfwTm>m`ai-W1KcvvEg!+ODb!;)YcEE$#p8whj4QekPZ zp|D}F;V>7>4NHe*z{bF`U>?|b*aX-_m=`t)mJ6E#%Y#jW&4A5<&4$gbrue$n<&LK@ zWP>YxD8<*C5u0nH;&2||a;|r|hc;l`u^VakP#S}r7^{qQTwEH}7iHA=8kCj!(WBHN zF)q2|X+41PDjMe-(KzO6R5h)J#fxzc=NK{IV~CHBM|+8B)^TVE#-WQnrs;U^{g^p! z{l*-(o;Js=-3v>-2PmdOy@M1}0S~cg1Dm`8kFXdc9q~yk0{6N<~EbeZ4NWHIsPKH`A8?X`Ixy)oKIfAUTwKu z-PjUaV{dUs*S1t2y{m=BwN&;_>$WH*IlOO)uaB-vA>7`hE`{cy$BcQP3>eLb-;qH4 z>gwMZ{5qiEclNAH3IA-U{H{Ktz1{t&UTyW$M!n5IQSZJ1bt&W*m2DV=Ho`tn;`a;@ zZ8Tm(ZEUOlozcd2plD<7NaA-_pD_3apy2n7Mtfn$1GTq*tZ45*Hr3l#-D=c302K8e z{*Lg8G$<;6FbC~~9Y@4J>ZAIHs$VziHvvWcLsRQgIy#1^{^K*zE`9$HH&D814$|i5 zOzFcn*QIzVAMpgGkK7{K{i67d0Qo>^j;!kgce`b6{=~sRvKDXC?0)NbS>0-yb8|Oy1q_K^N>$0ja zqX*~e*?P7zrYd&Kn4%lEHtAxFcg1_4?O?3NxSn3L!Ie=ol)ZQ|X{^Em;87391v}N> zw~G{#?;!6W*_Aa|s_%w3`SYRJ3`74=LJw!bcSC zR>H>=?KZ+E6>S0GQ;N2b@M%R`MEI;*u%}I{BkW_-?jr1G)3y)}uxZ-}2idfG!XY+oJK;4pZ3p2{#Q6Rk zWRDqm$PK>UcQ{^t&oYqS?t39#e!sGi-r+kNug}2u4d+nfk>2V1D4soO)Nf)>8F*wO z_+7rwoH~W$k10ub>4(Whzc%nYl$ltjI>1TZ7 z69Xc>yB_JIzG;d2G$X%Z2hz{^ZcPk`bmK0hkNK7->eG$tj9<|OH(z*nf8j47@94DlBo zQ{tSC>Rm9V#JL{jUvx}~bHh%>lsI>!_GL_I)t875#gsU=q_m7Ft-2HOp_mfqo|Nt| zro_1^mFqC3#JMY#J1?fhxome#DRjGUYfKr^|H5L*kRBKoQ-<`xu$WTnh21e_trF;#a=*{PUPe$Qk~DZgJbrj)vTM=_<;?I&S;rkGOd_cErGI(}zkN~!mEWlR~a zGs>7UTyLzwK2l7{Or2E5lu|F1F{RW^WlSmcQyEiA9aYAZQcsmJrPNhrOeyu%J=p#? z9aDCse_CV8C|rYi_4Aq*;>bos*PN$Y(!5PEWhDLbQglmdk7CM5`ekQzOKO8+%1HX< zrRbK_9>tWA^vjFUEvX%fDZ};43+a|tJ(Bdv&ghsFV_J1e(k-nzCdHUik0gIcpX^9G z505deIwk3rRvlBum{#4=s#j91N!yh%rd7AJ>Xj60Qn^Smrd6*b{gTQ>iZQKvCFz$` zj$%xyS5lhv%kCKS+ZAJmbjoo3QtFgpv8L2d&eJbVoibd%?2a)b$C%x^Ww&k_Nw<{W zFE_?Z(ksJtOc`T#q+3dTaw6Ixy)sqA=zqZ~Lx@FP7&ettzUZ$8bl71;;%8Su0sXdA*Bk7kirtC;5hO8I@epl)g2SKLud*^zG9 zQB3(&y5%<&Q+Df?oxg9>s#~^xE_BOuJd4!Vz7h9cdNz6R9E_@LJWs!jq+43Ak<)xm z*Te}sb6+Ku3%}+aNr$vv^QJmEIg(PkEOz|0f$}t?bhSk z^tkS8yxrG$JA2u)0Huj)Zse1wmS*qrlGf)_ZP}*x>Bd>>@{6^pJ7*C)2)K< zbYEZ%kgnOvYqjzk?H2HKzn{D&E3d`MYq0krpYG?|1Eg!NQtOJNdwG@fuX9GyqP{9l zh@>@H@0FuDD?Cn!q%D0_oDfMnvflef+lUk=glk1qt~*Zn_B?mc9VZy~6gPHjEEm>T zy5ofIIN@UB1gW{mI6-PFosARZJ!3LXkoSglHcpTl%zu|S0Z;z6u-EZS7{|%nTR^;r z%s7X7-F)^-U|(IJ77*Bvy=CMldjMDBvDru0yh7k@{SQ3jqmAD?yLIbY=^2zS1o|7}G;N~g*{=Il-8avhV{rg!I zVjb>tR`v9yR?&Jgp2O-}J>T9R*spq_o!@+u%D+2L<+sdN`Tw|8<+tCa^85uV_b*iW z;zcSiEKvEAvQp*WTcz?ft5yE}wJQI?I+a&f zseHYz@(ndA-&m{iA8u0lk2b6P$8{?I$z3Ynx<%!8Z&UeC>s5Zwc9s9z4wc`#Q{_L~ zrShNeR{6i*r}Fz7RQ`)SDu19+<-gpk@(1^+{8#%`{?GxH|N3E-KXOpz|MO9mKXypv zPdu*jC!19Mo5L!9>WIo)npOUEK;^%EM&-{QRr&9pQ~C49RQ}=%DnD^t3^O z-g;W)ufC%4|9DO1ub)x*8)sGi)|)E-{o5*k=bXy_@UF_=e^2Fq{6OV@`cUQn>m!x_ z`D2y;?@v_zpP#DyuYXbb-~Xoa|MQv3|M?G<|KHD5{;w}o9{f_}QBe+VvpINlw1daS zI(UyB4(@O`ctV1MUvY(l_w4E5S6=Dhefl_f-@Xpsub+eW@9*FP1~~XtS2_5gK@L86 zu!9d7;^0?b?cmp3gO48V;F*~YK6b2wj~nOU+1U<$?X?d6o$omK zb=Ns~PL6|5p6uX0pMzh2y@OAk>fqC-JNV3*4t~Q84nAj&gWq_guHR#Nl~=JkNijX9D#Zp87s`PQMSH(;sou7}T5R&T(yf*)Y1*Wav1_sU1(`Oy*c;Ku^^ zu>^jUbiE%s-pOJ3F&};`gdaulqpa)w*e3hqR`{_9ek_L{SHH;m2n9v90U<*e3g9E&Ql~A9e7fzU%$q ziK0K~`A51YgQ01-SD;(R`y`Nd1q7xWWqKG~hF4&go?xWM1O$3$e(z%&R>kPf**U}$A0+nXxIA@8ov&~k0bEoDEv6y^?roLugBp>GyHfC ze!SH6euT!aCioG6AIIRw$*%V!G=3e1AJ4##7vRU~uJ_{`8^3y*;}@ZI{Gv3CUxe22 z>lO6Fo9Ku4&<`JXeLsZ8uh-zm+wkK9`0+{C`w<$y&cKgz@Z&@H@oCrl5gNbF!jE_1 z$4BtvFJ13PX#Dyc{P+TX*t69U@z-{Jd@kdS&)~_P~Jyso&f__LtKe)QSA420- z68snlKZe2&ch~z78oxC7;e;Q<;75Ab`w<$ylHo@x{1^^DGP>T6(D*e5eoTNLx$t8~ z*ZUEQzp~)RMEEfUe$486KSJY|2Yz_rM;`o`-SvKi#;@`4V-ozB20!L@y&vD$_;sZ@ zei2&7FG|z+MQ9nn9*(LJTAt3>ccH({pm+1f#p!yF9u_^Xu9>tr(({Bq2R$#2G1Bp3 zq4)KMmew0u8fkg~ywiXDccB5s$Lo4Rf@otrX@XnOhG%Gv$CDPD?WrFREieW3QYsje zI$#Y`@!nrgy%%juolo2g6&EkwkMD|)i^2=kp!J#BpcC(V3u%8wJyZK@ukpE{@wuV# zr9&(E6fWRy&qjQ5B9}qZ^^)%9xyaK1R*BBJM^BBj%cqL3@F!_mLe1^$S z3gfd(eo7lZVmtO@H}<0u`*Fb1-#ftH2YxU3hb{fR6MO^sec%sT`g<4nJ>d6)f7H_7 zhrk~Je-!+2OMgEOz8U;;;9s)zcN6#k_+#KtTKfAi_-DYs0RHqh*xyD>fqtUzp0&S? zn8M_({Y^aeHcw-I}oytThy!G64n{df=i@v)`9UjzR(_z%E;V(ITQ;Lm~o5d5c> z{yq!-UGN`)|BI!+{|5dG@b=M&!86;A@6W)02|fmVVrJX%{SWYL6k=5HJhSci{v5ml zUIpJP^P4-qd!yfmqo4YO^*8ZulehLa@zmc`-rC>9)A3EbwZD6hM9hc%NW*@(GTX+J zB=7^l4+ZbeY#UEB@J{f2V1Fl}f2hCv0Tlw${Y^aeH}RJK zPPVzAi=Gd*J7z@8>C-v}L+wAl49dJnT$I4*3`;t+uA zpJItNakrU|57AiCO#J(o!BSeCZ{}m{HI{W|+DwLE#C~E?HJwSTU-ieDSE_;#Dg~f>~a;xY9^C z(+djxnf`($W3$F(Wfi9T{fpBJ(lbYo%UoPokU4H_cHwCJp!4EIi?0c}eclSDS1hq? z!jIc$(eF4B--s8hS{J-G6H6>RCOgi|6vbj$S=hEln&sIDv!DqclWz>;C!2hpt9|)d zCO=_Rd;S)a=Naw!0+TPgy*+Py$ngIVU2wE4YmN0-_C=ww8>a{iGhhXI;@Xj^eGTI9S$&3_gQmdg5a{*fq$ zS(z{E8;#&!KNz(w^W}Uc(TO}-WWMxY;#J7AS@PxlC{b>o+MtE@!Ma?H1aY!``8$jf zz1G554^Pu@PW9#d%~Fw|1uu}MC3l&?(El(8#v+62(mv24$BDe}i*%U)qmO7+_8;A@ zMax7Oak74s7{FPRSwBy96cxr2ws|(w{Z+Ytf>0X3BId u0m*?d>AyUF7J!q!n>7q5sun4^|1wW*SEh@w5~)7@$__35unJ*G_WuB4?3dI4 diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index 1cd04937b7..c74773e24c 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -697,11 +697,6 @@ def test_absorb_decode_fp8(): return err, us_asm_decode def test_absorb_decode_3buffer(): - num_works = work_indptr[-1].item() - for i in range(num_works): - print( - f"work_info_set[{i}, 0]: {work_info_set[i, 0]}, [{i}, 1]: {work_info_set[i, 1]}, [{i}, 2]: {work_info_set[i, 2]}, [{i}, 3]: {work_info_set[i, 3]}, [{i}, 4]: {work_info_set[i, 4]}, [{i}, 5]: {work_info_set[i, 5]}, [{i}, 6]: {work_info_set[i, 6]}" - ) out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1) @@ -736,11 +731,6 @@ def test_absorb_decode_3buffer(): out_ref_fp8, msg="mla_decode-absorb_fp8 [golden fp8 vs golden]:......", ) - # print(f"kv_buffer_bytes shape: {kv_buffer_bytes.shape}, kv_buffer_bytes stride: {kv_buffer_bytes.stride()}, kv_buffer_bytes: {kv_buffer_bytes[0:1,]}") - # print(f"q shape: {q.shape}, q stride: {q.stride()}, q: {q[0:1,]}") - # print(f"qo_indptr: {qo_indptr}, qo_indptr stride: {qo_indptr.stride()}, qo_indptr: {qo_indptr[0:1,]}") - # print(f"kv_indptr: {kv_indptr}, kv_indptr stride: {kv_indptr.stride()}, kv_indptr: {kv_indptr[0:1,]}") - # print(f"kv_indices: {kv_indices}, kv_indices stride: {kv_indices.stride()}, kv_indices: {kv_indices[0:1,]}") (attn_logits, attn_lse), us_asm_decode = run_perftest( aiter.mla.mla_decode_fwd, From e0f10abef4060d174e3bbb81c243aea8718bd940 Mon Sep 17 00:00:00 2001 From: minmengdie Date: Thu, 29 Jan 2026 06:12:22 +0000 Subject: [PATCH 4/4] fix the comments --- aiter/ops/attention.py | 10 +++++----- csrc/kernels/mla/metadata/v1_2_device.cuh | 14 +++++++------- csrc/kernels/mla/metadata/v1_comm.cuh | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 4219dfec4d..c720c11d03 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -880,14 +880,14 @@ def get_mla_metadata_v1( """ Inputs: cumulated seqlens of q/o: (batch_size + 1), dtype torch.int32. - cumulated seqlens or page indices of k/v: (batch_size + 1), dtype torch.int32. + cumulated page indices of k/v: (batch_size + 1), dtype torch.int32. Length of last page of k/v: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to num_heads_q // num_heads_k. num_heads_k: num_heads_k. is_causal: Whether causal mask is enabled. Options: Detailed settings for spliting. All of them are optional. page_size: default=1. The size of a page. - kv_granularity: default=16. The granularity on kv sequence length when cutting batch. + kv_granularity: default=16. The granularity on kv page nums when cutting batch. max_seqlen_qo: default=-1. Used to check lds usage and save time. value less than 1 means unknown. uni_seqlen_qo: default=-1. Sequence length of qo is uniform across batches. value less than 1 means the length is not fixed. @@ -905,11 +905,11 @@ def get_mla_metadata_v1( [2.2] q_start: (#work), The global index in seq where q/o starts. Use global index here can reduce memory access count in kernel. [2.3] q_end: (#work), The global index in seq where q/o ends (not included). - [2.4] kv_start: (#work), The global index in seq where k/v starts. - [2.5] kv_end: (#work), The global index in seq where k/v ends (not included). Note that + [2.4] kv_start: (#work), The global index in page where k/v starts. + [2.5] kv_end: (#work), The global index in page where k/v ends (not included). Note that this value indicates the end of last qo sequence if there are multiple qo sequences included in the current work and causal mask - is enabled. + is enabled when page_size is 1. [2.6] kv_offset: (#work), Remaining length in seq from kv_end to the end of current batch. [2.7] pad (#work, 1), Pad to 8 DWs. [3] reduce_indptr: (sum(qo_seqlen_blk_count) + 1), diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index b6d2d61784..0d37f9f81b 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -91,7 +91,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ const int32_t num_blocks = integer_divide_ceil_power2( seqlen_kv, params.kv_granularity, params.kv_granularity_log2); const int32_t num_qo_tiles = get_num_qo_tiles(bid); - sum_blocks += (num_blocks + params.k_fixed_over_head_num_blocks) * num_qo_tiles; + sum_blocks += (num_blocks + params.fixed_over_head_num_blocks) * num_qo_tiles; if constexpr(QoState::is_unique() == false) { @@ -116,7 +116,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ // expected payload handled by each cu part. const int32_t payload = ck_tile::integer_divide_ceil(sum_blocks, params.num_splits) + - params.k_fixed_over_head_num_blocks; + params.fixed_over_head_num_blocks; const int32_t page_size = params.page_size; int32_t curr_batch = 0; // batch ID of the batch which is under review int32_t curr_kv_block = 0; // #blocks handled by previous cu part(s) @@ -149,7 +149,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ const int32_t remain_kv_blocks = num_kv_blocks - curr_kv_block; // If current cu part is able to handle this batch of seqences - if(remain_payload >= (remain_kv_blocks + params.k_fixed_over_head_num_blocks)) + if(remain_payload >= (remain_kv_blocks + params.fixed_over_head_num_blocks)) { const int32_t num_splits = curr_n_split_idx + 1; @@ -234,7 +234,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ tot_qo_tiles += 1; num_works += 1; - remain_payload -= (remain_kv_blocks + params.k_fixed_over_head_num_blocks); + remain_payload -= (remain_kv_blocks + params.fixed_over_head_num_blocks); // update state curr_qo_tile_idx = @@ -282,10 +282,10 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ } else { - if(remain_payload > params.k_fixed_over_head_num_blocks) + if(remain_payload > params.fixed_over_head_num_blocks) { const int32_t consuming_blks = - remain_payload - params.k_fixed_over_head_num_blocks; + remain_payload - params.fixed_over_head_num_blocks; auto fill_work_info = [&]() { MlaWorkInfo work_info{}; @@ -470,7 +470,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba params.is_causal = is_causal; params.topk = (topk < 0) ? topk : (topk + page_size - 1) / page_size; params.qk_batch_ratio = qk_batch_ratio; - params.k_fixed_over_head_num_blocks = max(1, (16 + page_size - 1) / page_size); + params.fixed_over_head_num_blocks = max(1, (16 + page_size - 1) / page_size); // launch kernel MLA_METADATA_DISPATCHER( diff --git a/csrc/kernels/mla/metadata/v1_comm.cuh b/csrc/kernels/mla/metadata/v1_comm.cuh index cdfb45664b..c1df3ac8c5 100644 --- a/csrc/kernels/mla/metadata/v1_comm.cuh +++ b/csrc/kernels/mla/metadata/v1_comm.cuh @@ -61,7 +61,7 @@ struct MlaMetadataV1KernelParameter int32_t qk_batch_ratio; int32_t num_splits; bool is_causal; - int32_t k_fixed_over_head_num_blocks; + int32_t fixed_over_head_num_blocks; }; struct PaMetadataV1KernelParameter : MlaMetadataV1KernelParameter