From c1dae3a7680c34a0c4cf40bf86b36a7e6ca2e636 Mon Sep 17 00:00:00 2001 From: irexyc Date: Tue, 16 Sep 2025 13:05:41 +0000 Subject: [PATCH 01/31] use driver flag --- src/turbomind/models/llama/LlamaBatch.cc | 47 ++++++++++++------------ src/turbomind/models/llama/LlamaBatch.h | 1 + 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 3c33ea133d..a26512956a 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -218,7 +218,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& int idx = 0; for (const auto& r : reqs) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[ProcessInferRequests] Request for %llu received.", r->id); } @@ -246,7 +246,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& s = ptr->tokens.size(); } else if (s > ptr->tokens.size()) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_WARNING("[ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu", s, ptr->id); } s = ptr->tokens.size(); @@ -379,7 +379,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1 if (state.seq_len_limit[idx] >= session_len_) { state.seq_len_limit[idx] = session_len_ - 1; - if (tp_rank_ == 0) { + if (is_driver_) { const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx]; TM_LOG_WARNING( "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `max_new_tokens` is truncated to %d", @@ -870,6 +870,7 @@ LlamaBatch::LlamaBatch(DataType data_type, tp_rank_(model->tp_rank_), data_type_(data_type), debug_(isDebug()), + is_driver_(param.attn_tp_rank == 0), stream_(ctx->stream), context_(std::move(ctx)), model_(std::move(model)), @@ -998,7 +999,7 @@ void LlamaBatch::ComputeAndOutputLogits(const Tensor& hidden_states, int first, auto logits = model_->postDecodeEmbedding(hidden_states, symm_logits_buf_.buffer()); - if (tp_rank_ == 0) { + if (is_driver_) { OutputLogits(logits, first, last, GenerationConfig::kAll); } } @@ -1159,7 +1160,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } // ! Only rank-0 writes to output - if (tp_rank_ == 0 && output_logprobs) { + if (is_driver_ && output_logprobs) { NvtxScope scope("logprobs"); float* sampled_logprobs_ptr = h_sampled_logprobs_.data(); uint32_t* sampled_indexes_ptr = h_sampled_indexes_.data(); @@ -1186,7 +1187,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } // ! Only rank-0 writes to output - if (tp_rank_ == 0) { + if (is_driver_) { NvtxScope scope("output_ids"); for (int i = 0; i < batch_size - g.partial; ++i) { if (auto& r = state_->requests[i]) { @@ -1202,7 +1203,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) // Cache computed blocks to block trie sequence_manager_->CachePrompt(state_->sequences, batch_size); - if (debug_ && tp_rank_ == 0) { + if (debug_ && is_driver_) { for (int i = 0; i < batch_size; ++i) { // ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")"; std::vector tokens(state_->h_context_length[i]); @@ -1243,7 +1244,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) // Interrupt should reset r FT_CHECK(!r); } - else if (r->stream_output && tp_rank_ == 0) { + else if (r->stream_output && is_driver_) { const auto seq_len = *r->sequence_length.data(); // Create signals by copying the request handles for non-finished streaming requests signals.push_back([this, r, seq_len] { // @@ -1270,11 +1271,11 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) auto LlamaBatch::Interrupt(int index, bool force_stop) -> Signal { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Interrupt] slot %d, request %llu, stop %d", index, state_->requests[index]->id, force_stop); } - if (debug_ && tp_rank_ == 0) { + if (debug_ && is_driver_) { std::vector tokens(state_->h_context_length[index]); core::Copy(state_->output_ids.data() + index * session_len_, tokens.size(), tokens.data()); cudaStreamSynchronize(stream_); @@ -1350,7 +1351,7 @@ void LlamaBatch::InternalThreadEntry() std::shared_ptr req; - if (tp_rank_ == 0) { + if (is_driver_) { req = std::make_shared(); { NvtxScope _("pop"); @@ -1394,7 +1395,7 @@ void LlamaBatch::InternalThreadEntry() ProcessCancelRequests(req->cancel, signals); - if (tp_rank_ == 0) { + if (is_driver_) { gateway_->notify(std::move(signals)); } @@ -1418,7 +1419,7 @@ void LlamaBatch::InternalThreadEntry() comm_.h_tp_group->Sync(); } - if (tp_rank_ == 0) { + if (is_driver_) { gateway_->notify(std::move(signals)); } } @@ -1451,7 +1452,7 @@ bool LlamaBatch::Forward(GenerationState& g) const int active_size = state_->active_size; constexpr int kLogInterval = 10; - if (tp_rank_ == 0 && (g.step - 1) % kLogInterval == 0) { + if (is_driver_ && (g.step - 1) % kLogInterval == 0) { TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1); } @@ -1531,7 +1532,7 @@ bool LlamaBatch::Forward(GenerationState& g) const int dc_batch_size = p ? 0 : pf_offset; const int pf_batch_size = mini_batch_size - dc_batch_size; - if (tp_rank_ == 0) { + if (is_driver_) { if (pf_batch_size) { const auto max_q = *std::max_element(h_input_length_buf_.data() + first, h_input_length_buf_.data() + last); @@ -1647,7 +1648,7 @@ bool LlamaBatch::Forward(GenerationState& g) }); AnomalyHandler::instance().Reset(); - if (debug_ && tp_rank_ == 0) { + if (debug_ && is_driver_) { std::vector curr(active_size); core::Copy(token_ids_buf_.data() + g.step * active_size, active_size, curr.data()); cudaStreamSynchronize(stream_); @@ -1704,7 +1705,7 @@ void LlamaBatch::Warmup() if (auto str = std::getenv("TM_GEMM_IMPORT")) { std::ifstream ifs(str); const int n_imported = linear.Import(ifs); - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Gemm2] %d records imported", n_imported); } return; @@ -1722,7 +1723,7 @@ void LlamaBatch::Warmup() bss.push_back(max_forward_token_num_); } - if (tp_rank_ == 0) { + if (is_driver_) { auto str = Join(bss.begin(), bss.end(), ", "); TM_LOG_INFO("[Gemm2] Tuning sequence: %s", str.c_str()); } @@ -1745,7 +1746,7 @@ void LlamaBatch::Warmup() /// NOTE: No explicit barrier can be used here as internal threads are waiting on it now for (auto token_num : bss) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Gemm2] %d", token_num); } @@ -1774,7 +1775,7 @@ void LlamaBatch::Warmup() auto tock = std::chrono::steady_clock::now(); - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_INFO("[Gemm2] Tuning finished in %.2f seconds.", std::chrono::duration>(tock - tick).count()); } @@ -1784,7 +1785,7 @@ void LlamaBatch::Warmup() check_cuda_error(cudaStreamSynchronize(stream_)); // Only rank-0 exports the dispatch cache - if (tp_rank_ == 0) { + if (is_driver_) { if (auto path = std::getenv("TM_GEMM_EXPORT")) { std::ofstream ofs(path); const auto n_records = context_->linear->Export(ofs); @@ -1830,7 +1831,7 @@ void LlamaBatch::InitializeBufferAndKVCache() const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len; if (max_session_len < session_len_) { - if (tp_rank_ == 0) { + if (is_driver_) { TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.", session_len_, max_session_len); @@ -1915,7 +1916,7 @@ void LlamaBatch::DestroyCommunicators() void LlamaBatch::UpdateMetrics() { - if (tp_rank_ == 0 && param_.enable_metrics) { + if (is_driver_ && param_.enable_metrics) { // update schedule metrics int total_seqs, active_seqs, cached_seqs; std::tie(total_seqs, active_seqs, cached_seqs) = sequence_manager_->seq_stats(); diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index cf604a0a4f..2280ef739f 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -220,6 +220,7 @@ class LlamaBatch { const int tp_rank_; const DataType data_type_; const bool debug_; + const bool is_driver_; // Refs into `Context` cudaStream_t const stream_{}; From bb27b62074aa921f826dad91138f2875ddcaab50 Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 19 Sep 2025 13:25:35 +0000 Subject: [PATCH 02/31] update --- lmdeploy/messages.py | 1 + lmdeploy/turbomind/turbomind.py | 11 +- src/turbomind/comm/device_comm.h | 12 ++ src/turbomind/comm/nccl/nccl.cu | 20 ++- .../kernels/attention/attention_params.h | 7 + .../kernels/attention/attention_template.h | 3 + .../kernels/attention/attention_universal.h | 49 +++++-- .../kernels/attention/decoding_template.h | 3 + .../kernels/attention/kv_cache_utils_v2.cu | 49 +++++-- .../kernels/attention/kv_cache_utils_v2.h | 8 ++ .../kernels/attention/mainloop_sm70.h | 11 +- .../kernels/attention/mainloop_sm80.h | 11 +- src/turbomind/kernels/attention/reduce.cu | 9 ++ src/turbomind/kernels/attention/reduce.h | 3 + .../kernels/attention/reduce_kernel.h | 21 ++- src/turbomind/models/llama/CMakeLists.txt | 4 +- src/turbomind/models/llama/LlamaBatch.cc | 35 +++-- src/turbomind/models/llama/LlamaBatch.h | 5 + src/turbomind/models/llama/LlamaV2.cc | 6 + src/turbomind/models/llama/LlamaV2.h | 3 + src/turbomind/models/llama/SequenceManager.cc | 5 +- src/turbomind/models/llama/SequenceManager.h | 2 + src/turbomind/models/llama/context.h | 2 + src/turbomind/models/llama/cp_utils.cu | 121 ++++++++++++++++++ src/turbomind/models/llama/cp_utils.h | 20 +++ src/turbomind/models/llama/llama_params.h | 2 + .../models/llama/unified_attention_layer.cc | 119 ++++++++++++++++- .../models/llama/unified_attention_layer.h | 12 ++ src/turbomind/models/llama/unified_decoder.cc | 3 +- src/turbomind/models/llama/unified_decoder.h | 2 + .../triton_backend/llama/LlamaTritonModel.cc | 53 ++++++-- 31 files changed, 564 insertions(+), 48 deletions(-) create mode 100644 src/turbomind/models/llama/cp_utils.cu create mode 100644 src/turbomind/models/llama/cp_utils.h diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 73c02e2914..7065106b92 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -237,6 +237,7 @@ class TurbomindEngineConfig: dp: int = 1 device_num: int = None attn_tp_size: int = None + attn_cp_size: int = None attn_dp_size: int = None mlp_tp_size: int = None mlp_dp_size: int = None diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index dac5325364..09ccb5076c 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -84,6 +84,7 @@ def complete_parallel_config(cfg: TurbomindEngineConfig): def update_parallel_config(cfg: TurbomindEngineConfig): if not complete_parallel_config(cfg): + attn_cp_size = cfg.attn_cp_size or 1 total = cfg.dp * cfg.tp if not cfg.device_num: count = torch.cuda.device_count() @@ -97,11 +98,12 @@ def update_parallel_config(cfg: TurbomindEngineConfig): inner_tp_size = cfg.tp // mlp_tp_size cfg.outer_dp_size = cfg.dp // attn_dp_size cfg.attn_dp_size = attn_dp_size - cfg.attn_tp_size = inner_tp_size + cfg.attn_tp_size = inner_tp_size // attn_cp_size + cfg.attn_cp_size = attn_cp_size cfg.mlp_dp_size = 1 cfg.mlp_tp_size = mlp_tp_size * inner_tp_size - assert cfg.attn_dp_size * cfg.attn_tp_size == cfg.mlp_dp_size * cfg.mlp_tp_size - assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num + assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size + assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num cfg.devices = cfg.devices or list(range(cfg.device_num)) @@ -231,6 +233,8 @@ def _get_params(device_id, que): tm_params[k] = [v] else: tm_params[k].append(v) + # for k, v in tm_params.items(): + # print(k, len(v)) logger.warning(f'get {len(tm_params)} model params') def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: TurbomindEngineConfig): @@ -269,6 +273,7 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig): self._postprocess_config(tm_model.tm_config, engine_config) + print(yaml.safe_dump(self.config_dict)) model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='', config=yaml.safe_dump(self.config_dict), weight_type=self.config.model_config.weight_type) diff --git a/src/turbomind/comm/device_comm.h b/src/turbomind/comm/device_comm.h index a6948762df..8a5960c3af 100644 --- a/src/turbomind/comm/device_comm.h +++ b/src/turbomind/comm/device_comm.h @@ -54,6 +54,18 @@ class DeviceCommImpl { int group, cudaStream_t stream) = 0; + virtual void AllGatherCP(const void* send_M, + void* recv_M, + const void* send_L, + void* recv_L, + size_t sendcount, + DataType type, + int group, + cudaStream_t stream) + { + throw std::runtime_error("not implemented"); + } + virtual void ReduceScatter(const void* sendbuff, // void* recvbuff, size_t recvcount, diff --git a/src/turbomind/comm/nccl/nccl.cu b/src/turbomind/comm/nccl/nccl.cu index 44b6e8d55b..9fc7d694fa 100644 --- a/src/turbomind/comm/nccl/nccl.cu +++ b/src/turbomind/comm/nccl/nccl.cu @@ -222,10 +222,11 @@ public: int Split(int color, int key, int group) override { - auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit); + // auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit); ncclComm_t comm{}; - NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr)); + // NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr)); + NCCLCHECK(ncclCommSplit(groups_.at(group), color, key, &comm, nullptr)); int index = groups_.size(); groups_.push_back(comm); @@ -260,6 +261,21 @@ public: NCCLCHECK(ncclGroupEnd()); } + void AllGatherCP(const void* send_M, + void* recv_M, + const void* send_L, + void* recv_L, + size_t sendcount, + DataType type, + int group, + cudaStream_t stream) + { + NCCLCHECK(ncclGroupStart()); + NCCLCHECK(ncclAllGather(send_M, recv_M, sendcount, to_nccl_dtype(type), groups_.at(group), stream)); + NCCLCHECK(ncclAllGather(send_L, recv_L, sendcount, to_nccl_dtype(type), groups_.at(group), stream)); + NCCLCHECK(ncclGroupEnd()); + } + void ReduceScatter( const void* sendbuff, void* recvbuff, size_t recvcount, DataType type, int group, cudaStream_t stream) override { diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index 59a04368fa..b2d43d5fc9 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -79,6 +79,13 @@ struct AttentionParams { float* partial_L; int* locks; + // context parallel + int cp_rank{0}; + int cp_size{1}; + float* cp_M{nullptr}; + float* cp_L{nullptr}; + float* cp_O{nullptr}; + int arch; cudaStream_t stream; diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index 02dd8d20af..7bff780396 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -85,6 +85,9 @@ void invokeAttention(const typename Kernel::ParamType& params) params.partial_M, params.partial_L, params.partial_O, + params.cp_M, + params.cp_L, + params.cp_O, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5a1a9e7605..53cce315ac 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -276,7 +276,10 @@ struct AttentionUniversal { } iterator.block_head_.with( - iterator.block_ptrs_, ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + iterator.block_ptrs_, ti / params.cp_size, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + if (ti % params.cp_size != params.cp_rank) { + return; + } PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; @@ -371,11 +374,18 @@ struct AttentionUniversal { const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx]; const int history_len = context_len - input_len; + auto get_cp_len = [&](int length) -> int { + return (length / params.cp_size + (length % params.cp_size > params.cp_rank ? 1 : 0)); + }; + const int last_K = history_len + min(query_idx + CTA_Q, input_len); - const int last_K_tile = (last_K - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion + const int last_K_tile = (get_cp_len(last_K) - 1) / CTA_S + 1; + // const int last_K_tile = (last_K - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index + // conversion const int first_K = max(history_len + query_idx - (params.window_size - 1), 0); - const int first_K_tile = first_K / CTA_S; + const int first_K_tile = get_cp_len(first_K) / CTA_S; + // const int first_K_tile = first_K / CTA_S; const int tile_count = last_K_tile - first_K_tile; @@ -417,7 +427,8 @@ struct AttentionUniversal { const int offset_K = (first_K_tile + iter_end - 1) * CTA_S; // This is for avoiding OOB access only - const int max_K = min(context_len, (first_K_tile + iter_end) * CTA_S); + // const int max_K = min(context_len, (first_K_tile + iter_end) * CTA_S); + const int max_K = min(get_cp_len(context_len), (first_K_tile + iter_end) * CTA_S); int tile_iter = iter_end - iter_begin; @@ -430,6 +441,9 @@ struct AttentionUniversal { // -> x * CTA_S >= offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - w int mask_iter_front = cdiv(max(0, offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - params.window_size), CTA_S); + // TODO: mask all iter for simplicity, use accurate mask_iter + mask_iter_back = 999999; + mask_iter_front = 999999; #if 0 if (threadIdx.x == 0) { printf( @@ -453,6 +467,7 @@ struct AttentionUniversal { cache_iter.SetTile(first_K_tile + iter_end - 1); Mainloop mainloop; + mainloop.SetCpInfo(params.cp_size, params.cp_rank); mainloop(frag_Q, cache_iter, frag_O, @@ -491,12 +506,12 @@ struct AttentionUniversal { } } - if (iter_begin == 0 && iter_end == tile_count) { + if (iter_begin == 0 && iter_end == tile_count && params.cp_size == 1) { StoreO(frag_O, frag_L, qi_begin, qi_end, head_idx, params, storage); } else { StorePartial(frag_O, frag_M, frag_L, qi_begin, qi_end, head_idx, split_idx, params, storage); - if (!separate_reduce) { + if (!separate_reduce && split_cnt > 1) { Reduce(qi_begin, head_idx, split_idx, iter_end == tile_count, params, cta_map, smem_buf); } } @@ -527,6 +542,9 @@ struct AttentionUniversal { params.partial_M, params.partial_L, params.partial_O, + params.cp_M, + params.cp_L, + params.cp_O, qi_begin, head_idx, params.num_heads, @@ -598,15 +616,28 @@ struct AttentionUniversal { Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) { if (qi_begin + qi < qi_end && check_h(hi)) { - Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec); + if (params.max_split_k > 1) { // decode + Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec); + } + if (params.cp_size > 1 && split_idx == 0) { + const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * kHeadDim + di; + Store(¶ms.cp_O[index], vec); + } } }); Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) { const int index = get_index(hi, qi); if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) { - params.partial_M[index] = M; - params.partial_L[index] = L; + if (params.max_split_k > 1) { // decode + params.partial_M[index] = M; + params.partial_L[index] = L; + } + if (params.cp_size > 1 && split_idx == 0) { + const int index = (qi_begin + qi) * params.num_heads + (head_idx + hi); + params.cp_M[index] = M; + params.cp_L[index] = L; + } } }); } diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index 37f6baebe3..5be8c12c08 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -84,6 +84,9 @@ bool invokeDecoding(const typename Kernel::ParamType& params) params.partial_M, params.partial_L, params.partial_O, + params.cp_M, + params.cp_L, + params.cp_O, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index adb697e8c4..be9f5b7430 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -28,6 +28,8 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, int64_t stride_h, int64_t stride_s, int layer_id, + int cp_size, + int cp_rank, BlockLayout block_layout) { @@ -159,9 +161,9 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int qi = offset.y + s * Map::kDeltaS + token_idx; // local offset into `input_length` - if (qi < q_len) { - const int ti = history_len + qi; // timestep - block_head.with((char**)blocks, ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + const int ti = history_len + qi; // timestep + if (qi < q_len && (ti % cp_size == cp_rank)) { + block_head.with((char**)blocks, ti / cp_size, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { int di = offset.x + c * Map::kDeltaC; @@ -198,6 +200,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, int block_seq_len, int layer_id, + int cp_size, + int cp_rank, int max_q_len, int head_num, int head_dim, @@ -233,6 +237,8 @@ void invokeProcessKV_v2(char** blocks, stride_h, stride_s, layer_id, + cp_size, + cp_rank, block_layout); }; @@ -276,6 +282,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, \ int block_seq_len, \ int layer_id, \ + int cp_size, \ + int cp_rank, \ int max_q_len, \ int head_num, \ int head_dim, \ @@ -300,6 +308,8 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, int64_t stride_h, int64_t stride_s, int layer_id, + int cp_size, + int cp_rank, BlockLayout block_layout) { constexpr int kVecSize = sizeof(uint4) / sizeof(T); @@ -344,8 +354,8 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int si = offset.y + s * Map::kDeltaS + token_idx; - if (si < seq_len) { - block_head.with((char**)blocks, si, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + if (si < seq_len && (si % cp_size == cp_rank)) { + block_head.with((char**)blocks, si / cp_size, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { int di = offset.x + c * Map::kDeltaC; @@ -389,14 +399,27 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { - const int si = offset.y + s * Map::kDeltaS + token_idx; - const int di = offset.x + c * Map::kDeltaC; - const int64_t index = - (batch_idx * stride_b + ti_beg * stride_c + si * stride_s + head_idx * stride_h) * HeadDim + di; - if (si < seq_len) { + const int si = offset.y + s * Map::kDeltaS + token_idx; + const int di = offset.x + c * Map::kDeltaC; + // save first + if (si < seq_len && si % cp_size == cp_rank) { + const int64_t index = + (batch_idx * stride_b + ti_beg * stride_c + si / cp_size * stride_s + head_idx * stride_h) * HeadDim + + di; Store(&k[index], out_K[s][c]); Store(&v[index], out_V[s][c]); } + + // const int64_t index = + // (batch_idx * stride_b + ti_beg * stride_c + si * stride_s + head_idx * stride_h) * HeadDim + di; + // if (si < seq_len) { + // if (si % cp_size != cp_rank) { + // clear(out_K[s][c]); + // clear(out_V[s][c]); + // } + // Store(&k[index], out_K[s][c]); + // Store(&v[index], out_V[s][c]); + // } } } } @@ -414,6 +437,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, int block_seq_len, int layer_id, + int cp_size, + int cp_rank, int max_seq_len, int head_num, int head_dim, @@ -446,6 +471,8 @@ void invokeFlattenKV_v2(T* k, stride_h, stride_s, layer_id, + cp_size, + cp_rank, block_layout); }; @@ -486,6 +513,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, \ int block_seq_len, \ int layer_id, \ + int cp_size, \ + int cp_rank, \ int max_seq_len, \ int head_num, \ int head_dim, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index 01525f5596..5419979c29 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -23,6 +23,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, int block_seq_len, int layer_id, + int cp_size, + int cp_rank, int max_q_len, int head_num, int head_dim, @@ -48,6 +50,8 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.stride / params.size_per_head, // stride s params.block_iter_params.block_len, params.block_iter_params.layer_id, + params.cp_size, + params.cp_rank, params.max_q_len, params.num_kv_heads, params.size_per_head, @@ -69,6 +73,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, int block_seq_len, int layer_id, + int cp_size, + int cp_rank, int max_seq_len, int head_num, int head_dim, @@ -93,6 +99,8 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) 1, params.block_iter_params.block_len, params.block_iter_params.layer_id, + params.cp_size, + params.cp_rank, params.max_k_len, params.num_kv_heads, params.size_per_head, diff --git a/src/turbomind/kernels/attention/mainloop_sm70.h b/src/turbomind/kernels/attention/mainloop_sm70.h index c4d2e5afeb..a030d372de 100644 --- a/src/turbomind/kernels/attention/mainloop_sm70.h +++ b/src/turbomind/kernels/attention/mainloop_sm70.h @@ -40,6 +40,15 @@ struct Mainloop { static constexpr int CTA_S = Impl::CTA_S; + int cp_size_{1}; + int cp_rank_{0}; + + __device__ void SetCpInfo(int cp_size, int cp_rank) + { + cp_size_ = cp_size; + cp_rank_ = cp_rank; + } + template __device__ void operator()(FragQ& frag_Q, CacheIter& cache_iter, @@ -128,7 +137,7 @@ struct Mainloop { __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size) { Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) { - int w = (offset_Q + qi) - (offset_K + si); + int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_); if (0 <= w && w < window_size) {} else { score -= std::numeric_limits::infinity(); diff --git a/src/turbomind/kernels/attention/mainloop_sm80.h b/src/turbomind/kernels/attention/mainloop_sm80.h index 997a6aa9fc..3b07b717e4 100644 --- a/src/turbomind/kernels/attention/mainloop_sm80.h +++ b/src/turbomind/kernels/attention/mainloop_sm80.h @@ -49,6 +49,15 @@ struct Mainloop, Impl_> { using SharedStorage = typename Impl::SharedStorage; + int cp_size_{1}; + int cp_rank_{0}; + + __device__ void SetCpInfo(int cp_size, int cp_rank) + { + cp_size_ = cp_size; + cp_rank_ = cp_rank; + } + template __device__ void operator()(Args&&... args) { @@ -442,7 +451,7 @@ struct Mainloop, Impl_> { __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size) { Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) { - int w = (offset_Q + qi) - (offset_K + si); + int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_); if (0 <= w && w < window_size) {} else { score -= std::numeric_limits::infinity(); diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index c654f40d05..53bbda59ce 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -12,6 +12,9 @@ void invokeReduce(T* out, float* partial_M, float* partial_L, float* partial_O, + float* cp_M, + float* cp_L, + float* cp_O, const int* split_cnt, int partial_len, int max_split_cnt, @@ -34,6 +37,9 @@ void invokeReduce(T* out, partial_M, partial_L, partial_O, + cp_M, + cp_L, + cp_O, nullptr, split_cnt, partial_len, @@ -58,6 +64,9 @@ void invokeReduce(T* out, float* partial_M, \ float* partial_L, \ float* partial_O, \ + float* cp_M, \ + float* cp_L, \ + float* cp_O, \ const int* split_cnt, \ int partial_len, \ int max_split_cnt, \ diff --git a/src/turbomind/kernels/attention/reduce.h b/src/turbomind/kernels/attention/reduce.h index c078de5958..8fe17f5fa7 100644 --- a/src/turbomind/kernels/attention/reduce.h +++ b/src/turbomind/kernels/attention/reduce.h @@ -16,6 +16,9 @@ void invokeReduce(T* out, float* partial_M, float* partial_L, float* partial_O, + float* cp_M, + float* cp_L, + float* cp_O, const int* split_cnt, int partial_len, int max_split_cnt, diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h index b4c9064cfe..c246631e78 100644 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ b/src/turbomind/kernels/attention/reduce_kernel.h @@ -27,6 +27,9 @@ struct Reduce { float* partial_M, float* partial_L, float* partial_O, + float* cp_M, + float* cp_L, + float* cp_O, int query_idx, int head_idx, int head_num, @@ -102,7 +105,7 @@ struct Reduce { Array scale; PRAGMA_UNROLL for (int k = 0; k < K; ++k) { - scale[k] = IsFinal ? expdiff_M[k] / block_L : expdiff_M[k]; + scale[k] = (IsFinal && cp_O == nullptr) ? expdiff_M[k] / block_L : expdiff_M[k]; } if (hi < CTA_H) { @@ -124,6 +127,13 @@ struct Reduce { } } } + else { + if (cp_M != nullptr && cp_L != nullptr && lane_id % L == 0 && hi < hi_end) { + const int idx = query_idx * head_num + head_idx + hi; + cp_M[idx] = block_M; + cp_L[idx] = block_L; + } + } } __syncthreads(); @@ -195,6 +205,9 @@ struct Reduce { if (ki == 0 && hi < hi_end) { if constexpr (IsFinal) { const int offset = (query_idx * head_num + head_idx + hi) * HeadDim + di; + if (cp_O != nullptr) { + Store(&cp_O[offset], (Vec&)storage.O[hi][ki][di]); + } Store(&out[offset], cast((Vec&)storage.O[hi][ki][di])); } else { @@ -212,6 +225,9 @@ __global__ void reduce_kernel(typename Reduce::T* out, float* partial_M, float* partial_L, float* partial_O, + float* cp_M, + float* cp_L, + float* cp_O, int* signals, const int* split_cnt_, int max_split_cnt, @@ -238,6 +254,9 @@ __global__ void reduce_kernel(typename Reduce::T* out, partial_M, partial_L, partial_O, + cp_M, + cp_L, + cp_O, query_idx, head_idx, head_num, diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 1b767d1a13..b58d4e3a85 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -21,7 +21,9 @@ add_library(Llama STATIC unified_attention_layer.cc llama_kernels.cu llama_utils.cu - mla_utils.cu) + mla_utils.cu + cp_utils.cu +) set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Llama PUBLIC CUDA::cudart diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index a26512956a..8d72fd955e 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -827,12 +827,24 @@ void LlamaBatch::AllocSymmBuffers() symm_hidden_states_buf_ = {{max_forward_token_num_ * param_.attn_dp_size, hidden_units}, data_type_, symm_alloc_}; symm_logits_buf_ = {{max_batch_size_, vocab_size_padded}, data_type_, symm_alloc_}; + + if (param_.attn_cp_size > 1) { + symm_cp_M_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_}, symm_alloc_}; + symm_cp_L_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_}, symm_alloc_}; + symm_cp_O_ = { + {param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_, (int)model_->size_per_head_}, + symm_alloc_}; + } } void LlamaBatch::FreeSymmBuffers() { symm_hidden_states_buf_ = {}; symm_logits_buf_ = {}; + + symm_cp_M_ = {}; + symm_cp_L_ = {}; + symm_cp_O_ = {}; } LlamaBatch::~LlamaBatch() @@ -870,7 +882,7 @@ LlamaBatch::LlamaBatch(DataType data_type, tp_rank_(model->tp_rank_), data_type_(data_type), debug_(isDebug()), - is_driver_(param.attn_tp_rank == 0), + is_driver_(param.attn_tp_rank == 0 && param.attn_cp_rank == 0), stream_(ctx->stream), context_(std::move(ctx)), model_(std::move(model)), @@ -988,12 +1000,12 @@ void LlamaBatch::ComputeAndOutputLogits(const Tensor& hidden_states, int first, if (symm_logits_buf_.shape(0) < token_num) { if (tp_size_ > 1) { check_cuda_error(cudaStreamSynchronize(stream_)); - comm_.h_tp_group->Sync(); + comm_.h_tp_cp_group->Sync(); } symm_logits_buf_ = {{token_num, vocab_size_padded}, data_type_, symm_alloc_}; if (tp_size_ > 1) { check_cuda_error(cudaStreamSynchronize(stream_)); - comm_.h_tp_group->Sync(); + comm_.h_tp_cp_group->Sync(); } } @@ -1230,7 +1242,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } if (need_sync) { // Release updates on request output buffers to all ranks (`Interrupt` will use it) - comm_.h_tp_group->Sync(); + comm_.h_tp_cp_group->Sync(); } } @@ -1368,14 +1380,14 @@ void LlamaBatch::InternalThreadEntry() if (state_->size == g.finished_count) { // Batch is empty, use blocking sync to avoid spinning - comm_.h_tp_group->Sync(true); + comm_.h_tp_cp_group->Sync(true); } NvtxScope scope("mainloop"); // 1. Wait while rank-0 is dequeueing // 2. Broadcast `ec` from rank-0 - Broadcast(comm_.h_tp_group, req, 0); + Broadcast(comm_.h_tp_cp_group, req, 0); if (req->abort) { TM_LOG_INFO("[InternalThreadEntry] stop requested."); @@ -1416,7 +1428,7 @@ void LlamaBatch::InternalThreadEntry() // Finished requests and corresponding output tensors will be released when notified // wait for all ranks to ensure no rank (except for output thread) will access related // resources - comm_.h_tp_group->Sync(); + comm_.h_tp_cp_group->Sync(); } if (is_driver_) { @@ -1573,6 +1585,9 @@ bool LlamaBatch::Forward(GenerationState& g) state_->h_context_length.slice(first, mini_batch_size), rope_theta_.slice(first, mini_batch_size), &mrope, + symm_cp_M_, + symm_cp_L_, + symm_cp_O_, finished_buf_.slice(first, mini_batch_size), Buffer(local_token_nums.data(), local_token_nums.size(), kCPU), lora_mask_buf_, @@ -1765,6 +1780,9 @@ void LlamaBatch::Warmup() Buffer{&input_length, 1, kCPU}, rope_theta_.slice(0, bsz), nullptr, // mrope + symm_cp_M_, + symm_cp_L_, + symm_cp_O_, finished_buf_.slice(0, bsz), Buffer{local_token_nums.data(), (int)local_token_nums.size(), kCPU}, Buffer{}, @@ -1817,7 +1835,7 @@ void LlamaBatch::InitializeBufferAndKVCache() const auto get_free_size = [&] { // size_t free{}, total{}; check_cuda_error(cudaMemGetInfo(&free, &total)); - return AllReduce(model_->comm_->h_tp_group, free, comm::RedOp::kMin); + return AllReduce(model_->comm_->h_tp_cp_group, free, comm::RedOp::kMin); }; sequence_manager_.reset(new SequenceManager{model_->layer_num_, @@ -1826,6 +1844,7 @@ void LlamaBatch::InitializeBufferAndKVCache() param_.cache_chunk_size, param_.enable_prefix_caching, tp_rank_, + param_.attn_cp_size, core::Context::alloc(kDEVICE), get_free_size}); diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 2280ef739f..379155649a 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -245,6 +245,11 @@ class LlamaBatch { Tensor symm_hidden_states_buf_; Tensor symm_logits_buf_; + // context parallel + Tensor_ symm_cp_O_; + Tensor_ symm_cp_M_; + Tensor_ symm_cp_L_; + Tensor decoder_output_buf_; Tensor_ sampling_logits_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index a5c364a0a0..7947157b45 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -163,6 +163,9 @@ void LlamaV2::Forward(Buffer_ input_ids, Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, + Tensor cp_M, + Tensor cp_L, + Tensor cp_O, Buffer finished, Buffer local_token_nums, Buffer lora_mask, @@ -258,6 +261,9 @@ void LlamaV2::Forward(Buffer_ input_ids, {"decode_num", Buffer{&decode_num, 1, kCPU}}, {"prefil_num", Buffer{&prefil_num, 1, kCPU}}, {"rope_base", rope_base}, + {"cp_M", cp_M}, + {"cp_L", cp_L}, + {"cp_O", cp_O}, {"cu_block_nums", cu_block_nums}, {"kv_block_ptrs", kv_block_ptrs}, {"local_token_nums", local_token_nums}}; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 304fb97fd3..b51b2e56d9 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -69,6 +69,9 @@ class LlamaV2 { Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, + Tensor cp_M, + Tensor cp_L, + Tensor cp_O, Buffer finished, Buffer local_token_nums, Buffer lora_mask, diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index 963fccb38f..13efe09478 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -34,9 +34,10 @@ SequenceManager::SequenceManager(size_t layer_num, int chunk_size, bool enable_prefix_caching, int rank, + int attn_cp_size, core::Allocator allocator, GetFreeMemSize get_free_size): - block_seq_len_(block_config.block_len_), rank_(rank) + block_seq_len_(block_config.block_len_), rank_(rank), attn_cp_size_(attn_cp_size) { block::Layout layout{block_config}; // dump(layout); @@ -385,7 +386,7 @@ std::vector SequenceManager::CountRequiredBlocks(const Sequences& se { std::vector required(sequences.size()); for (int i = 0; i < sequences.size(); ++i) { - int seq_len = context_lengths[i] + step_length; + int seq_len = (context_lengths[i] + step_length + attn_cp_size_ - 1) / attn_cp_size_; int count = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast(sequences[i]->blocks.size()); required[i] = std::max(0, count); } diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h index 5cbdc4a426..a1c4f1615a 100644 --- a/src/turbomind/models/llama/SequenceManager.h +++ b/src/turbomind/models/llama/SequenceManager.h @@ -81,6 +81,7 @@ class SequenceManager { int chunk_size, bool enable_prefix_caching, int rank, + int attn_cp_size, core::Allocator allocator, GetFreeMemSize get_free_size); @@ -186,6 +187,7 @@ class SequenceManager { private: int block_seq_len_; int rank_; + int attn_cp_size_; // Use `std::map` to avoid reference invalidation std::map sequences_; diff --git a/src/turbomind/models/llama/context.h b/src/turbomind/models/llama/context.h index 33b7be29ac..d5e7891077 100644 --- a/src/turbomind/models/llama/context.h +++ b/src/turbomind/models/llama/context.h @@ -17,11 +17,13 @@ namespace turbomind { struct Communicators { comm::HostComm h_comm; + comm::HostComm h_tp_cp_group; comm::HostComm h_tp_group; comm::HostComm h_dp_group; comm::DeviceComm d_comm; int d_tp_group; + int d_cp_group; }; // Execution context for the model diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu new file mode 100644 index 0000000000..8072b6cfe8 --- /dev/null +++ b/src/turbomind/models/llama/cp_utils.cu @@ -0,0 +1,121 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/models/llama/cp_utils.h" + +namespace turbomind { + +template +__global__ void CpReduce(T* out, + float* O, + float* M, + float* L, + int token_num, + int head_num, + int size_per_head, + int cp_size, + int cp_rank, + float exp_scale) +{ + __shared__ float scale[WARP_SIZE]; + float frag_M = -std::numeric_limits::infinity(); + float frag_L = 0.0f; + + const int token_idx = blockIdx.x; + const int head_idx = blockIdx.y; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (warp_id == 0 && lane_id < cp_size) { + const int index = lane_id * token_num * head_num + token_idx * head_num + head_idx; + frag_M = M[index]; + frag_L = L[index]; + } + + float block_M = frag_M; + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); + } + + float expdiff_M = exp2f((frag_M - block_M) * exp_scale); + + float block_L = frag_L * expdiff_M; + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); + } + + if (warp_id == 0 && lane_id < cp_size) { + scale[lane_id] = expdiff_M / block_L; + } + + __syncthreads(); + + // for (int i = threadIdx.x; i < size_per_head; i += blockDim.x) { + // float flag_O = 0; + // for (int j = 0; j < cp_size; ++j) { + // int index = j * token_num * head_num * size_per_head + token_idx * head_num * size_per_head + // + head_idx * size_per_head + i; + // flag_O += O[index] * scale[j]; + // } + // int out_index = token_idx * head_num * size_per_head + head_idx * size_per_head + i; // q, h, d + // // out[out_index] = (T)flag_O; + // out[out_index] = (T)(flag_O / cp_size); + // } + + for (int i = threadIdx.x; i < size_per_head; i += blockDim.x) { + int src_index = cp_rank * token_num * head_num * size_per_head + token_idx * head_num * size_per_head + + head_idx * size_per_head + i; + int dst_index = token_idx * head_num * size_per_head + head_idx * size_per_head + i; // q, h, d + out[dst_index] = (T)(O[src_index] * scale[cp_rank]); + } +} + +template +void invokeCpReduce(T* out, + float* O, + float* M, + float* L, + int token_num, + int head_num, + int size_per_head, + int cp_size, + int cp_rank, + float exp_scale, + cudaStream_t stream) +{ + TM_CHECK(cp_size <= WARP_SIZE); + const dim3 block = 4 * WARP_SIZE; + const dim3 grid(token_num, head_num); + size_t smem_size = sizeof(float) * WARP_SIZE * 2; + CpReduce<<>>( + out, O, M, L, token_num, head_num, size_per_head, cp_size, cp_rank, exp_scale); + sync_check_cuda_error(); +} + +template void invokeCpReduce(half* out, + float* O, + float* M, + float* L, + int token_num, + int head_num, + int size_per_head, + int cp_size, + int cp_rank, + float exp_scale, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeCpReduce(__nv_bfloat16* out, + float* O, + float* M, + float* L, + int token_num, + int head_num, + int size_per_head, + int cp_size, + int cp_rank, + float exp_scale, + cudaStream_t stream); +#endif + +} // namespace turbomind diff --git a/src/turbomind/models/llama/cp_utils.h b/src/turbomind/models/llama/cp_utils.h new file mode 100644 index 0000000000..4ae640a51a --- /dev/null +++ b/src/turbomind/models/llama/cp_utils.h @@ -0,0 +1,20 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/core/core.h" + +namespace turbomind { + +template +void invokeCpReduce(T* out, + float* O, + float* M, + float* L, + int token_num, + int head_num, + int size_per_head, + int cp_size, + int cp_rank, + float exp_scale, + cudaStream_t stream); + +} // namespace turbomind diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 9dad3fd2fc..b406ffada0 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -106,6 +106,8 @@ struct EngineParam { int attn_dp_rank; int attn_tp_size; int attn_tp_rank; + int attn_cp_size; + int attn_cp_rank; int mlp_tp_size; int mlp_tp_rank; diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 5808541001..7b51605ec9 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -23,6 +23,9 @@ #include #include +#include "src/turbomind/models/llama/cp_utils.h" +#include + #include "src/turbomind/core/check.h" #include "src/turbomind/core/data_type.h" #include "src/turbomind/core/tensor.h" @@ -72,6 +75,9 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, local_kv_head_num_(model.kv_head_num / tp_size), param_(attn), model_param_(model), + engine_param_(engine), + attn_cp_group_(ctx.comm.d_cp_group), + d_comm_(ctx.comm.d_comm), lora_param_(lora), context_(ctx), stream_(ctx.stream), @@ -136,6 +142,10 @@ void UnifiedAttentionLayer::Initialize(TensorMap& args) cu_block_nums_ = args.at("cu_block_nums").buffer(); kv_block_ptrs_ = args.at("kv_block_ptrs").buffer(); + cp_M_ = args.at("cp_M").borrow(); + cp_L_ = args.at("cp_L").borrow(); + cp_O_ = args.at("cp_O").borrow(); + // rotary embedding, add offest when forward if (rope_param_.type == RopeType::kDynamic) { rope_param_.base = const_cast(rope_base_.data()); @@ -226,6 +236,53 @@ void UnifiedAttentionLayer::Forward(ForwardParam p) sync_check_cuda_error(); } +template +void UnifiedAttentionLayer::cp_postprocess(Tensor& attn) +{ + + const int token_num = attn.shape(0); + const int count = token_num * local_head_num_; + d_comm_->AllGatherCP(cp_M_.data() + count * engine_param_.attn_cp_rank, + cp_M_.data(), + cp_L_.data() + count * engine_param_.attn_cp_rank, + cp_L_.data(), + count, + kFloat32, + attn_cp_group_, + stream_); + sync_check_cuda_error(); + + // auto allgather = [&](float* src, float* dst, int count) { + // d_comm_->AllGather(src + count * engine_param_.attn_cp_rank, dst, count, kFloat32, attn_cp_group_, stream_); + // sync_check_cuda_error(); + // }; + // allgather(cp_O_.data(), cp_O_.data(), + // token_num * local_head_num_ * size_per_head_); // (cp, q, h, d) + // allgather(cp_M_.data(), cp_M_.data(), token_num * local_head_num_); // (cp, q, h) + // allgather(cp_L_.data(), cp_L_.data(), token_num * local_head_num_); // (cp, q, h) + + float inv_sqrt_dh = (float)std::log2(expf(1.)); + if (param_.softmax_scale) { + inv_sqrt_dh *= param_.softmax_scale; + } + else { + inv_sqrt_dh /= std::sqrt((float)size_per_head_); + } + + invokeCpReduce(attn.data(), + cp_O_.data(), + cp_M_.data(), + cp_L_.data(), + token_num, + local_head_num_, + size_per_head_, + engine_param_.attn_cp_size, + engine_param_.attn_cp_rank, + inv_sqrt_dh, + stream_); + sync_check_cuda_error(); +} + template Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights) { @@ -240,7 +297,7 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; Tensor attn{{q_count, (int)local_head_num_ * (int)size_per_head_}, dtype, device}; - Tensor tmp_kv{{2, (int)local_kv_head_num_, k_count + MAX_CTA_S, (int)size_per_head_}, dtype, device}; + Tensor tmp_kv{{(int)local_kv_head_num_, 2, k_count + MAX_CTA_S, (int)size_per_head_}, dtype, device}; auto stream_ptr = streams_.data(); @@ -327,6 +384,17 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, params.locks = barriers_.data(); params.max_split_k = std::min(std::max(1, kMaxWorkspaceTokens / params.token_num), max_kv_splits); + // context parallel + params.cp_rank = engine_param_.attn_cp_rank; + params.cp_size = engine_param_.attn_cp_size; + if (params.cp_size > 1) { + const int off_ML = q_count * local_head_num_ * engine_param_.attn_cp_rank; + const int off_O = q_count * local_head_num_ * size_per_head_ * engine_param_.attn_cp_rank; + params.cp_M = cp_M_.data() + off_ML; + params.cp_L = cp_L_.data() + off_ML; + params.cp_O = cp_O_.data() + off_O; + } + params.arch = arch_; params.stream = stream; @@ -375,6 +443,55 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, check_cuda_error(cudaStreamWaitEvent(stream_, aux_event_)); } + if ((decode_num_ || prefil_num_) && !isTuning() && engine_param_.attn_cp_size > 1) { + cp_postprocess(attn); + + if (0) { + auto save_tensor = [&](const std::string& name, const Tensor& ten) { + std::stringstream ss; + for (auto& s : ten.shape()) { + ss << s << "_"; + } + TM_LOG_ERROR("name=%s, shape=%s", name.c_str(), ss.str().c_str()); + std::ofstream ofs(name + ".bin", std::ios::binary); + ofs.write((const char*)ten.raw_data(), ten.byte_size()); + }; + // out + Tensor_ dattn = {attn.data(), {(int)q_count, local_head_num_, size_per_head_}, kDEVICE}; + Tensor_ hattn = empty_like(dattn, kCPU); + Copy(dattn, hattn); + + // // k, v + // Tensor_ dkv = {tmp_kv.data(), {local_kv_head_num_, 2, k_count, size_per_head_}, kDEVICE}; + // Tensor_ hkv = empty_like(dkv, kCPU); + // Copy(dkv, hkv); + + const int off_ML = q_count * local_head_num_ * engine_param_.attn_cp_rank; + + // q, h, + Tensor_ dl = {cp_M_.data() + off_ML, {q_count, local_head_num_}, kDEVICE}; + Tensor_ dm = {cp_L_.data() + off_ML, {q_count, local_head_num_}, kDEVICE}; + Tensor_ hl = empty_like(dl, kCPU); + Tensor_ hm = empty_like(dm, kCPU); + Copy(dl, hl); + Copy(dm, hm); + cudaDeviceSynchronize(); + + save_tensor("attn_" + std::to_string(engine_param_.attn_tp_rank) + + std::to_string(engine_param_.attn_cp_rank), + hattn); + // save_tensor("hkv_" + std::to_string(engine_param_.attn_tp_rank) + // + std::to_string(engine_param_.attn_cp_rank), + // hkv); + save_tensor("hl_" + std::to_string(engine_param_.attn_tp_rank) + std::to_string(engine_param_.attn_cp_rank), + hl); + save_tensor("hm_" + std::to_string(engine_param_.attn_tp_rank) + std::to_string(engine_param_.attn_cp_rank), + hm); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + exit(0); + } + } + if (isTuning()) { rng_.set_stream(stream_); rng_.GenerateUniform(attn.data(), attn.size(), .02f, -.01f); diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index a498b3b881..42c8e81b0e 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -76,6 +76,9 @@ class UnifiedAttentionLayer { template Tensor core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights); + template + void cp_postprocess(Tensor& attn); + void qk_norm(Tensor& qkv, const WeightType& weights); private: @@ -87,6 +90,7 @@ class UnifiedAttentionLayer { const int local_kv_head_num_; const AttentionParam param_; + const EngineParam engine_param_; const ModelParam model_param_; const LoraParam lora_param_; const Context& context_; @@ -99,6 +103,9 @@ class UnifiedAttentionLayer { cudaEvent_t qkv_event_; cudaEvent_t aux_event_; + const int attn_cp_group_; + comm::DeviceCommImpl* const d_comm_; + std::array streams_; RNG rng_; @@ -116,6 +123,11 @@ class UnifiedAttentionLayer { Tensor_ split_cnt_; Tensor_ barriers_; // always zero + // context parallel + Tensor_ cp_M_; + Tensor_ cp_L_; + Tensor_ cp_O_; + Event event_; Buffer_ h_q_len_; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index b771f0f00d..6405a3e225 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -33,6 +33,7 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, rmsnorm_eps_(model.norm_eps), stream_(ctx.stream), d_comm_(ctx.comm.d_comm), + engine_param_(engine), tune_layer_num_(model.tune_layer_num) { attn_layer_ = std::make_unique(model, attn, engine, lora, attn_tp_size_, ctx); @@ -57,7 +58,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, { const auto dtype = hidden_states.dtype(); if (0) {} - else if (group0 || group1) { + else if (engine_param_.attn_dp_size > 1 && engine_param_.attn_cp_size == 1) { d_comm_->AllreduceResidualBiasRMSnormEx(hidden_states.raw_data(), residual.data_or((void*)nullptr), bias.data_or((void*)nullptr), diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index dd03293744..2d001c9bc3 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -35,6 +35,8 @@ class UnifiedDecoder { const int attn_tp_group_; + const EngineParam engine_param_; + const float rmsnorm_eps_; cudaStream_t const stream_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index e99e34a41e..05fd5d9f82 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -379,6 +379,8 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ engine_param_.attn_dp_rank = 0; engine_param_.attn_tp_size = engine_reader["attn_tp_size"].as(); engine_param_.attn_tp_rank = 0; + engine_param_.attn_cp_size = engine_reader["attn_cp_size"].as(); + engine_param_.attn_cp_rank = 0; engine_param_.mlp_tp_size = engine_reader["mlp_tp_size"].as(); engine_param_.mlp_tp_rank = 0; @@ -389,7 +391,7 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ engine_param_.max_forward_token_num = ((size_t)max_forward_token_num + tp - 1) / tp * tp; } - comm_size_ = engine_param_.attn_dp_size * engine_param_.attn_tp_size; + comm_size_ = engine_param_.attn_dp_size * engine_param_.attn_tp_size * engine_param_.attn_cp_size; FT_CHECK(engine_param_.mlp_tp_size == comm_size_); communicator_ = engine_reader["communicator"].as(); @@ -444,14 +446,21 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ } const int device_num = engine_param_.outer_dp_size * comm_size_; + const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; + // comm layout: outer_dp x inner(dp, tp, cp) engine_params_.resize(device_num, engine_param_); for (int i = 0; i < device_num; ++i) { auto& e = engine_params_[i]; e.outer_dp_rank = i / comm_size_; - e.attn_tp_rank = i % comm_size_ % e.attn_tp_size; - e.attn_dp_rank = i % comm_size_ / e.attn_tp_size; - e.mlp_tp_rank = i % comm_size_; + // e.attn_tp_rank = i % comm_size_ % e.attn_tp_size; + // e.attn_dp_rank = i % comm_size_ / e.attn_tp_size; + // e.mlp_tp_rank = i % comm_size_; + + e.attn_cp_rank = i % comm_size_ % e.attn_cp_size; + e.attn_tp_rank = i % tp_cp_size / e.attn_cp_size; + e.attn_dp_rank = i % comm_size_ / tp_cp_size; + e.mlp_tp_rank = i % comm_size_; } TM_LOG_INFO("%s", toString().c_str()); @@ -501,17 +510,45 @@ Communicators LlamaTritonModel::createCommSplits(int rank) const int outer_rank = rank / comm_size_; const int inner_rank = rank % comm_size_; + const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; + const int color_tp_cp = inner_rank / tp_cp_size; + const int color_cp = + inner_rank / engine_param_.attn_cp_size + (inner_rank / tp_cp_size) * engine_param_.attn_tp_size; + const int color_tp = + inner_rank % engine_param_.attn_cp_size + (inner_rank / tp_cp_size) * engine_param_.attn_cp_size; + TM_LOG_ERROR("[split] rank=%d, tp_cp_size=%d, color_cp=%d, color_tp=%d, comm_size=%d, inner_rank=%d", + rank, + tp_cp_size, + color_cp, + color_tp, + comm_size_, + inner_rank); + comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank); - comm.h_tp_group = comm.h_comm->Split(inner_rank / engine_param_.attn_tp_size, 0); - comm.h_dp_group = comm.h_comm->Split(inner_rank % engine_param_.attn_tp_size, 0); + // comm.h_tp_group = comm.h_comm->Split(inner_rank / engine_param_.attn_tp_size, 0); + // comm.h_dp_group = comm.h_comm->Split(inner_rank % engine_param_.attn_tp_size, 0); + + comm.h_tp_cp_group = comm.h_comm->Split(color_tp_cp, 0); + comm.h_tp_group = comm.h_comm->Split(color_tp, 0); + comm.h_dp_group = comm.h_comm->Split(inner_rank % tp_cp_size, 0); if (comm_size_ > 1) { comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, inner_rank, comm.h_comm); // comm.d_tp_group = 0; if (engine_param_.attn_tp_size != comm_size_) { - comm.d_tp_group = comm.d_comm->Split(inner_rank / engine_param_.attn_tp_size, 0, 0); + // comm.d_tp_group = comm.d_comm->Split(inner_rank / engine_param_.attn_tp_size, 0, 0); + + comm.d_cp_group = comm.d_comm->Split(color_cp, 0, 0); + comm.d_tp_group = comm.d_comm->Split(color_tp, 0, 0); + + // d2t2c3 example + // d0t0c0, d0t0c1, d0t0c2, d0t1c0, d0t1c1, d0t1c2 + // c 0 0 0 1 1 1 + // t 0 1 2 0 1 2 + // c inner_rank / attn_cp_size + (inner_rank / tp_cp_size) * attn_tp_size + // t inner_rank % attn_cp_size + (inner_rank / tp_cp_size) * attn_cp_size } } @@ -574,7 +611,7 @@ void LlamaTritonModel::createEngine(int device_id, int rank) if (first_create) { try { - engine.Warmup(); + // engine.Warmup(); } catch (const std::exception& e) { TM_LOG_ERROR("[Engine][Warmup] %s", e.what()); From 0fe88bc5d8c943d3787e76e341044f89d1be55d5 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 22 Sep 2025 07:49:48 +0000 Subject: [PATCH 03/31] accurate mask iter --- .../kernels/attention/attention_universal.h | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 53cce315ac..1cce94f980 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -441,9 +441,19 @@ struct AttentionUniversal { // -> x * CTA_S >= offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - w int mask_iter_front = cdiv(max(0, offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - params.window_size), CTA_S); - // TODO: mask all iter for simplicity, use accurate mask_iter - mask_iter_back = 999999; - mask_iter_front = 999999; + if (params.cp_size > 1) { + // mask all iter for simplicity + // mask_iter_back = 1 << 30; + // mask_iter_front = 1 << 30; + // TODO: use accurate mask_iter + mask_iter_back = + cdiv(max(0, params.cp_size * (offset_K + CTA_S) - offset_Q + params.cp_rank), params.cp_size * CTA_S); + mask_iter_front = cdiv(max(0, + offset_Q + CTA_Q - params.window_size - params.cp_rank + - params.cp_size * (offset_K - tile_iter * CTA_S)), + params.cp_size * CTA_S); + } + #if 0 if (threadIdx.x == 0) { printf( From 5c0277933e732a52a02bb3c6d446e4ed049eff93 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 22 Sep 2025 08:59:30 +0000 Subject: [PATCH 04/31] use fast divmod --- .../kernels/attention/CMakeLists.txt | 2 +- .../kernels/attention/attention_universal.h | 24 ++++++++----- .../kernels/attention/kv_cache_utils_v2.cu | 36 +++++++++---------- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/turbomind/kernels/attention/CMakeLists.txt b/src/turbomind/kernels/attention/CMakeLists.txt index d9711f112c..d1fee315cc 100644 --- a/src/turbomind/kernels/attention/CMakeLists.txt +++ b/src/turbomind/kernels/attention/CMakeLists.txt @@ -45,7 +45,7 @@ set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_compile_options(attention PRIVATE -O3 $<$:-use_fast_math --expt-relaxed-constexpr>) - +target_link_libraries(attention PRIVATE nvidia::cutlass::cutlass) if (BUILD_TEST) target_compile_options(attention PRIVATE diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 1cce94f980..7efb12c699 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -2,6 +2,8 @@ #pragma once +#include "cutlass/fast_math.h" + #include "quantization.h" #include "src/turbomind/kernels/attention/reduce_kernel.h" #include "src/turbomind/kernels/attention/rotary_embedding.h" @@ -256,6 +258,10 @@ struct AttentionUniversal { const int qi = offset.y / CTA_H; const int ti = history_len; + cutlass::FastDivmod cp_divmod{params.cp_size}; + int cp_quo, cp_rem; + cp_divmod(cp_quo, cp_rem, ti); + Array param_K[1]; Array param_V[1]; @@ -276,8 +282,8 @@ struct AttentionUniversal { } iterator.block_head_.with( - iterator.block_ptrs_, ti / params.cp_size, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { - if (ti % params.cp_size != params.cp_rank) { + iterator.block_ptrs_, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + if (cp_rem != params.cp_rank) { return; } PRAGMA_UNROLL @@ -374,18 +380,20 @@ struct AttentionUniversal { const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx]; const int history_len = context_len - input_len; + cutlass::FastDivmod cp_divmod{params.cp_size}; + auto get_cp_len = [&](int length) -> int { - return (length / params.cp_size + (length % params.cp_size > params.cp_rank ? 1 : 0)); + int cp_quo, cp_rem; + cp_divmod(cp_quo, cp_rem, length); + return (cp_quo + (cp_rem > params.cp_rank ? 1 : 0)); }; - const int last_K = history_len + min(query_idx + CTA_Q, input_len); - const int last_K_tile = (get_cp_len(last_K) - 1) / CTA_S + 1; - // const int last_K_tile = (last_K - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index - // conversion + const int last_K = history_len + min(query_idx + CTA_Q, input_len); + const int last_K_tile = + (get_cp_len(last_K) - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion const int first_K = max(history_len + query_idx - (params.window_size - 1), 0); const int first_K_tile = get_cp_len(first_K) / CTA_S; - // const int first_K_tile = first_K / CTA_S; const int tile_count = last_K_tile - first_K_tile; diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index be9f5b7430..ffa8a41b42 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -2,6 +2,8 @@ #include +#include "cutlass/fast_math.h" + #include "src/turbomind/kernels/attention/block.h" #include "src/turbomind/kernels/attention/kv_cache_utils_v2.h" #include "src/turbomind/kernels/attention/quantization.h" @@ -154,6 +156,9 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } + cutlass::FastDivmod cp_divmod{cp_size}; + int cp_quo, cp_rem; + blocks += cu_block_num[batch_idx]; block::Head block_head{block_layout, layer_id, head_idx}; @@ -162,8 +167,9 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, for (int s = 0; s < ITER_S; ++s) { const int qi = offset.y + s * Map::kDeltaS + token_idx; // local offset into `input_length` const int ti = history_len + qi; // timestep - if (qi < q_len && (ti % cp_size == cp_rank)) { - block_head.with((char**)blocks, ti / cp_size, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + cp_divmod(cp_quo, cp_rem, ti); + if (qi < q_len && cp_rem == cp_rank) { + block_head.with((char**)blocks, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { int di = offset.x + c * Map::kDeltaC; @@ -351,11 +357,15 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, Array param_K[ITER_S]; Array param_V[ITER_S]; + cutlass::FastDivmod cp_divmod{cp_size}; + int cp_quo, cp_rem; + PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int si = offset.y + s * Map::kDeltaS + token_idx; - if (si < seq_len && (si % cp_size == cp_rank)) { - block_head.with((char**)blocks, si / cp_size, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { + cp_divmod(cp_quo, cp_rem, si); + if (si < seq_len && cp_rem == cp_rank) { + block_head.with((char**)blocks, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { int di = offset.x + c * Map::kDeltaC; @@ -401,25 +411,13 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, for (int c = 0; c < ITER_C; ++c) { const int si = offset.y + s * Map::kDeltaS + token_idx; const int di = offset.x + c * Map::kDeltaC; - // save first - if (si < seq_len && si % cp_size == cp_rank) { + cp_divmod(cp_quo, cp_rem, si); + if (si < seq_len && cp_rem == cp_rank) { const int64_t index = - (batch_idx * stride_b + ti_beg * stride_c + si / cp_size * stride_s + head_idx * stride_h) * HeadDim - + di; + (batch_idx * stride_b + ti_beg * stride_c + cp_quo * stride_s + head_idx * stride_h) * HeadDim + di; Store(&k[index], out_K[s][c]); Store(&v[index], out_V[s][c]); } - - // const int64_t index = - // (batch_idx * stride_b + ti_beg * stride_c + si * stride_s + head_idx * stride_h) * HeadDim + di; - // if (si < seq_len) { - // if (si % cp_size != cp_rank) { - // clear(out_K[s][c]); - // clear(out_V[s][c]); - // } - // Store(&k[index], out_K[s][c]); - // Store(&v[index], out_V[s][c]); - // } } } } From 53654add1e6915c90f06e986ca6913ccfd182eb0 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 22 Sep 2025 13:38:23 +0000 Subject: [PATCH 05/31] remove cp_O --- .../kernels/attention/attention_params.h | 1 - .../kernels/attention/attention_template.h | 1 - .../kernels/attention/attention_universal.h | 3 +-- .../kernels/attention/decoding_template.h | 1 - src/turbomind/kernels/attention/reduce.cu | 3 --- src/turbomind/kernels/attention/reduce.h | 1 - .../kernels/attention/reduce_kernel.h | 8 +----- src/turbomind/models/llama/LlamaBatch.cc | 6 ----- src/turbomind/models/llama/LlamaBatch.h | 1 - src/turbomind/models/llama/LlamaV2.cc | 2 -- src/turbomind/models/llama/LlamaV2.h | 1 - src/turbomind/models/llama/cp_utils.cu | 26 +++---------------- src/turbomind/models/llama/cp_utils.h | 1 - .../models/llama/unified_attention_layer.cc | 12 --------- .../models/llama/unified_attention_layer.h | 1 - 15 files changed, 6 insertions(+), 62 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index b2d43d5fc9..d0f8d1fcff 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -84,7 +84,6 @@ struct AttentionParams { int cp_size{1}; float* cp_M{nullptr}; float* cp_L{nullptr}; - float* cp_O{nullptr}; int arch; cudaStream_t stream; diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index 7bff780396..b7257fa334 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -87,7 +87,6 @@ void invokeAttention(const typename Kernel::ParamType& params) params.partial_O, params.cp_M, params.cp_L, - params.cp_O, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 7efb12c699..fdc9c5650f 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -562,7 +562,6 @@ struct AttentionUniversal { params.partial_O, params.cp_M, params.cp_L, - params.cp_O, qi_begin, head_idx, params.num_heads, @@ -639,7 +638,7 @@ struct AttentionUniversal { } if (params.cp_size > 1 && split_idx == 0) { const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * kHeadDim + di; - Store(¶ms.cp_O[index], vec); + Store(¶ms.out[index], cast(vec)); } } }); diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index 5be8c12c08..d22217dc6c 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -86,7 +86,6 @@ bool invokeDecoding(const typename Kernel::ParamType& params) params.partial_O, params.cp_M, params.cp_L, - params.cp_O, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index 53bbda59ce..493c113b4c 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -14,7 +14,6 @@ void invokeReduce(T* out, float* partial_O, float* cp_M, float* cp_L, - float* cp_O, const int* split_cnt, int partial_len, int max_split_cnt, @@ -39,7 +38,6 @@ void invokeReduce(T* out, partial_O, cp_M, cp_L, - cp_O, nullptr, split_cnt, partial_len, @@ -66,7 +64,6 @@ void invokeReduce(T* out, float* partial_O, \ float* cp_M, \ float* cp_L, \ - float* cp_O, \ const int* split_cnt, \ int partial_len, \ int max_split_cnt, \ diff --git a/src/turbomind/kernels/attention/reduce.h b/src/turbomind/kernels/attention/reduce.h index 8fe17f5fa7..d1f06a075c 100644 --- a/src/turbomind/kernels/attention/reduce.h +++ b/src/turbomind/kernels/attention/reduce.h @@ -18,7 +18,6 @@ void invokeReduce(T* out, float* partial_O, float* cp_M, float* cp_L, - float* cp_O, const int* split_cnt, int partial_len, int max_split_cnt, diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h index c246631e78..48c89e940b 100644 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ b/src/turbomind/kernels/attention/reduce_kernel.h @@ -29,7 +29,6 @@ struct Reduce { float* partial_O, float* cp_M, float* cp_L, - float* cp_O, int query_idx, int head_idx, int head_num, @@ -105,7 +104,7 @@ struct Reduce { Array scale; PRAGMA_UNROLL for (int k = 0; k < K; ++k) { - scale[k] = (IsFinal && cp_O == nullptr) ? expdiff_M[k] / block_L : expdiff_M[k]; + scale[k] = (IsFinal && cp_M == nullptr) ? expdiff_M[k] / block_L : expdiff_M[k]; } if (hi < CTA_H) { @@ -205,9 +204,6 @@ struct Reduce { if (ki == 0 && hi < hi_end) { if constexpr (IsFinal) { const int offset = (query_idx * head_num + head_idx + hi) * HeadDim + di; - if (cp_O != nullptr) { - Store(&cp_O[offset], (Vec&)storage.O[hi][ki][di]); - } Store(&out[offset], cast((Vec&)storage.O[hi][ki][di])); } else { @@ -227,7 +223,6 @@ __global__ void reduce_kernel(typename Reduce::T* out, float* partial_O, float* cp_M, float* cp_L, - float* cp_O, int* signals, const int* split_cnt_, int max_split_cnt, @@ -256,7 +251,6 @@ __global__ void reduce_kernel(typename Reduce::T* out, partial_O, cp_M, cp_L, - cp_O, query_idx, head_idx, head_num, diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 8d72fd955e..d278605450 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -831,9 +831,6 @@ void LlamaBatch::AllocSymmBuffers() if (param_.attn_cp_size > 1) { symm_cp_M_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_}, symm_alloc_}; symm_cp_L_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_}, symm_alloc_}; - symm_cp_O_ = { - {param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_, (int)model_->size_per_head_}, - symm_alloc_}; } } @@ -844,7 +841,6 @@ void LlamaBatch::FreeSymmBuffers() symm_cp_M_ = {}; symm_cp_L_ = {}; - symm_cp_O_ = {}; } LlamaBatch::~LlamaBatch() @@ -1587,7 +1583,6 @@ bool LlamaBatch::Forward(GenerationState& g) &mrope, symm_cp_M_, symm_cp_L_, - symm_cp_O_, finished_buf_.slice(first, mini_batch_size), Buffer(local_token_nums.data(), local_token_nums.size(), kCPU), lora_mask_buf_, @@ -1782,7 +1777,6 @@ void LlamaBatch::Warmup() nullptr, // mrope symm_cp_M_, symm_cp_L_, - symm_cp_O_, finished_buf_.slice(0, bsz), Buffer{local_token_nums.data(), (int)local_token_nums.size(), kCPU}, Buffer{}, diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 379155649a..03c86b7f3c 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -246,7 +246,6 @@ class LlamaBatch { Tensor symm_logits_buf_; // context parallel - Tensor_ symm_cp_O_; Tensor_ symm_cp_M_; Tensor_ symm_cp_L_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 7947157b45..671f0d9549 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -165,7 +165,6 @@ void LlamaV2::Forward(Buffer_ input_ids, MropeRope* mrope, Tensor cp_M, Tensor cp_L, - Tensor cp_O, Buffer finished, Buffer local_token_nums, Buffer lora_mask, @@ -263,7 +262,6 @@ void LlamaV2::Forward(Buffer_ input_ids, {"rope_base", rope_base}, {"cp_M", cp_M}, {"cp_L", cp_L}, - {"cp_O", cp_O}, {"cu_block_nums", cu_block_nums}, {"kv_block_ptrs", kv_block_ptrs}, {"local_token_nums", local_token_nums}}; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index b51b2e56d9..bac1503db7 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -71,7 +71,6 @@ class LlamaV2 { MropeRope* mrope, Tensor cp_M, Tensor cp_L, - Tensor cp_O, Buffer finished, Buffer local_token_nums, Buffer lora_mask, diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu index 8072b6cfe8..dabb70c291 100644 --- a/src/turbomind/models/llama/cp_utils.cu +++ b/src/turbomind/models/llama/cp_utils.cu @@ -6,7 +6,6 @@ namespace turbomind { template __global__ void CpReduce(T* out, - float* O, float* M, float* L, int token_num, @@ -51,29 +50,14 @@ __global__ void CpReduce(T* out, __syncthreads(); - // for (int i = threadIdx.x; i < size_per_head; i += blockDim.x) { - // float flag_O = 0; - // for (int j = 0; j < cp_size; ++j) { - // int index = j * token_num * head_num * size_per_head + token_idx * head_num * size_per_head - // + head_idx * size_per_head + i; - // flag_O += O[index] * scale[j]; - // } - // int out_index = token_idx * head_num * size_per_head + head_idx * size_per_head + i; // q, h, d - // // out[out_index] = (T)flag_O; - // out[out_index] = (T)(flag_O / cp_size); - // } - for (int i = threadIdx.x; i < size_per_head; i += blockDim.x) { - int src_index = cp_rank * token_num * head_num * size_per_head + token_idx * head_num * size_per_head - + head_idx * size_per_head + i; - int dst_index = token_idx * head_num * size_per_head + head_idx * size_per_head + i; // q, h, d - out[dst_index] = (T)(O[src_index] * scale[cp_rank]); + int index = token_idx * head_num * size_per_head + head_idx * size_per_head + i; + out[index] = (T)((float)out[index] * scale[cp_rank]); } } template void invokeCpReduce(T* out, - float* O, float* M, float* L, int token_num, @@ -87,14 +71,13 @@ void invokeCpReduce(T* out, TM_CHECK(cp_size <= WARP_SIZE); const dim3 block = 4 * WARP_SIZE; const dim3 grid(token_num, head_num); - size_t smem_size = sizeof(float) * WARP_SIZE * 2; + size_t smem_size = sizeof(float) * WARP_SIZE; CpReduce<<>>( - out, O, M, L, token_num, head_num, size_per_head, cp_size, cp_rank, exp_scale); + out, M, L, token_num, head_num, size_per_head, cp_size, cp_rank, exp_scale); sync_check_cuda_error(); } template void invokeCpReduce(half* out, - float* O, float* M, float* L, int token_num, @@ -106,7 +89,6 @@ template void invokeCpReduce(half* out, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeCpReduce(__nv_bfloat16* out, - float* O, float* M, float* L, int token_num, diff --git a/src/turbomind/models/llama/cp_utils.h b/src/turbomind/models/llama/cp_utils.h index 4ae640a51a..6ac06e0ffd 100644 --- a/src/turbomind/models/llama/cp_utils.h +++ b/src/turbomind/models/llama/cp_utils.h @@ -6,7 +6,6 @@ namespace turbomind { template void invokeCpReduce(T* out, - float* O, float* M, float* L, int token_num, diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 7b51605ec9..34190d15e1 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -144,7 +144,6 @@ void UnifiedAttentionLayer::Initialize(TensorMap& args) cp_M_ = args.at("cp_M").borrow(); cp_L_ = args.at("cp_L").borrow(); - cp_O_ = args.at("cp_O").borrow(); // rotary embedding, add offest when forward if (rope_param_.type == RopeType::kDynamic) { @@ -252,15 +251,6 @@ void UnifiedAttentionLayer::cp_postprocess(Tensor& attn) stream_); sync_check_cuda_error(); - // auto allgather = [&](float* src, float* dst, int count) { - // d_comm_->AllGather(src + count * engine_param_.attn_cp_rank, dst, count, kFloat32, attn_cp_group_, stream_); - // sync_check_cuda_error(); - // }; - // allgather(cp_O_.data(), cp_O_.data(), - // token_num * local_head_num_ * size_per_head_); // (cp, q, h, d) - // allgather(cp_M_.data(), cp_M_.data(), token_num * local_head_num_); // (cp, q, h) - // allgather(cp_L_.data(), cp_L_.data(), token_num * local_head_num_); // (cp, q, h) - float inv_sqrt_dh = (float)std::log2(expf(1.)); if (param_.softmax_scale) { inv_sqrt_dh *= param_.softmax_scale; @@ -270,7 +260,6 @@ void UnifiedAttentionLayer::cp_postprocess(Tensor& attn) } invokeCpReduce(attn.data(), - cp_O_.data(), cp_M_.data(), cp_L_.data(), token_num, @@ -392,7 +381,6 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, const int off_O = q_count * local_head_num_ * size_per_head_ * engine_param_.attn_cp_rank; params.cp_M = cp_M_.data() + off_ML; params.cp_L = cp_L_.data() + off_ML; - params.cp_O = cp_O_.data() + off_O; } params.arch = arch_; diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index 42c8e81b0e..ec4c4e8ff1 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -126,7 +126,6 @@ class UnifiedAttentionLayer { // context parallel Tensor_ cp_M_; Tensor_ cp_L_; - Tensor_ cp_O_; Event event_; From e3dd4f7fe5d536920f5a233a47b2cc69d53de188 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 22 Sep 2025 13:45:38 +0000 Subject: [PATCH 06/31] remove unused --- .../models/llama/unified_attention_layer.cc | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 34190d15e1..db54d56bbd 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -23,9 +23,6 @@ #include #include -#include "src/turbomind/models/llama/cp_utils.h" -#include - #include "src/turbomind/core/check.h" #include "src/turbomind/core/data_type.h" #include "src/turbomind/core/tensor.h" @@ -37,6 +34,7 @@ #include "src/turbomind/macro.h" +#include "src/turbomind/models/llama/cp_utils.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/mla_utils.h" #include "src/turbomind/models/llama/unified_attention_layer.h" @@ -433,51 +431,6 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, if ((decode_num_ || prefil_num_) && !isTuning() && engine_param_.attn_cp_size > 1) { cp_postprocess(attn); - - if (0) { - auto save_tensor = [&](const std::string& name, const Tensor& ten) { - std::stringstream ss; - for (auto& s : ten.shape()) { - ss << s << "_"; - } - TM_LOG_ERROR("name=%s, shape=%s", name.c_str(), ss.str().c_str()); - std::ofstream ofs(name + ".bin", std::ios::binary); - ofs.write((const char*)ten.raw_data(), ten.byte_size()); - }; - // out - Tensor_ dattn = {attn.data(), {(int)q_count, local_head_num_, size_per_head_}, kDEVICE}; - Tensor_ hattn = empty_like(dattn, kCPU); - Copy(dattn, hattn); - - // // k, v - // Tensor_ dkv = {tmp_kv.data(), {local_kv_head_num_, 2, k_count, size_per_head_}, kDEVICE}; - // Tensor_ hkv = empty_like(dkv, kCPU); - // Copy(dkv, hkv); - - const int off_ML = q_count * local_head_num_ * engine_param_.attn_cp_rank; - - // q, h, - Tensor_ dl = {cp_M_.data() + off_ML, {q_count, local_head_num_}, kDEVICE}; - Tensor_ dm = {cp_L_.data() + off_ML, {q_count, local_head_num_}, kDEVICE}; - Tensor_ hl = empty_like(dl, kCPU); - Tensor_ hm = empty_like(dm, kCPU); - Copy(dl, hl); - Copy(dm, hm); - cudaDeviceSynchronize(); - - save_tensor("attn_" + std::to_string(engine_param_.attn_tp_rank) - + std::to_string(engine_param_.attn_cp_rank), - hattn); - // save_tensor("hkv_" + std::to_string(engine_param_.attn_tp_rank) - // + std::to_string(engine_param_.attn_cp_rank), - // hkv); - save_tensor("hl_" + std::to_string(engine_param_.attn_tp_rank) + std::to_string(engine_param_.attn_cp_rank), - hl); - save_tensor("hm_" + std::to_string(engine_param_.attn_tp_rank) + std::to_string(engine_param_.attn_cp_rank), - hm); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - exit(0); - } } if (isTuning()) { From 1f75dd64c447c87f73e9f34f7f100056259388de Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Mon, 22 Sep 2025 20:23:34 +0800 Subject: [PATCH 07/31] return the last token's logprobs if include_stop_str_in_output is requested (#4000) --- lmdeploy/serve/async_engine.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index be9b3adb58..0e7bbdf7eb 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -846,6 +846,8 @@ def is_error(status): if outputs.logprobs is not None: log_offset = ids_offset - start_ids_offset out.logprobs = outputs.logprobs[log_offset:] + if hit_stop_token: + out.logprobs = out.logprobs[:-hit_stop_token] if outputs.last_hidden_state is not None: out.last_hidden_state = outputs.last_hidden_state if hit_stop_token: @@ -865,11 +867,15 @@ def is_error(status): if not response.endswith('�'): # avoid returning the last response twice response = '' - token_ids = [] + token_ids, logits, last_hidden_state, logprobs = [], None, None, None if gen_config.include_stop_str_in_output and finish_reason == 'stop': - # return the eos token id (MUST be in a list) and its string + # return the eos token id (MUST be in a list), eos string, eos token's logits and so on token_ids = outputs.token_ids[-1:] response = self.tokenizer.decode(token_ids, skip_special_tokens=False) + logits = outputs.logits[-1:] if outputs.logits else None + last_hidden_state = outputs.last_hidden_state[-1:] if outputs.last_hidden_state else None + logprobs = outputs.logprobs[-1:] if outputs.logprobs else None + logger.info(f'session {session_id} finished, reason ' f'"{finish_reason}", input_tokens ' f'{len(input_ids)}, output_tokens {gen_len}') @@ -879,6 +885,9 @@ def is_error(status): gen_len, finish_reason, token_ids=token_ids, + logprobs=logprobs, + logits=logits, + last_hidden_state=last_hidden_state, cache_block_ids=outputs.cache_block_ids) # Update a session's sequence only when it is in finished status if outputs.status == ResponseType.FINISH: From be504d387d898ed48554a203ef25d9017adbf170 Mon Sep 17 00:00:00 2001 From: CyCle1024 Date: Mon, 22 Sep 2025 21:20:30 +0800 Subject: [PATCH 08/31] [Fix] device args in chat cli when using pytorch engine (#3999) * [Fix] device args in chat cli when using pytorch engine * [Fix] change device into device_type in chat cli --- lmdeploy/cli/chat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmdeploy/cli/chat.py b/lmdeploy/cli/chat.py index 0e612453d0..401f178a7a 100644 --- a/lmdeploy/cli/chat.py +++ b/lmdeploy/cli/chat.py @@ -25,6 +25,7 @@ def build_pipe(model_path, backend, **kwargs): else: engine_config = PytorchEngineConfig() for key, value in kwargs.items(): + key = 'device_type' if key == 'device' else key if hasattr(PytorchEngineConfig, key): setattr(engine_config, key, value) if kwargs.get('adapters', None): From 77ef52a19a956c9c87158f7b03c4ea52a7435f78 Mon Sep 17 00:00:00 2001 From: irexyc Date: Tue, 23 Sep 2025 11:19:49 +0000 Subject: [PATCH 09/31] fix NULL raw data --- .../models/llama/unified_attention_layer.cc | 6 ++++-- .../triton_backend/llama/LlamaTritonModel.cc | 20 +------------------ 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index db54d56bbd..7723a6339e 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -140,8 +140,10 @@ void UnifiedAttentionLayer::Initialize(TensorMap& args) cu_block_nums_ = args.at("cu_block_nums").buffer(); kv_block_ptrs_ = args.at("kv_block_ptrs").buffer(); - cp_M_ = args.at("cp_M").borrow(); - cp_L_ = args.at("cp_L").borrow(); + if (engine_param_.attn_cp_size > 1) { + cp_M_ = args.at("cp_M").borrow(); + cp_L_ = args.at("cp_L").borrow(); + } // rotary embedding, add offest when forward if (rope_param_.type == RopeType::kDynamic) { diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 05fd5d9f82..15aa9b133c 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -516,19 +516,9 @@ Communicators LlamaTritonModel::createCommSplits(int rank) inner_rank / engine_param_.attn_cp_size + (inner_rank / tp_cp_size) * engine_param_.attn_tp_size; const int color_tp = inner_rank % engine_param_.attn_cp_size + (inner_rank / tp_cp_size) * engine_param_.attn_cp_size; - TM_LOG_ERROR("[split] rank=%d, tp_cp_size=%d, color_cp=%d, color_tp=%d, comm_size=%d, inner_rank=%d", - rank, - tp_cp_size, - color_cp, - color_tp, - comm_size_, - inner_rank); comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank); - // comm.h_tp_group = comm.h_comm->Split(inner_rank / engine_param_.attn_tp_size, 0); - // comm.h_dp_group = comm.h_comm->Split(inner_rank % engine_param_.attn_tp_size, 0); - comm.h_tp_cp_group = comm.h_comm->Split(color_tp_cp, 0); comm.h_tp_group = comm.h_comm->Split(color_tp, 0); comm.h_dp_group = comm.h_comm->Split(inner_rank % tp_cp_size, 0); @@ -537,18 +527,10 @@ Communicators LlamaTritonModel::createCommSplits(int rank) comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, inner_rank, comm.h_comm); // comm.d_tp_group = 0; + comm.d_cp_group = 0; if (engine_param_.attn_tp_size != comm_size_) { - // comm.d_tp_group = comm.d_comm->Split(inner_rank / engine_param_.attn_tp_size, 0, 0); - comm.d_cp_group = comm.d_comm->Split(color_cp, 0, 0); comm.d_tp_group = comm.d_comm->Split(color_tp, 0, 0); - - // d2t2c3 example - // d0t0c0, d0t0c1, d0t0c2, d0t1c0, d0t1c1, d0t1c2 - // c 0 0 0 1 1 1 - // t 0 1 2 0 1 2 - // c inner_rank / attn_cp_size + (inner_rank / tp_cp_size) * attn_tp_size - // t inner_rank % attn_cp_size + (inner_rank / tp_cp_size) * attn_cp_size } } From 29cf813693f48a1f908b4731de942bb1e46d88a0 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 24 Sep 2025 07:07:30 +0000 Subject: [PATCH 10/31] add attn_cp_size to cli --- benchmark/profile_throughput.py | 2 ++ lmdeploy/cli/cli.py | 1 + lmdeploy/cli/serve.py | 2 ++ lmdeploy/cli/utils.py | 10 ++++++++++ 4 files changed, 15 insertions(+) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 6e4243cca6..76e828d005 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -329,6 +329,7 @@ def parse_args(): tb_group._group_actions.append(dtype_act) ArgumentHelper.dp(tb_group) + ArgumentHelper.attn_cp_size(tb_group) ArgumentHelper.model_format(tb_group, default='hf') ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) @@ -346,6 +347,7 @@ def main(): max_batch_size=args.concurrency // args.dp, tp=args.tp, dp=args.dp, + attn_cp_size=args.attn_cp_size, cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, model_format=args.model_format, diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index d71198791f..d6982982d6 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -76,6 +76,7 @@ def add_parser_chat(): ArgumentHelper.model_format(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.communicator(tb_group) + ArgumentHelper.attn_cp_size(tb_group) @staticmethod def add_parser_checkenv(): diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 6a9e9f2b13..a3e58713dc 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -135,6 +135,7 @@ def add_parser_api_server(): tb_group._group_actions.append(model_format) tb_group._group_actions.append(hf_overrides) tb_group._group_actions.append(enable_metrics) + ArgumentHelper.attn_cp_size(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) @@ -232,6 +233,7 @@ def api_server(args): from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig(dtype=args.dtype, tp=args.tp, + attn_cp_size=args.attn_cp_size, max_batch_size=max_batch_size, session_len=args.session_len, model_format=args.model_format, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index bfd94182d0..929938a95b 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -188,6 +188,16 @@ def ep(parser): default=1, help='expert parallelism. dp is required when pytorch engine is used.') + @staticmethod + def attn_cp_size(parser): + """Add argument attn_cp_size to parser.""" + + return parser.add_argument( + '--attn-cp-size', + type=int, + default=1, + help='context parallelism size in attention for turbomind backend. Should divide tp.') + @staticmethod def dp_rank(parser): """Add argument dp_rank to parser.""" From 0044d4f42903d8e04a4be38101fd6c5608e19934 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 24 Sep 2025 08:12:35 +0000 Subject: [PATCH 11/31] build cutlass::FastDivmod on host --- .../kernels/attention/attention_params.h | 10 ++++--- .../kernels/attention/attention_universal.h | 11 +++----- .../kernels/attention/kv_cache_utils_v2.cu | 26 +++++++++---------- .../kernels/attention/kv_cache_utils_v2.h | 8 +++--- src/turbomind/models/llama/CMakeLists.txt | 1 + .../models/llama/unified_attention_layer.cc | 3 ++- 6 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index d0f8d1fcff..5d15a0eec3 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -2,6 +2,7 @@ #pragma once +#include "cutlass/fast_math.h" #include #include @@ -80,10 +81,11 @@ struct AttentionParams { int* locks; // context parallel - int cp_rank{0}; - int cp_size{1}; - float* cp_M{nullptr}; - float* cp_L{nullptr}; + int cp_rank{0}; + int cp_size{1}; + cutlass::FastDivmod cp_divmod{1}; + float* cp_M{nullptr}; + float* cp_L{nullptr}; int arch; cudaStream_t stream; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index fdc9c5650f..a841d1cb1c 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -2,8 +2,6 @@ #pragma once -#include "cutlass/fast_math.h" - #include "quantization.h" #include "src/turbomind/kernels/attention/reduce_kernel.h" #include "src/turbomind/kernels/attention/rotary_embedding.h" @@ -258,9 +256,8 @@ struct AttentionUniversal { const int qi = offset.y / CTA_H; const int ti = history_len; - cutlass::FastDivmod cp_divmod{params.cp_size}; - int cp_quo, cp_rem; - cp_divmod(cp_quo, cp_rem, ti); + int cp_quo, cp_rem; + params.cp_divmod(cp_quo, cp_rem, ti); Array param_K[1]; Array param_V[1]; @@ -380,11 +377,9 @@ struct AttentionUniversal { const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx]; const int history_len = context_len - input_len; - cutlass::FastDivmod cp_divmod{params.cp_size}; - auto get_cp_len = [&](int length) -> int { int cp_quo, cp_rem; - cp_divmod(cp_quo, cp_rem, length); + params.cp_divmod(cp_quo, cp_rem, length); return (cp_quo + (cp_rem > params.cp_rank ? 1 : 0)); }; diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index ffa8a41b42..18457af18d 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -2,8 +2,6 @@ #include -#include "cutlass/fast_math.h" - #include "src/turbomind/kernels/attention/block.h" #include "src/turbomind/kernels/attention/kv_cache_utils_v2.h" #include "src/turbomind/kernels/attention/quantization.h" @@ -15,6 +13,8 @@ namespace turbomind { +using cutlass::FastDivmod; + template __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const T* k, @@ -30,8 +30,8 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, int64_t stride_h, int64_t stride_s, int layer_id, - int cp_size, int cp_rank, + FastDivmod cp_divmod, BlockLayout block_layout) { @@ -156,8 +156,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } - cutlass::FastDivmod cp_divmod{cp_size}; - int cp_quo, cp_rem; + int cp_quo, cp_rem; blocks += cu_block_num[batch_idx]; @@ -206,8 +205,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, int block_seq_len, int layer_id, - int cp_size, int cp_rank, + FastDivmod cp_divmod, int max_q_len, int head_num, int head_dim, @@ -243,8 +242,8 @@ void invokeProcessKV_v2(char** blocks, stride_h, stride_s, layer_id, - cp_size, cp_rank, + cp_divmod, block_layout); }; @@ -288,8 +287,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, \ int block_seq_len, \ int layer_id, \ - int cp_size, \ int cp_rank, \ + FastDivmod cp_divmod, \ int max_q_len, \ int head_num, \ int head_dim, \ @@ -314,8 +313,8 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, int64_t stride_h, int64_t stride_s, int layer_id, - int cp_size, int cp_rank, + FastDivmod cp_divmod, BlockLayout block_layout) { constexpr int kVecSize = sizeof(uint4) / sizeof(T); @@ -357,8 +356,7 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, Array param_K[ITER_S]; Array param_V[ITER_S]; - cutlass::FastDivmod cp_divmod{cp_size}; - int cp_quo, cp_rem; + int cp_quo, cp_rem; PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { @@ -435,8 +433,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, int block_seq_len, int layer_id, - int cp_size, int cp_rank, + FastDivmod cp_divmod, int max_seq_len, int head_num, int head_dim, @@ -469,8 +467,8 @@ void invokeFlattenKV_v2(T* k, stride_h, stride_s, layer_id, - cp_size, cp_rank, + cp_divmod, block_layout); }; @@ -511,8 +509,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, \ int block_seq_len, \ int layer_id, \ - int cp_size, \ int cp_rank, \ + FastDivmod cp_divmod, \ int max_seq_len, \ int head_num, \ int head_dim, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index 5419979c29..c959a9c9bf 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -23,8 +23,8 @@ void invokeProcessKV_v2(char** blocks, int64_t stride_s, int block_seq_len, int layer_id, - int cp_size, int cp_rank, + cutlass::FastDivmod cp_divmod, int max_q_len, int head_num, int head_dim, @@ -50,8 +50,8 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.stride / params.size_per_head, // stride s params.block_iter_params.block_len, params.block_iter_params.layer_id, - params.cp_size, params.cp_rank, + params.cp_divmod, params.max_q_len, params.num_kv_heads, params.size_per_head, @@ -73,8 +73,8 @@ void invokeFlattenKV_v2(T* k, int64_t stride_s, int block_seq_len, int layer_id, - int cp_size, int cp_rank, + cutlass::FastDivmod cp_divmod, int max_seq_len, int head_num, int head_dim, @@ -99,8 +99,8 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) 1, params.block_iter_params.block_len, params.block_iter_params.layer_id, - params.cp_size, params.cp_rank, + params.cp_divmod, params.max_k_len, params.num_kv_heads, params.size_per_head, diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index b58d4e3a85..2186850712 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -31,6 +31,7 @@ target_link_libraries(Llama PUBLIC CUDA::cudart core gemm2 CUDA::cublas + nvidia::cutlass::cutlass rms_norm DynamicDecodeLayer activation_kernels diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 7723a6339e..ff7ec245fb 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -238,7 +238,7 @@ void UnifiedAttentionLayer::Forward(ForwardParam p) template void UnifiedAttentionLayer::cp_postprocess(Tensor& attn) { - + NvtxScope scope("cp"); const int token_num = attn.shape(0); const int count = token_num * local_head_num_; d_comm_->AllGatherCP(cp_M_.data() + count * engine_param_.attn_cp_rank, @@ -377,6 +377,7 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, params.cp_rank = engine_param_.attn_cp_rank; params.cp_size = engine_param_.attn_cp_size; if (params.cp_size > 1) { + params.cp_divmod = cutlass::FastDivmod(params.cp_size); const int off_ML = q_count * local_head_num_ * engine_param_.attn_cp_rank; const int off_O = q_count * local_head_num_ * size_per_head_ * engine_param_.attn_cp_rank; params.cp_M = cp_M_.data() + off_ML; From e4050a4de7b664495a5530ec64da3ef797d56a0d Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 25 Sep 2025 11:51:39 +0000 Subject: [PATCH 12/31] use single buffer --- .../kernels/attention/attention_params.h | 3 +- .../kernels/attention/attention_template.h | 3 +- .../kernels/attention/attention_universal.h | 9 +++--- .../kernels/attention/decoding_template.h | 3 +- src/turbomind/kernels/attention/reduce.cu | 9 ++---- src/turbomind/kernels/attention/reduce.h | 3 +- .../kernels/attention/reduce_kernel.h | 19 ++++++------- src/turbomind/models/llama/LlamaBatch.cc | 12 +++----- src/turbomind/models/llama/LlamaBatch.h | 3 +- src/turbomind/models/llama/LlamaV2.cc | 6 ++-- src/turbomind/models/llama/LlamaV2.h | 3 +- src/turbomind/models/llama/cp_utils.cu | 28 ++++++------------- src/turbomind/models/llama/cp_utils.h | 3 +- .../models/llama/unified_attention_layer.cc | 24 +++++----------- .../models/llama/unified_attention_layer.h | 3 +- 15 files changed, 45 insertions(+), 86 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index 5d15a0eec3..d3a0b30b27 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -84,8 +84,7 @@ struct AttentionParams { int cp_rank{0}; int cp_size{1}; cutlass::FastDivmod cp_divmod{1}; - float* cp_M{nullptr}; - float* cp_L{nullptr}; + float* cp_ML{nullptr}; int arch; cudaStream_t stream; diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index b7257fa334..c4474a7ad3 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -85,8 +85,7 @@ void invokeAttention(const typename Kernel::ParamType& params) params.partial_M, params.partial_L, params.partial_O, - params.cp_M, - params.cp_L, + params.cp_ML, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index a841d1cb1c..ce61410a0c 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -555,8 +555,7 @@ struct AttentionUniversal { params.partial_M, params.partial_L, params.partial_O, - params.cp_M, - params.cp_L, + params.cp_ML, qi_begin, head_idx, params.num_heads, @@ -646,9 +645,9 @@ struct AttentionUniversal { params.partial_L[index] = L; } if (params.cp_size > 1 && split_idx == 0) { - const int index = (qi_begin + qi) * params.num_heads + (head_idx + hi); - params.cp_M[index] = M; - params.cp_L[index] = L; + const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * 2; + params.cp_ML[index] = M; + params.cp_ML[index + 1] = L; } } }); diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index d22217dc6c..44cf31f22d 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -84,8 +84,7 @@ bool invokeDecoding(const typename Kernel::ParamType& params) params.partial_M, params.partial_L, params.partial_O, - params.cp_M, - params.cp_L, + params.cp_ML, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index 493c113b4c..44f5a7aa8e 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -12,8 +12,7 @@ void invokeReduce(T* out, float* partial_M, float* partial_L, float* partial_O, - float* cp_M, - float* cp_L, + float* cp_ML, const int* split_cnt, int partial_len, int max_split_cnt, @@ -36,8 +35,7 @@ void invokeReduce(T* out, partial_M, partial_L, partial_O, - cp_M, - cp_L, + cp_ML, nullptr, split_cnt, partial_len, @@ -62,8 +60,7 @@ void invokeReduce(T* out, float* partial_M, \ float* partial_L, \ float* partial_O, \ - float* cp_M, \ - float* cp_L, \ + float* cp_ML, \ const int* split_cnt, \ int partial_len, \ int max_split_cnt, \ diff --git a/src/turbomind/kernels/attention/reduce.h b/src/turbomind/kernels/attention/reduce.h index d1f06a075c..f8e3ac9318 100644 --- a/src/turbomind/kernels/attention/reduce.h +++ b/src/turbomind/kernels/attention/reduce.h @@ -16,8 +16,7 @@ void invokeReduce(T* out, float* partial_M, float* partial_L, float* partial_O, - float* cp_M, - float* cp_L, + float* cp_ML, const int* split_cnt, int partial_len, int max_split_cnt, diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h index 48c89e940b..68871aa065 100644 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ b/src/turbomind/kernels/attention/reduce_kernel.h @@ -27,8 +27,7 @@ struct Reduce { float* partial_M, float* partial_L, float* partial_O, - float* cp_M, - float* cp_L, + float* cp_ML, int query_idx, int head_idx, int head_num, @@ -104,7 +103,7 @@ struct Reduce { Array scale; PRAGMA_UNROLL for (int k = 0; k < K; ++k) { - scale[k] = (IsFinal && cp_M == nullptr) ? expdiff_M[k] / block_L : expdiff_M[k]; + scale[k] = (IsFinal && cp_ML == nullptr) ? expdiff_M[k] / block_L : expdiff_M[k]; } if (hi < CTA_H) { @@ -127,10 +126,10 @@ struct Reduce { } } else { - if (cp_M != nullptr && cp_L != nullptr && lane_id % L == 0 && hi < hi_end) { - const int idx = query_idx * head_num + head_idx + hi; - cp_M[idx] = block_M; - cp_L[idx] = block_L; + if (cp_ML != nullptr && lane_id % L == 0 && hi < hi_end) { + const int idx = (query_idx * head_num + head_idx + hi) * 2; + cp_ML[idx] = block_M; + cp_ML[idx + 1] = block_L; } } } @@ -221,8 +220,7 @@ __global__ void reduce_kernel(typename Reduce::T* out, float* partial_M, float* partial_L, float* partial_O, - float* cp_M, - float* cp_L, + float* cp_ML, int* signals, const int* split_cnt_, int max_split_cnt, @@ -249,8 +247,7 @@ __global__ void reduce_kernel(typename Reduce::T* out, partial_M, partial_L, partial_O, - cp_M, - cp_L, + cp_ML, query_idx, head_idx, head_num, diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index d278605450..6a3d7c45d9 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -829,8 +829,7 @@ void LlamaBatch::AllocSymmBuffers() symm_logits_buf_ = {{max_batch_size_, vocab_size_padded}, data_type_, symm_alloc_}; if (param_.attn_cp_size > 1) { - symm_cp_M_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_}, symm_alloc_}; - symm_cp_L_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_}, symm_alloc_}; + symm_cp_ML_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_, 2}, symm_alloc_}; } } @@ -839,8 +838,7 @@ void LlamaBatch::FreeSymmBuffers() symm_hidden_states_buf_ = {}; symm_logits_buf_ = {}; - symm_cp_M_ = {}; - symm_cp_L_ = {}; + symm_cp_ML_ = {}; } LlamaBatch::~LlamaBatch() @@ -1581,8 +1579,7 @@ bool LlamaBatch::Forward(GenerationState& g) state_->h_context_length.slice(first, mini_batch_size), rope_theta_.slice(first, mini_batch_size), &mrope, - symm_cp_M_, - symm_cp_L_, + symm_cp_ML_, finished_buf_.slice(first, mini_batch_size), Buffer(local_token_nums.data(), local_token_nums.size(), kCPU), lora_mask_buf_, @@ -1775,8 +1772,7 @@ void LlamaBatch::Warmup() Buffer{&input_length, 1, kCPU}, rope_theta_.slice(0, bsz), nullptr, // mrope - symm_cp_M_, - symm_cp_L_, + symm_cp_ML_, finished_buf_.slice(0, bsz), Buffer{local_token_nums.data(), (int)local_token_nums.size(), kCPU}, Buffer{}, diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 03c86b7f3c..23bfdfc7a9 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -246,8 +246,7 @@ class LlamaBatch { Tensor symm_logits_buf_; // context parallel - Tensor_ symm_cp_M_; - Tensor_ symm_cp_L_; + Tensor_ symm_cp_ML_; Tensor decoder_output_buf_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 671f0d9549..549e981a6a 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -163,8 +163,7 @@ void LlamaV2::Forward(Buffer_ input_ids, Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, - Tensor cp_M, - Tensor cp_L, + Tensor cp_ML, Buffer finished, Buffer local_token_nums, Buffer lora_mask, @@ -260,8 +259,7 @@ void LlamaV2::Forward(Buffer_ input_ids, {"decode_num", Buffer{&decode_num, 1, kCPU}}, {"prefil_num", Buffer{&prefil_num, 1, kCPU}}, {"rope_base", rope_base}, - {"cp_M", cp_M}, - {"cp_L", cp_L}, + {"cp_ML", cp_ML}, {"cu_block_nums", cu_block_nums}, {"kv_block_ptrs", kv_block_ptrs}, {"local_token_nums", local_token_nums}}; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index bac1503db7..9283d5967b 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -69,8 +69,7 @@ class LlamaV2 { Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, - Tensor cp_M, - Tensor cp_L, + Tensor cp_ML, Buffer finished, Buffer local_token_nums, Buffer lora_mask, diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu index dabb70c291..9aaaad8e12 100644 --- a/src/turbomind/models/llama/cp_utils.cu +++ b/src/turbomind/models/llama/cp_utils.cu @@ -5,15 +5,8 @@ namespace turbomind { template -__global__ void CpReduce(T* out, - float* M, - float* L, - int token_num, - int head_num, - int size_per_head, - int cp_size, - int cp_rank, - float exp_scale) +__global__ void +CpReduce(T* out, float* ML, int token_num, int head_num, int size_per_head, int cp_size, int cp_rank, float exp_scale) { __shared__ float scale[WARP_SIZE]; float frag_M = -std::numeric_limits::infinity(); @@ -25,9 +18,9 @@ __global__ void CpReduce(T* out, const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; if (warp_id == 0 && lane_id < cp_size) { - const int index = lane_id * token_num * head_num + token_idx * head_num + head_idx; - frag_M = M[index]; - frag_L = L[index]; + const int index = (lane_id * token_num * head_num + token_idx * head_num + head_idx) * 2; + frag_M = ML[index]; + frag_L = ML[index + 1]; } float block_M = frag_M; @@ -58,8 +51,7 @@ __global__ void CpReduce(T* out, template void invokeCpReduce(T* out, - float* M, - float* L, + float* ML, int token_num, int head_num, int size_per_head, @@ -73,13 +65,12 @@ void invokeCpReduce(T* out, const dim3 grid(token_num, head_num); size_t smem_size = sizeof(float) * WARP_SIZE; CpReduce<<>>( - out, M, L, token_num, head_num, size_per_head, cp_size, cp_rank, exp_scale); + out, ML, token_num, head_num, size_per_head, cp_size, cp_rank, exp_scale); sync_check_cuda_error(); } template void invokeCpReduce(half* out, - float* M, - float* L, + float* ML, int token_num, int head_num, int size_per_head, @@ -89,8 +80,7 @@ template void invokeCpReduce(half* out, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeCpReduce(__nv_bfloat16* out, - float* M, - float* L, + float* ML, int token_num, int head_num, int size_per_head, diff --git a/src/turbomind/models/llama/cp_utils.h b/src/turbomind/models/llama/cp_utils.h index 6ac06e0ffd..b7992e9f1f 100644 --- a/src/turbomind/models/llama/cp_utils.h +++ b/src/turbomind/models/llama/cp_utils.h @@ -6,8 +6,7 @@ namespace turbomind { template void invokeCpReduce(T* out, - float* M, - float* L, + float* ML, int token_num, int head_num, int size_per_head, diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index ff7ec245fb..4e8877b398 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -141,8 +141,7 @@ void UnifiedAttentionLayer::Initialize(TensorMap& args) kv_block_ptrs_ = args.at("kv_block_ptrs").buffer(); if (engine_param_.attn_cp_size > 1) { - cp_M_ = args.at("cp_M").borrow(); - cp_L_ = args.at("cp_L").borrow(); + cp_ML_ = args.at("cp_ML").borrow(); } // rotary embedding, add offest when forward @@ -240,15 +239,9 @@ void UnifiedAttentionLayer::cp_postprocess(Tensor& attn) { NvtxScope scope("cp"); const int token_num = attn.shape(0); - const int count = token_num * local_head_num_; - d_comm_->AllGatherCP(cp_M_.data() + count * engine_param_.attn_cp_rank, - cp_M_.data(), - cp_L_.data() + count * engine_param_.attn_cp_rank, - cp_L_.data(), - count, - kFloat32, - attn_cp_group_, - stream_); + const int count = token_num * local_head_num_ * 2; + d_comm_->AllGather( + cp_ML_.data() + count * engine_param_.attn_cp_rank, cp_ML_.data(), count, kFloat32, attn_cp_group_, stream_); sync_check_cuda_error(); float inv_sqrt_dh = (float)std::log2(expf(1.)); @@ -260,8 +253,7 @@ void UnifiedAttentionLayer::cp_postprocess(Tensor& attn) } invokeCpReduce(attn.data(), - cp_M_.data(), - cp_L_.data(), + cp_ML_.data(), token_num, local_head_num_, size_per_head_, @@ -378,10 +370,8 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, params.cp_size = engine_param_.attn_cp_size; if (params.cp_size > 1) { params.cp_divmod = cutlass::FastDivmod(params.cp_size); - const int off_ML = q_count * local_head_num_ * engine_param_.attn_cp_rank; - const int off_O = q_count * local_head_num_ * size_per_head_ * engine_param_.attn_cp_rank; - params.cp_M = cp_M_.data() + off_ML; - params.cp_L = cp_L_.data() + off_ML; + const int off_ML = engine_param_.attn_cp_rank * q_count * local_head_num_ * 2; + params.cp_ML = cp_ML_.data() + off_ML; } params.arch = arch_; diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index ec4c4e8ff1..84164a6f0f 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -124,8 +124,7 @@ class UnifiedAttentionLayer { Tensor_ barriers_; // always zero // context parallel - Tensor_ cp_M_; - Tensor_ cp_L_; + Tensor_ cp_ML_; Event event_; From f44ef96eb2d33632798d80a4999051bede4b2e8d Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 26 Sep 2025 04:04:58 +0000 Subject: [PATCH 13/31] udpate comm --- src/turbomind/models/llama/context.h | 2 +- src/turbomind/models/llama/unified_decoder.cc | 8 ++++---- src/turbomind/models/llama/unified_decoder.h | 2 +- .../triton_backend/llama/LlamaTritonModel.cc | 17 ++++++++--------- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/turbomind/models/llama/context.h b/src/turbomind/models/llama/context.h index d5e7891077..e1d7cee7a9 100644 --- a/src/turbomind/models/llama/context.h +++ b/src/turbomind/models/llama/context.h @@ -18,10 +18,10 @@ namespace turbomind { struct Communicators { comm::HostComm h_comm; comm::HostComm h_tp_cp_group; - comm::HostComm h_tp_group; comm::HostComm h_dp_group; comm::DeviceComm d_comm; + int d_tp_cp_group; int d_tp_group; int d_cp_group; }; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 6405a3e225..daf8c905eb 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -29,7 +29,7 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, attn_dp_size_(engine.attn_dp_size), attn_dp_rank_(engine.attn_dp_rank), mlp_tp_size_(engine.mlp_tp_size), - attn_tp_group_(ctx.comm.d_tp_group), + attn_tp_cp_group_(ctx.comm.d_tp_cp_group), rmsnorm_eps_(model.norm_eps), stream_(ctx.stream), d_comm_(ctx.comm.d_comm), @@ -58,7 +58,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, { const auto dtype = hidden_states.dtype(); if (0) {} - else if (engine_param_.attn_dp_size > 1 && engine_param_.attn_cp_size == 1) { + else if (group0 || group1) { d_comm_->AllreduceResidualBiasRMSnormEx(hidden_states.raw_data(), residual.data_or((void*)nullptr), bias.data_or((void*)nullptr), @@ -175,7 +175,7 @@ void UnifiedDecoder::Forward(TensorMap& args, const std::vector& we weights.at(layer)->self_attn_weights->output.bias, weights.at(layer)->ffn_norm, local_token_num, - attn_tp_group_, + attn_tp_cp_group_, 0, local_token_nums.data()); @@ -217,7 +217,7 @@ void UnifiedDecoder::Forward(TensorMap& args, const std::vector& we scale_weight, local_token_num, 0, - attn_tp_group_, + attn_tp_cp_group_, local_token_nums.data()); sync_check_cuda_error(); diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index 2d001c9bc3..775ed01dbf 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -33,7 +33,7 @@ class UnifiedDecoder { const int attn_dp_rank_; const int mlp_tp_size_; - const int attn_tp_group_; + const int attn_tp_cp_group_; const EngineParam engine_param_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 15aa9b133c..4d519c88ea 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -512,25 +512,24 @@ Communicators LlamaTritonModel::createCommSplits(int rank) const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; const int color_tp_cp = inner_rank / tp_cp_size; - const int color_cp = - inner_rank / engine_param_.attn_cp_size + (inner_rank / tp_cp_size) * engine_param_.attn_tp_size; - const int color_tp = - inner_rank % engine_param_.attn_cp_size + (inner_rank / tp_cp_size) * engine_param_.attn_cp_size; + const int color_tp = inner_rank % tp_cp_size % engine_param_.attn_cp_size; + const int color_cp = inner_rank % tp_cp_size / engine_param_.attn_cp_size; comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank); comm.h_tp_cp_group = comm.h_comm->Split(color_tp_cp, 0); - comm.h_tp_group = comm.h_comm->Split(color_tp, 0); comm.h_dp_group = comm.h_comm->Split(inner_rank % tp_cp_size, 0); if (comm_size_ > 1) { comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, inner_rank, comm.h_comm); // - comm.d_tp_group = 0; - comm.d_cp_group = 0; + comm.d_tp_cp_group = 0; + comm.d_tp_group = 0; + comm.d_cp_group = 0; if (engine_param_.attn_tp_size != comm_size_) { - comm.d_cp_group = comm.d_comm->Split(color_cp, 0, 0); - comm.d_tp_group = comm.d_comm->Split(color_tp, 0, 0); + comm.d_tp_cp_group = comm.d_comm->Split(color_tp_cp, 0, 0); + comm.d_tp_group = comm.d_comm->Split(color_tp, 0, comm.d_tp_cp_group); + comm.d_cp_group = comm.d_comm->Split(color_cp, 0, comm.d_tp_cp_group); } } From a329b29f268825aebd8061bbbf42b649e69e85bb Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 24 Oct 2025 02:29:17 +0000 Subject: [PATCH 14/31] use two stage reduce --- .../kernels/attention/attention_params.h | 8 +- .../kernels/attention/attention_template.h | 10 +- .../kernels/attention/attention_universal.h | 28 +- .../kernels/attention/decoding_template.h | 10 +- src/turbomind/kernels/attention/reduce.cu | 6 +- src/turbomind/kernels/attention/reduce.h | 1 - .../kernels/attention/reduce_kernel.h | 16 +- src/turbomind/models/llama/cp_utils.cu | 507 ++++++++++++++++-- src/turbomind/models/llama/cp_utils.h | 25 +- .../models/llama/unified_attention_layer.cc | 56 +- .../models/llama/unified_attention_layer.h | 8 +- 11 files changed, 549 insertions(+), 126 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index d3a0b30b27..151a3ed9eb 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -24,6 +24,8 @@ struct BlockIteratorParams { int block_len; }; +typedef void (*cp_post_fn)(void* context, int split_cnt); + /// TODO: Rename to attention::Param template struct AttentionParams { @@ -84,7 +86,11 @@ struct AttentionParams { int cp_rank{0}; int cp_size{1}; cutlass::FastDivmod cp_divmod{1}; - float* cp_ML{nullptr}; + int cp_q_offset{0}; // decode offset + float* cp_ML{nullptr}; // cp, q, h, 2 + float* cp_k_ML{nullptr}; // q, h, k, 2 + cp_post_fn cp_fn{nullptr}; + void* cp_fn_ctx{nullptr}; int arch; cudaStream_t stream; diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index c4474a7ad3..e75568c936 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -45,7 +45,8 @@ void invokeAttention(const typename Kernel::ParamType& params) return int2{sm_count, max_active_ctas}; }(); - const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S); + const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size; + const int tile_count = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S); const int max_split_count = std::min(params.max_split_k, tile_count); typename Kernel::CtaMap cta_map{ @@ -80,12 +81,15 @@ void invokeAttention(const typename Kernel::ParamType& params) std::abort(); } - if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) { + if (params.cp_fn) { + int split_k = Kernel::need_separate_reduce(split_cnt) ? split_cnt : 1; + params.cp_fn(params.cp_fn_ctx, split_k); + } + else if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) { attention::invokeReduce(params.out, params.partial_M, params.partial_L, params.partial_O, - params.cp_ML, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index ce61410a0c..b7170591b6 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -523,7 +523,7 @@ struct AttentionUniversal { StoreO(frag_O, frag_L, qi_begin, qi_end, head_idx, params, storage); } else { - StorePartial(frag_O, frag_M, frag_L, qi_begin, qi_end, head_idx, split_idx, params, storage); + StorePartial(frag_O, frag_M, frag_L, split_cnt, qi_begin, qi_end, head_idx, split_idx, params, storage); if (!separate_reduce && split_cnt > 1) { Reduce(qi_begin, head_idx, split_idx, iter_end == tile_count, params, cta_map, smem_buf); } @@ -556,6 +556,8 @@ struct AttentionUniversal { params.partial_L, params.partial_O, params.cp_ML, + params.cp_k_ML, + params.cp_q_offset, qi_begin, head_idx, params.num_heads, @@ -612,6 +614,7 @@ struct AttentionUniversal { __device__ void StorePartial(FragO& frag_O, FragM& frag_M, FragL& frag_L, + int split_cnt, int qi_begin, int qi_end, int head_idx, @@ -627,10 +630,10 @@ struct AttentionUniversal { Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) { if (qi_begin + qi < qi_end && check_h(hi)) { - if (params.max_split_k > 1) { // decode + if (split_cnt > 1) { // decode Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec); } - if (params.cp_size > 1 && split_idx == 0) { + if (params.cp_size > 1 && split_cnt == 1) { const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * kHeadDim + di; Store(¶ms.out[index], cast(vec)); } @@ -640,14 +643,23 @@ struct AttentionUniversal { Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) { const int index = get_index(hi, qi); if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) { - if (params.max_split_k > 1) { // decode + if (split_cnt > 1) { // decode params.partial_M[index] = M; params.partial_L[index] = L; } - if (params.cp_size > 1 && split_idx == 0) { - const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * 2; - params.cp_ML[index] = M; - params.cp_ML[index + 1] = L; + + auto save_cp_stats = [&](int max_split_k, int split_idx, float* ml, float M, float L) { + const int q = qi_begin + qi - params.cp_q_offset; + const int index = (q * params.num_heads + (head_idx + hi)) * max_split_k + split_idx; + ml[index * 2] = M; + ml[index * 2 + 1] = L; + }; + + if (params.cp_size > 1) { + if (split_cnt == 1) { + save_cp_stats(1, 0, params.cp_ML, M, L); + } + save_cp_stats(params.max_split_k, split_idx, params.cp_k_ML, M, L); } } }); diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index 44cf31f22d..25e72605d4 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -25,7 +25,8 @@ bool invokeDecoding(const typename Kernel::ParamType& params) }(); } - const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S); + const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size; + const int tile_count = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S); const int max_split_count = std::min(params.max_split_k, tile_count); using CtaMap = typename Kernel::CtaMap; @@ -79,12 +80,15 @@ bool invokeDecoding(const typename Kernel::ParamType& params) std::abort(); } - if (Kernel::need_separate_reduce(split_cnt)) { + if (params.cp_fn) { + int split_k = Kernel::need_separate_reduce(split_cnt) ? split_cnt : 1; + params.cp_fn(params.cp_fn_ctx, split_k); + } + else if (Kernel::need_separate_reduce(split_cnt)) { attention::invokeReduce(params.out, params.partial_M, params.partial_L, params.partial_O, - params.cp_ML, params.split_cnt, params.max_split_k, split_cnt, diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index 44f5a7aa8e..23f7547372 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -12,7 +12,6 @@ void invokeReduce(T* out, float* partial_M, float* partial_L, float* partial_O, - float* cp_ML, const int* split_cnt, int partial_len, int max_split_cnt, @@ -35,7 +34,9 @@ void invokeReduce(T* out, partial_M, partial_L, partial_O, - cp_ML, + nullptr, + nullptr, + 0, nullptr, split_cnt, partial_len, @@ -60,7 +61,6 @@ void invokeReduce(T* out, float* partial_M, \ float* partial_L, \ float* partial_O, \ - float* cp_ML, \ const int* split_cnt, \ int partial_len, \ int max_split_cnt, \ diff --git a/src/turbomind/kernels/attention/reduce.h b/src/turbomind/kernels/attention/reduce.h index f8e3ac9318..c078de5958 100644 --- a/src/turbomind/kernels/attention/reduce.h +++ b/src/turbomind/kernels/attention/reduce.h @@ -16,7 +16,6 @@ void invokeReduce(T* out, float* partial_M, float* partial_L, float* partial_O, - float* cp_ML, const int* split_cnt, int partial_len, int max_split_cnt, diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h index 68871aa065..2986119e62 100644 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ b/src/turbomind/kernels/attention/reduce_kernel.h @@ -28,6 +28,8 @@ struct Reduce { float* partial_L, float* partial_O, float* cp_ML, + float* cp_k_ML, + int cp_q_offset, int query_idx, int head_idx, int head_num, @@ -127,9 +129,13 @@ struct Reduce { } else { if (cp_ML != nullptr && lane_id % L == 0 && hi < hi_end) { - const int idx = (query_idx * head_num + head_idx + hi) * 2; - cp_ML[idx] = block_M; - cp_ML[idx + 1] = block_L; + const int idx1 = ((query_idx - cp_q_offset) * head_num + head_idx + hi) * 2; + cp_ML[idx1] = block_M; + cp_ML[idx1 + 1] = block_L; + + const int idx2 = idx1 * max_split_cnt; + cp_k_ML[idx2] = block_M; + cp_k_ML[idx2 + 1] = block_L; } } } @@ -221,6 +227,8 @@ __global__ void reduce_kernel(typename Reduce::T* out, float* partial_L, float* partial_O, float* cp_ML, + float* cp_k_ML, + int cp_q_offset, int* signals, const int* split_cnt_, int max_split_cnt, @@ -248,6 +256,8 @@ __global__ void reduce_kernel(typename Reduce::T* out, partial_L, partial_O, cp_ML, + cp_k_ML, + cp_q_offset, query_idx, head_idx, head_num, diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu index 9aaaad8e12..ae47e74e91 100644 --- a/src/turbomind/models/llama/cp_utils.cu +++ b/src/turbomind/models/llama/cp_utils.cu @@ -1,31 +1,180 @@ // Copyright (c) OpenMMLab. All rights reserved. +#include "src/turbomind/kernels/core/array.h" +#include "src/turbomind/kernels/core/array_ops.h" #include "src/turbomind/models/llama/cp_utils.h" +#include "src/turbomind/models/llama/llama_utils.h" namespace turbomind { +int next_power_of_two(int v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__global__ void ReduceK(float2* cp_ML, + float* partial_M, // q, h, k + float* partial_L, // q, h, k + int* split_cnt_, + int max_split_k, + int num_tokens, + int num_heads, + int stride_k, + int offset_k, + float exp_scale) +{ + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + offset_k *= blockIdx.z; + + const int q = blockIdx.x * WarpCnt + warp_id; + const int h = blockIdx.y; + const int split_cnt = (q >= num_tokens) ? 0 : split_cnt_[q]; + if (offset_k >= split_cnt) { + return; + } + + float frag_M = -std::numeric_limits::infinity(); + float frag_L = 0.0f; + + const int ki = lane_id * stride_k + offset_k; + const bool mask = ki < split_cnt && h < num_heads; + const int index = (q * num_heads + h) * max_split_k + ki; + + if (mask) { + frag_M = partial_M[index]; + frag_L = partial_L[index]; + } + + float block_M = frag_M; + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); + } + + float expdiff_M = exp2f((frag_M - block_M) * exp_scale); + float block_L = expdiff_M * frag_L; + + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); + } + + if (mask) { + partial_M[index] = block_M; + partial_L[index] = block_L; + + if (ki == 0 && gridDim.z == 1) { + cp_ML[q * num_heads + h] = {block_M, block_L}; + } + } +} + template -__global__ void -CpReduce(T* out, float* ML, int token_num, int head_num, int size_per_head, int cp_size, int cp_rank, float exp_scale) +void invokeReduceK(CpPostContext* ctx, AttentionParams* params, int split_cnt) { - __shared__ float scale[WARP_SIZE]; - float frag_M = -std::numeric_limits::infinity(); - float frag_L = 0.0f; + constexpr int MaxN = 32; + + int split_k = split_cnt; + int stride_k = 1; + int offset_k = 1; + + auto invoke = [&](auto n) { + constexpr int WarpCnt = 4; + const dim3 block(WarpCnt * WARP_SIZE); + const dim3 grid((params->token_num + WarpCnt - 1) / WarpCnt, params->num_heads, (split_k + n - 1) / n); + ReduceK<<stream>>>( // + (float2*)ctx->cp_ML + params->cp_rank * params->token_num * params->num_heads, + params->partial_M, + params->partial_L, + params->split_cnt, + params->max_split_k, + params->token_num, + params->num_heads, + stride_k, + offset_k * n, + params->inv_sqrt_dh); + sync_check_cuda_error(); + + stride_k *= n; + offset_k *= n; + split_k = (split_k + n - 1) / n; + }; + + auto dispatch_n = [&](int n) { + n = min(next_power_of_two(n), MaxN); + switch (n) { + case 2: + return invoke(std::integral_constant{}); + case 4: + return invoke(std::integral_constant{}); + case 8: + return invoke(std::integral_constant{}); + case 16: + return invoke(std::integral_constant{}); + case 32: + return invoke(std::integral_constant{}); + default: + TM_CHECK(0); + } + }; + + while (split_k > 1) { + dispatch_n(split_k); + } +} - const int token_idx = blockIdx.x; - const int head_idx = blockIdx.y; +template +__global__ void ReduceCP(float2* cp_ML, // cp, q, h, 2 + int cp_size, + int num_heads, + int total, + int stride, + int offset, + float exp_scale) +{ + __shared__ float2 s_ML[WarpCnt][WARP_SIZE + 1]; const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - if (warp_id == 0 && lane_id < cp_size) { - const int index = (lane_id * token_num * head_num + token_idx * head_num + head_idx) * 2; - frag_M = ML[index]; - frag_L = ML[index + 1]; + + offset *= blockIdx.y; + const int qh_offset = blockIdx.x * WARP_SIZE; + if (qh_offset >= total || offset >= cp_size) { + return; + } + + float2 ml = {-std::numeric_limits::infinity(), 0.f}; + + int qh = qh_offset + lane_id; + int ki = warp_id * stride + offset; + if (ki < cp_size && qh < total) { + ml = cp_ML[ki * total + qh]; } + s_ML[warp_id][lane_id] = ml; + + __syncthreads(); + + // Reduce + const int qh_i = lane_id / (WarpCnt * 2) * (WarpCnt * 2) + lane_id % (WarpCnt * 2) / WarpCnt + warp_id * 2; + const int wi = lane_id % WarpCnt; + + ml = s_ML[wi][qh_i]; + float frag_M = ml.x; + float frag_L = ml.y; float block_M = frag_M; PRAGMA_UNROLL - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + for (int mask = WarpCnt / 2; mask >= 1; mask /= 2) { block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); } @@ -33,61 +182,313 @@ CpReduce(T* out, float* ML, int token_num, int head_num, int size_per_head, int float block_L = frag_L * expdiff_M; PRAGMA_UNROLL - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + for (int mask = WarpCnt / 2; mask >= 1; mask /= 2) { block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); } - if (warp_id == 0 && lane_id < cp_size) { - scale[lane_id] = expdiff_M / block_L; + if (wi == 0 && (qh_offset + qh_i < total)) { + cp_ML[qh_offset + qh_i] = {block_M, block_L}; + } +} + +template +void invokeReduceCP(CpPostContext* ctx, AttentionParams* params) +{ + constexpr int MaxN = 8; + const int total = params->token_num * params->num_heads; + + int split_k = params->cp_size; + int stride_k = 1; + int offset_k = 1; + + auto invoke = [&](auto n) { + const dim3 block(n * WARP_SIZE); + const dim3 grid((total + WARP_SIZE - 1) / WARP_SIZE, (split_k + n - 1) / n); + const int shm_size = sizeof(float2) * n * (WARP_SIZE + 1); + ReduceCP<<stream>>>( // + (float2*)ctx->cp_ML, + params->cp_size, + params->num_heads, + total, + stride_k, + offset_k * n, + params->inv_sqrt_dh); + sync_check_cuda_error(); + + stride_k *= n; + offset_k *= n; + split_k = (split_k + n - 1) / n; + }; + + auto dispatch_n = [&](int n) { + n = min(next_power_of_two(n), MaxN); + switch (n) { + case 2: + return invoke(std::integral_constant{}); + case 4: + return invoke(std::integral_constant{}); + case 8: + return invoke(std::integral_constant{}); + default: + TM_CHECK(0); + } + }; + + while (split_k > 1) { + dispatch_n(split_k); + } +} + +template +__global__ void ReduceOutput(T* out, // + float* partial_O, + float* cp_k_ML, // q, h, k, 2 + float2* cp_ML, // q, h, 2 + cutlass::FastDivmod h_divmod, + int* split_cnt_, + int max_split_cnt, + int total, + int num_heads, + int stride_k, + int offset_k, + float exp_scale) +{ + __shared__ float s_out[WarpCnt][HeadDim]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + // warp_id, q, h + const int qh = blockIdx.x * M + warp_id % M; + int q, h; + h_divmod(q, h, qh); + + if (q * num_heads + h >= total) { + return; + } + + offset_k *= blockIdx.y; + const int split_cnt = (split_cnt_ != nullptr) ? split_cnt_[q] : 1; + if (offset_k >= split_cnt) { + return; + } + + float scale = 1.0f; + float2 global_ML; + + auto get_scale = [&](float2 ml, int ki) { + int index = (q * num_heads + h) * max_split_cnt * 2 + ki * 2; + return exp2f((cp_k_ML[index] - ml.x) * exp_scale) / ml.y; + }; + + if (stride_k == 1) { + global_ML = cp_ML[q * num_heads + h]; + } + + // HeadDim / WARP_SIZE + // 128 -> 4 + // 64, 192 -> 2 + constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2; + constexpr int iterC = HeadDim / (WARP_SIZE * kVecSize); + + using namespace ops; + using VecF = Array; + using VecT = Array; + + // in most cases,no split_k + if constexpr (N == 1) { + VecT frag_O; + scale = get_scale(global_ML, 0); + + PRAGMA_UNROLL + for (int c = 0; c < iterC; ++c) { + Load(frag_O, &out[(q * num_heads + h) * HeadDim + lane_id * kVecSize + c * WARP_SIZE * kVecSize]); + frag_O = cast(cast(frag_O) * scale); + Store(&out[(q * num_heads + h) * HeadDim + lane_id * kVecSize + c * WARP_SIZE * kVecSize], frag_O); + } + + return; + } + + VecF accu_O[iterC]{}; + VecF frag_O[iterC]; + + PRAGMA_UNROLL + for (int k = 0; k < N; k += WarpCnt / M) { + const int ki = (warp_id / M + k) * stride_k + offset_k; + const int base = (((q * num_heads + h) * max_split_cnt + ki) * HeadDim); // q, h, k, d + + if (ki < split_cnt) { + if (stride_k == 1) { + scale = get_scale(global_ML, ki); + } + + PRAGMA_UNROLL + for (int c = 0; c < iterC; ++c) { + const int index = base + lane_id * kVecSize + c * WARP_SIZE * kVecSize; + Load(frag_O[c], &partial_O[index]); + accu_O[c] = accu_O[c] + frag_O[c] * scale; + } + } } + PRAGMA_UNROLL + for (int c = 0; c < iterC; ++c) { + Store(&s_out[warp_id][c * WARP_SIZE * kVecSize + lane_id * kVecSize], accu_O[c]); + } + + // PRAGMA_UNROLL + // for (int w = WarpCnt / 2 / M; w > 0; w /= 2) { + // const int ki = warp_id / M; + // __syncthreads(); + // if (ki < w) { + // PRAGMA_UNROLL + // for (int c = 0; c < iterC; ++c) { + // const int index = c * WARP_SIZE * kVecSize + lane_id * kVecSize; + // (VecF&)s_out[warp_id][index] = (VecF&)s_out[warp_id][index] + (VecF&)s_out[warp_id + w * M][index]; + // } + // } + // } + __syncthreads(); + if (warp_id / M == 0) { + PRAGMA_UNROLL + for (int k = 1; k < WarpCnt / M; ++k) { + for (int c = 0; c < iterC; ++c) { + const int index = c * WARP_SIZE * kVecSize + lane_id * kVecSize; + (VecF&)s_out[warp_id][index] = (VecF&)s_out[warp_id][index] + (VecF&)s_out[warp_id + k * M][index]; + } + } + } + + if (warp_id / M == 0) { + const int base = gridDim.y == 1 ? (q * num_heads + h) * HeadDim : + (((q * num_heads + h) * max_split_cnt + offset_k) * HeadDim); + PRAGMA_UNROLL + for (int c = 0; c < iterC; ++c) { + const int off = c * WARP_SIZE * kVecSize + lane_id * kVecSize; + if (gridDim.y == 1) { + Store(&out[base + off], cast((VecF&)s_out[warp_id][off])); + } + else { + Store(&partial_O[base + off], (VecF&)s_out[warp_id][off]); + } + } + } +} + +template +void invokeReduceOutput(CpPostContext* ctx, AttentionParams* params, int split_cnt) +{ + constexpr int MaxN = 32; + + int split_k = split_cnt; + int stride_k = 1; + int offset_k = 1; + + cutlass::FastDivmod h_divmod = cutlass::FastDivmod(params->num_heads); + + auto invoke = [&](auto n, auto head_dim) { + constexpr int WarpCnt = 4; + constexpr int M = (WarpCnt + n - 1) / n; // item per block, 1, 2, 4 + const int total = params->token_num * params->num_heads; - for (int i = threadIdx.x; i < size_per_head; i += blockDim.x) { - int index = token_idx * head_num * size_per_head + head_idx * size_per_head + i; - out[index] = (T)((float)out[index] * scale[cp_rank]); + const dim3 block(WarpCnt * WARP_SIZE); + const dim3 grid((total + M - 1) / M, (split_k + n - 1) / n); + const int shm_size = WarpCnt * sizeof(float) * head_dim; + ReduceOutput<<stream>>>( // + params->out + params->cp_q_offset * params->num_heads * params->size_per_head, + params->partial_O, + params->cp_k_ML, + (float2*)ctx->cp_ML, + h_divmod, + split_cnt > 1 ? params->split_cnt : nullptr, + params->max_split_k, + total, + params->num_heads, + stride_k, + offset_k * n, + params->inv_sqrt_dh); + + sync_check_cuda_error(); + + stride_k *= n; + offset_k *= n; + split_k = (split_k + n - 1) / n; + }; + + auto dispatch_n = [&](int split_k, auto head_dim) { + int n = min(next_power_of_two(split_k), MaxN); + + switch (n) { + case 1: + return invoke(std::integral_constant{}, head_dim); + case 2: + return invoke(std::integral_constant{}, head_dim); + case 4: + return invoke(std::integral_constant{}, head_dim); + case 8: + return invoke(std::integral_constant{}, head_dim); + case 16: + return invoke(std::integral_constant{}, head_dim); + case 32: + return invoke(std::integral_constant{}, head_dim); + default: + TM_CHECK(0); + } + }; + + auto dispatch_head_dim = [&](int split_k) { + switch (params->size_per_head) { + case 64: + return dispatch_n(split_k, std::integral_constant{}); + case 128: + return dispatch_n(split_k, std::integral_constant{}); + case 192: + return dispatch_n(split_k, std::integral_constant{}); + default: + TM_CHECK(0); + } + }; + + dispatch_head_dim(split_k); + while (split_k > 1) { + dispatch_head_dim(split_k); } } template -void invokeCpReduce(T* out, - float* ML, - int token_num, - int head_num, - int size_per_head, - int cp_size, - int cp_rank, - float exp_scale, - cudaStream_t stream) +void CpReduce(CpPostContext* ctx, AttentionParams* params, int split_cnt) { - TM_CHECK(cp_size <= WARP_SIZE); - const dim3 block = 4 * WARP_SIZE; - const dim3 grid(token_num, head_num); - size_t smem_size = sizeof(float) * WARP_SIZE; - CpReduce<<>>( - out, ML, token_num, head_num, size_per_head, cp_size, cp_rank, exp_scale); + NvtxScope scope("CpReduce"); + + if (split_cnt > 1) { + invokeReduceK(ctx, params, split_cnt); + } + + const int count = params->token_num * params->num_heads * 2; + ctx->d_comm->AllGather(ctx->cp_ML + params->cp_rank * count, // + ctx->cp_ML, + count, + DataType::kFloat, + ctx->attn_cp_group, + params->stream); sync_check_cuda_error(); + + invokeReduceCP(ctx, params); + invokeReduceOutput(ctx, params, split_cnt); } -template void invokeCpReduce(half* out, - float* ML, - int token_num, - int head_num, - int size_per_head, - int cp_size, - int cp_rank, - float exp_scale, - cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeCpReduce(__nv_bfloat16* out, - float* ML, - int token_num, - int head_num, - int size_per_head, - int cp_size, - int cp_rank, - float exp_scale, - cudaStream_t stream); -#endif +void CpPost(void* context, int split_cnt) +{ + auto ctx = reinterpret_cast(context); + + auto invoke = [&](auto t) { + using T = decltype(t); + CpReduce(ctx, static_cast*>(ctx->attn_param), split_cnt); + }; + + TM_DISPATCH_PRIMARY_DTYPES(ctx->attn_type, invoke); +} } // namespace turbomind diff --git a/src/turbomind/models/llama/cp_utils.h b/src/turbomind/models/llama/cp_utils.h index b7992e9f1f..f1389089d7 100644 --- a/src/turbomind/models/llama/cp_utils.h +++ b/src/turbomind/models/llama/cp_utils.h @@ -1,18 +1,23 @@ // Copyright (c) OpenMMLab. All rights reserved. +#include "src/turbomind/comm/device_comm.h" #include "src/turbomind/core/core.h" +#include "src/turbomind/kernels/attention/attention_params.h" namespace turbomind { -template -void invokeCpReduce(T* out, - float* ML, - int token_num, - int head_num, - int size_per_head, - int cp_size, - int cp_rank, - float exp_scale, - cudaStream_t stream); +struct CpPostContext { + + CpPostContext(comm::DeviceCommImpl* d_comm, int attn_cp_group): d_comm(d_comm), attn_cp_group(attn_cp_group) {} + + comm::DeviceCommImpl* d_comm; + int attn_cp_group; + + float* cp_ML; + void* attn_param; + DataType attn_type; +}; + +void CpPost(void* context, int split_cnt); } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 4e8877b398..a1afbc9ebe 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -34,7 +34,6 @@ #include "src/turbomind/macro.h" -#include "src/turbomind/models/llama/cp_utils.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/mla_utils.h" #include "src/turbomind/models/llama/unified_attention_layer.h" @@ -75,6 +74,7 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, model_param_(model), engine_param_(engine), attn_cp_group_(ctx.comm.d_cp_group), + cp_fn_ctx_(ctx.comm.d_comm, ctx.comm.d_cp_group), d_comm_(ctx.comm.d_comm), lora_param_(lora), context_(ctx), @@ -100,6 +100,11 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, split_cnt_ = Tensor_({kMaxWorkspaceTokens}, kDEVICE); barriers_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); + if (engine_param_.attn_cp_size > 1) { + const int cp_workspace_tokens = kMaxWorkspaceTokens + engine_param_.max_forward_token_num; + cp_k_ML_ = Tensor_({cp_workspace_tokens, local_head_num_, 2}, kDEVICE); + } + Clear(split_cnt_.buffer()); Clear(barriers_.buffer()); @@ -234,36 +239,6 @@ void UnifiedAttentionLayer::Forward(ForwardParam p) sync_check_cuda_error(); } -template -void UnifiedAttentionLayer::cp_postprocess(Tensor& attn) -{ - NvtxScope scope("cp"); - const int token_num = attn.shape(0); - const int count = token_num * local_head_num_ * 2; - d_comm_->AllGather( - cp_ML_.data() + count * engine_param_.attn_cp_rank, cp_ML_.data(), count, kFloat32, attn_cp_group_, stream_); - sync_check_cuda_error(); - - float inv_sqrt_dh = (float)std::log2(expf(1.)); - if (param_.softmax_scale) { - inv_sqrt_dh *= param_.softmax_scale; - } - else { - inv_sqrt_dh /= std::sqrt((float)size_per_head_); - } - - invokeCpReduce(attn.data(), - cp_ML_.data(), - token_num, - local_head_num_, - size_per_head_, - engine_param_.attn_cp_size, - engine_param_.attn_cp_rank, - inv_sqrt_dh, - stream_); - sync_check_cuda_error(); -} - template Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights) { @@ -370,8 +345,19 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, params.cp_size = engine_param_.attn_cp_size; if (params.cp_size > 1) { params.cp_divmod = cutlass::FastDivmod(params.cp_size); - const int off_ML = engine_param_.attn_cp_rank * q_count * local_head_num_ * 2; - params.cp_ML = cp_ML_.data() + off_ML; + + const int offset_ML = engine_param_.attn_cp_size * offset * local_head_num_ * 2; + params.cp_ML = cp_ML_.data() + offset_ML + params.cp_rank * params.token_num * local_head_num_ * 2; + params.cp_k_ML = cp_k_ML_.data() + (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0); + params.cp_q_offset = offset; + + // postprocess func + params.cp_fn = CpPost; + params.cp_fn_ctx = (void*)&cp_fn_ctx_; + + cp_fn_ctx_.cp_ML = cp_ML_.data() + offset_ML; + cp_fn_ctx_.attn_param = (void*)¶ms; + cp_fn_ctx_.attn_type = attn.dtype(); } params.arch = arch_; @@ -422,10 +408,6 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, check_cuda_error(cudaStreamWaitEvent(stream_, aux_event_)); } - if ((decode_num_ || prefil_num_) && !isTuning() && engine_param_.attn_cp_size > 1) { - cp_postprocess(attn); - } - if (isTuning()) { rng_.set_stream(stream_); rng_.GenerateUniform(attn.data(), attn.size(), .02f, -.01f); diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index 84164a6f0f..5f805cbf13 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -30,6 +30,7 @@ #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaLinear.h" #include "src/turbomind/models/llama/context.h" +#include "src/turbomind/models/llama/cp_utils.h" #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/utils/cuda_utils.h" @@ -76,9 +77,6 @@ class UnifiedAttentionLayer { template Tensor core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights); - template - void cp_postprocess(Tensor& attn); - void qk_norm(Tensor& qkv, const WeightType& weights); private: @@ -124,7 +122,9 @@ class UnifiedAttentionLayer { Tensor_ barriers_; // always zero // context parallel - Tensor_ cp_ML_; + Tensor_ cp_ML_; // cp, (d+p), h, 2 + Tensor_ cp_k_ML_; // (d+p), h, k, 2 + CpPostContext cp_fn_ctx_; Event event_; From c9649c04d639d4e491079d62c5d832c5240a9270 Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 24 Oct 2025 02:56:25 +0000 Subject: [PATCH 15/31] remove unused --- lmdeploy/turbomind/turbomind.py | 2 -- src/turbomind/comm/device_comm.h | 12 ------------ src/turbomind/comm/nccl/nccl.cu | 15 --------------- .../kernels/attention/attention_universal.h | 5 ----- .../triton_backend/llama/LlamaTritonModel.cc | 14 +++++--------- 5 files changed, 5 insertions(+), 43 deletions(-) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index b13f8db3b8..ca9c36f0d5 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -236,8 +236,6 @@ def _get_params(device_id, que): tm_params[k] = [v] else: tm_params[k].append(v) - # for k, v in tm_params.items(): - # print(k, len(v)) logger.warning(f'get {len(tm_params)} model params') def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: TurbomindEngineConfig): diff --git a/src/turbomind/comm/device_comm.h b/src/turbomind/comm/device_comm.h index 8a5960c3af..a6948762df 100644 --- a/src/turbomind/comm/device_comm.h +++ b/src/turbomind/comm/device_comm.h @@ -54,18 +54,6 @@ class DeviceCommImpl { int group, cudaStream_t stream) = 0; - virtual void AllGatherCP(const void* send_M, - void* recv_M, - const void* send_L, - void* recv_L, - size_t sendcount, - DataType type, - int group, - cudaStream_t stream) - { - throw std::runtime_error("not implemented"); - } - virtual void ReduceScatter(const void* sendbuff, // void* recvbuff, size_t recvcount, diff --git a/src/turbomind/comm/nccl/nccl.cu b/src/turbomind/comm/nccl/nccl.cu index 9fc7d694fa..557c63e9a9 100644 --- a/src/turbomind/comm/nccl/nccl.cu +++ b/src/turbomind/comm/nccl/nccl.cu @@ -261,21 +261,6 @@ public: NCCLCHECK(ncclGroupEnd()); } - void AllGatherCP(const void* send_M, - void* recv_M, - const void* send_L, - void* recv_L, - size_t sendcount, - DataType type, - int group, - cudaStream_t stream) - { - NCCLCHECK(ncclGroupStart()); - NCCLCHECK(ncclAllGather(send_M, recv_M, sendcount, to_nccl_dtype(type), groups_.at(group), stream)); - NCCLCHECK(ncclAllGather(send_L, recv_L, sendcount, to_nccl_dtype(type), groups_.at(group), stream)); - NCCLCHECK(ncclGroupEnd()); - } - void ReduceScatter( const void* sendbuff, void* recvbuff, size_t recvcount, DataType type, int group, cudaStream_t stream) override { diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index b7170591b6..f82f198ceb 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -430,7 +430,6 @@ struct AttentionUniversal { const int offset_K = (first_K_tile + iter_end - 1) * CTA_S; // This is for avoiding OOB access only - // const int max_K = min(context_len, (first_K_tile + iter_end) * CTA_S); const int max_K = min(get_cp_len(context_len), (first_K_tile + iter_end) * CTA_S); int tile_iter = iter_end - iter_begin; @@ -445,10 +444,6 @@ struct AttentionUniversal { int mask_iter_front = cdiv(max(0, offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - params.window_size), CTA_S); if (params.cp_size > 1) { - // mask all iter for simplicity - // mask_iter_back = 1 << 30; - // mask_iter_front = 1 << 30; - // TODO: use accurate mask_iter mask_iter_back = cdiv(max(0, params.cp_size * (offset_K + CTA_S) - offset_Q + params.cp_rank), params.cp_size * CTA_S); mask_iter_front = cdiv(max(0, diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 4d519c88ea..d2c25ae9dc 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -453,14 +453,10 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ for (int i = 0; i < device_num; ++i) { auto& e = engine_params_[i]; e.outer_dp_rank = i / comm_size_; - // e.attn_tp_rank = i % comm_size_ % e.attn_tp_size; - // e.attn_dp_rank = i % comm_size_ / e.attn_tp_size; - // e.mlp_tp_rank = i % comm_size_; - - e.attn_cp_rank = i % comm_size_ % e.attn_cp_size; - e.attn_tp_rank = i % tp_cp_size / e.attn_cp_size; - e.attn_dp_rank = i % comm_size_ / tp_cp_size; - e.mlp_tp_rank = i % comm_size_; + e.attn_cp_rank = i % comm_size_ % e.attn_cp_size; + e.attn_tp_rank = i % tp_cp_size / e.attn_cp_size; + e.attn_dp_rank = i % comm_size_ / tp_cp_size; + e.mlp_tp_rank = i % comm_size_; } TM_LOG_INFO("%s", toString().c_str()); @@ -592,7 +588,7 @@ void LlamaTritonModel::createEngine(int device_id, int rank) if (first_create) { try { - // engine.Warmup(); + engine.Warmup(); } catch (const std::exception& e) { TM_LOG_ERROR("[Engine][Warmup] %s", e.what()); From 52766d2e3d1ca25a07e186dbe3b8bdb5736b6d00 Mon Sep 17 00:00:00 2001 From: irexyc Date: Tue, 28 Oct 2025 10:10:40 +0000 Subject: [PATCH 16/31] better AllreduceResidualRMSnorm --- src/turbomind/triton_backend/llama/LlamaTritonModel.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index d2c25ae9dc..9c6977a5d1 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -523,9 +523,11 @@ Communicators LlamaTritonModel::createCommSplits(int rank) comm.d_tp_group = 0; comm.d_cp_group = 0; if (engine_param_.attn_tp_size != comm_size_) { - comm.d_tp_cp_group = comm.d_comm->Split(color_tp_cp, 0, 0); - comm.d_tp_group = comm.d_comm->Split(color_tp, 0, comm.d_tp_cp_group); - comm.d_cp_group = comm.d_comm->Split(color_cp, 0, comm.d_tp_cp_group); + if (tp_cp_size != comm_size_) { + comm.d_tp_cp_group = comm.d_comm->Split(color_tp_cp, 0, 0); + } + comm.d_tp_group = comm.d_comm->Split(color_tp, 0, comm.d_tp_cp_group); + comm.d_cp_group = comm.d_comm->Split(color_cp, 0, comm.d_tp_cp_group); } } From b783d5c29b0643aabde82aa0870a1db1012902e8 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 29 Oct 2025 05:53:12 +0000 Subject: [PATCH 17/31] fix max_session_len --- src/turbomind/models/llama/LlamaBatch.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 6a3d7c45d9..fe6e5d059b 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -1838,7 +1838,7 @@ void LlamaBatch::InitializeBufferAndKVCache() core::Context::alloc(kDEVICE), get_free_size}); - const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len; + const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len * param_.attn_cp_size; if (max_session_len < session_len_) { if (is_driver_) { TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.", From 47a349b1844c9d87e02f3751a09a929c2bf9ffe8 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 30 Oct 2025 18:40:25 +0800 Subject: [PATCH 18/31] update docs --- docs/en/advance/context_parallel.md | 24 ++++++++++++++++++++++++ docs/en/index.rst | 1 + docs/zh_cn/advance/context_parallel.md | 23 +++++++++++++++++++++++ docs/zh_cn/index.rst | 1 + 4 files changed, 49 insertions(+) create mode 100644 docs/en/advance/context_parallel.md create mode 100644 docs/zh_cn/advance/context_parallel.md diff --git a/docs/en/advance/context_parallel.md b/docs/en/advance/context_parallel.md new file mode 100644 index 0000000000..f890de3c64 --- /dev/null +++ b/docs/en/advance/context_parallel.md @@ -0,0 +1,24 @@ +# Context Parallel + +When the memory on a single GPU is insufficient to deploy a model, it is often deployed using tensor parallelism (TP), which generally requires `num_key_value_heads` to be divisible by `TP`. If you want to deploy with `TP > num_key_value_heads`, the kv-heads should be duplicated to meet the divisibility requirement. However, this has two disadvantages: + +1. The amount of available kv_cache is halved, which reducing the maximum supported session length. +2. The maximum inference batch size is reduced, leading to lower throughput. + +To address this issue, the TurboMind inference backend supports setting `attn_dp_size`, which avoids creating copies of kv-heads, but this introduces data imbalance. To eliminate data imbalance, TurboMind supports sequence parallelism, which allowing kv_cache to be stored interleaved on different cp_ranks. See the example below: + +``` +cp_rank=2, prompt_len=5, generation_len=4 +kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 +kv_cache stored on cp_rank1: 1, 3, 5, 7 +``` + +## Usage + +Taking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way: + +``` +lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2 + +lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2 +``` diff --git a/docs/en/index.rst b/docs/en/index.rst index b64c230cb8..b28042a977 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -103,6 +103,7 @@ Documentation advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md + advance/context_parallel.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/advance/context_parallel.md b/docs/zh_cn/advance/context_parallel.md new file mode 100644 index 0000000000..d5e6c5137b --- /dev/null +++ b/docs/zh_cn/advance/context_parallel.md @@ -0,0 +1,23 @@ +# 序列并行 + +在单卡显存不足以部署模型的时候,通常会以 `TP` 的方式进行部署,而这一般要求 `num_key_value_heads` 被 `TP` 整除。如果要以 `TP > num_key_value_heads` 的方式进行部署,需要创建 kv-heads 的副本,以满足整除需求。但是这样会有两个缺点: + +1. 可用的 kvcache 数量减半,进而减少请求最大推理长度 +2. 降低推理的最大 batch 数量,减少吞吐量。 + +为了解决这个问题,TurboMind 推理后端支持设置 `attn_dp_size`,避免了创建 kv-heads 的副本,但是这会引入数据的不均衡性。为了消除数据的不均衡,TurboMind 支持了序列并行,支持将 kv_cache 交错存储到不同的 cp_rank 上,例如 +``` +cp_rank=2, prompt_len=5, generation_len=4 +kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 +kv_cache stored on cp_rank1: 1, 3, 5, 7 +``` + +## 使用说明 + +以 `Intern-S1` / `Qwen3-235B-A22B` 为例,他们的 `num_key_value_heads` 为 4,若要用 `TP=8` 的方式部署,并避免 kv_cache 的拷贝,可以用如下的方式部署 + +``` +lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2 + +lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2 +``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index bd946ba96e..733bfc585e 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能: advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md + advance/context_parallel.md .. toctree:: :maxdepth: 1 From d83a2c70a6ffbc040864e9c1771716794db6e29c Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 3 Nov 2025 13:26:35 +0000 Subject: [PATCH 19/31] fix embedding/lm_head split --- lmdeploy/turbomind/deploy/config.py | 1 + lmdeploy/turbomind/deploy/converter.py | 1 + lmdeploy/turbomind/deploy/module.py | 3 +- .../turbomind/deploy/target_model/base.py | 1 + src/turbomind/models/llama/LlamaBatch.cc | 14 +++++----- src/turbomind/models/llama/LlamaV2.cc | 4 +-- src/turbomind/models/llama/LlamaWeight.cc | 4 +-- src/turbomind/models/llama/context.h | 3 +- src/turbomind/models/llama/unified_decoder.cc | 6 ++-- src/turbomind/models/llama/unified_decoder.h | 2 +- .../triton_backend/llama/LlamaTritonModel.cc | 28 +++++++++---------- 11 files changed, 33 insertions(+), 34 deletions(-) diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index bc5d8c6998..ef34822e1a 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -72,6 +72,7 @@ class ModelConfig: expert_weight_type: str = None session_len: int = None attn_tp_size: int = 1 + attn_cp_size: int = 1 mlp_tp_size: int = 1 model_format: str = 'hf' expert_num: List[int] = () diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index b336dbd5e8..45bbf83dc1 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -179,6 +179,7 @@ def get_tm_model(model_path, tm_cfg.model_config.model_name = model_name tm_cfg.model_config.attn_tp_size = engine_config.attn_tp_size + tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model, diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 96ed4777a8..27f53ca452 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -335,14 +335,13 @@ def pad_weight(tensor: torch.Tensor, tp: int): return tensor return torch.nn.functional.pad(tensor, (0, 0, 0, pad_size), 'constant', 0) + tp = self.model.attn_tp_size * self.model.attn_cp_size if emb is not None: - tp = self.model.attn_tp_size emb = pad_weight(emb, tp=tp) self.model.save_split(emb, 'tok_embeddings.weight', split_dim=1, split_num=tp) if norm_weight is not None: self.model.export_weight(norm_weight, 'norm.weight') if output_weight is not None: - tp = self.model.attn_tp_size output_weight = pad_weight(output_weight, tp=tp) # transpose self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 1b468f0d84..c09a73c583 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -51,6 +51,7 @@ def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model self.attention_config = cfg.attention_config self.lora_config = cfg.lora_config self.attn_tp_size = self.model_config.attn_tp_size + self.attn_cp_size = self.model_config.attn_cp_size self.mlp_tp_size = self.model_config.mlp_tp_size self.out_dir = out_dir self.to_file = True if out_dir else False diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index fe6e5d059b..b3a9eae370 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -994,12 +994,12 @@ void LlamaBatch::ComputeAndOutputLogits(const Tensor& hidden_states, int first, if (symm_logits_buf_.shape(0) < token_num) { if (tp_size_ > 1) { check_cuda_error(cudaStreamSynchronize(stream_)); - comm_.h_tp_cp_group->Sync(); + comm_.h_tp_group->Sync(); } symm_logits_buf_ = {{token_num, vocab_size_padded}, data_type_, symm_alloc_}; if (tp_size_ > 1) { check_cuda_error(cudaStreamSynchronize(stream_)); - comm_.h_tp_cp_group->Sync(); + comm_.h_tp_group->Sync(); } } @@ -1236,7 +1236,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } if (need_sync) { // Release updates on request output buffers to all ranks (`Interrupt` will use it) - comm_.h_tp_cp_group->Sync(); + comm_.h_tp_group->Sync(); } } @@ -1374,14 +1374,14 @@ void LlamaBatch::InternalThreadEntry() if (state_->size == g.finished_count) { // Batch is empty, use blocking sync to avoid spinning - comm_.h_tp_cp_group->Sync(true); + comm_.h_tp_group->Sync(true); } NvtxScope scope("mainloop"); // 1. Wait while rank-0 is dequeueing // 2. Broadcast `ec` from rank-0 - Broadcast(comm_.h_tp_cp_group, req, 0); + Broadcast(comm_.h_tp_group, req, 0); if (req->abort) { TM_LOG_INFO("[InternalThreadEntry] stop requested."); @@ -1422,7 +1422,7 @@ void LlamaBatch::InternalThreadEntry() // Finished requests and corresponding output tensors will be released when notified // wait for all ranks to ensure no rank (except for output thread) will access related // resources - comm_.h_tp_cp_group->Sync(); + comm_.h_tp_group->Sync(); } if (is_driver_) { @@ -1825,7 +1825,7 @@ void LlamaBatch::InitializeBufferAndKVCache() const auto get_free_size = [&] { // size_t free{}, total{}; check_cuda_error(cudaMemGetInfo(&free, &total)); - return AllReduce(model_->comm_->h_tp_cp_group, free, comm::RedOp::kMin); + return AllReduce(model_->comm_->h_tp_group, free, comm::RedOp::kMin); }; sequence_manager_.reset(new SequenceManager{model_->layer_num_, diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 549e981a6a..247313fa06 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -64,8 +64,8 @@ LlamaV2::LlamaV2(DataType dtype, attn_param_(attn), lora_param_(lora), comm_(&ctx.comm), - tp_size_(engine.attn_tp_size), - tp_rank_(engine.attn_tp_rank), + tp_size_(engine.attn_tp_size * engine.attn_cp_size), + tp_rank_(engine.attn_tp_rank * engine.attn_cp_size + engine.attn_cp_rank), head_num_(model.head_num), size_per_head_(model.head_dim), hidden_units_(model.hidden_units), diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 8514c92e2c..c61894a3c8 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -46,8 +46,8 @@ LlamaWeight::LlamaWeight(DataType data_type, num_layer_(model.layer_num), data_type_{data_type}, weight_type_{model.weight_type}, - tp_size_(engine_param.attn_tp_size), - tp_rank_(engine_param.attn_tp_rank) + tp_size_(engine_param.attn_tp_size * engine_param.attn_cp_size), + tp_rank_(engine_param.attn_tp_rank * engine_param.attn_cp_size + engine_param.attn_cp_rank) { if (vocab_size_padded_ % tp_size_ != 0) { vocab_size_padded_ = (vocab_size_ + tp_size_ - 1) / tp_size_ * tp_size_; diff --git a/src/turbomind/models/llama/context.h b/src/turbomind/models/llama/context.h index e1d7cee7a9..666803100d 100644 --- a/src/turbomind/models/llama/context.h +++ b/src/turbomind/models/llama/context.h @@ -17,11 +17,10 @@ namespace turbomind { struct Communicators { comm::HostComm h_comm; - comm::HostComm h_tp_cp_group; + comm::HostComm h_tp_group; comm::HostComm h_dp_group; comm::DeviceComm d_comm; - int d_tp_cp_group; int d_tp_group; int d_cp_group; }; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index daf8c905eb..1e42a6025b 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -29,7 +29,7 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, attn_dp_size_(engine.attn_dp_size), attn_dp_rank_(engine.attn_dp_rank), mlp_tp_size_(engine.mlp_tp_size), - attn_tp_cp_group_(ctx.comm.d_tp_cp_group), + attn_tp_group_(ctx.comm.d_tp_group), rmsnorm_eps_(model.norm_eps), stream_(ctx.stream), d_comm_(ctx.comm.d_comm), @@ -175,7 +175,7 @@ void UnifiedDecoder::Forward(TensorMap& args, const std::vector& we weights.at(layer)->self_attn_weights->output.bias, weights.at(layer)->ffn_norm, local_token_num, - attn_tp_cp_group_, + attn_tp_group_, 0, local_token_nums.data()); @@ -217,7 +217,7 @@ void UnifiedDecoder::Forward(TensorMap& args, const std::vector& we scale_weight, local_token_num, 0, - attn_tp_cp_group_, + attn_tp_group_, local_token_nums.data()); sync_check_cuda_error(); diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index 775ed01dbf..2d001c9bc3 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -33,7 +33,7 @@ class UnifiedDecoder { const int attn_dp_rank_; const int mlp_tp_size_; - const int attn_tp_cp_group_; + const int attn_tp_group_; const EngineParam engine_param_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 9c6977a5d1..18e245966c 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -506,28 +506,26 @@ Communicators LlamaTritonModel::createCommSplits(int rank) const int outer_rank = rank / comm_size_; const int inner_rank = rank % comm_size_; - const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; - const int color_tp_cp = inner_rank / tp_cp_size; - const int color_tp = inner_rank % tp_cp_size % engine_param_.attn_cp_size; - const int color_cp = inner_rank % tp_cp_size / engine_param_.attn_cp_size; + const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; + const int color_tp = inner_rank / tp_cp_size; + const int color_cp = inner_rank / engine_param_.attn_cp_size; + const int color_dp = inner_rank % tp_cp_size; comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank); - comm.h_tp_cp_group = comm.h_comm->Split(color_tp_cp, 0); - comm.h_dp_group = comm.h_comm->Split(inner_rank % tp_cp_size, 0); + comm.h_tp_group = comm.h_comm->Split(color_tp, 0); + comm.h_dp_group = comm.h_comm->Split(color_dp, 0); if (comm_size_ > 1) { comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, inner_rank, comm.h_comm); // - comm.d_tp_cp_group = 0; - comm.d_tp_group = 0; - comm.d_cp_group = 0; - if (engine_param_.attn_tp_size != comm_size_) { - if (tp_cp_size != comm_size_) { - comm.d_tp_cp_group = comm.d_comm->Split(color_tp_cp, 0, 0); - } - comm.d_tp_group = comm.d_comm->Split(color_tp, 0, comm.d_tp_cp_group); - comm.d_cp_group = comm.d_comm->Split(color_cp, 0, comm.d_tp_cp_group); + comm.d_tp_group = 0; + comm.d_cp_group = 0; + if (engine_param_.attn_dp_size > 1) { // has attn_dp + comm.d_tp_group = comm.d_comm->Split(color_tp, 0, 0); + } + if (engine_param_.attn_cp_size > 1) { // has attn_cp + comm.d_cp_group = comm.d_comm->Split(color_cp, 0, 0); } } From c7e1e237ca32c90c48fcc33f1120aa81a76101bc Mon Sep 17 00:00:00 2001 From: irexyc Date: Tue, 4 Nov 2025 06:35:46 +0000 Subject: [PATCH 20/31] use same split_k on different cp_rank --- docs/zh_cn/advance/context_parallel.md | 1 + src/turbomind/kernels/attention/attention_universal.h | 10 +++++----- src/turbomind/models/llama/cp_utils.cu | 5 +++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/zh_cn/advance/context_parallel.md b/docs/zh_cn/advance/context_parallel.md index d5e6c5137b..68bace2181 100644 --- a/docs/zh_cn/advance/context_parallel.md +++ b/docs/zh_cn/advance/context_parallel.md @@ -6,6 +6,7 @@ 2. 降低推理的最大 batch 数量,减少吞吐量。 为了解决这个问题,TurboMind 推理后端支持设置 `attn_dp_size`,避免了创建 kv-heads 的副本,但是这会引入数据的不均衡性。为了消除数据的不均衡,TurboMind 支持了序列并行,支持将 kv_cache 交错存储到不同的 cp_rank 上,例如 + ``` cp_rank=2, prompt_len=5, generation_len=4 kv_cache stored on cp_rank0: 0, 2, 4, 6, 8 diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index f82f198ceb..28b78307b6 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -377,18 +377,18 @@ struct AttentionUniversal { const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx]; const int history_len = context_len - input_len; - auto get_cp_len = [&](int length) -> int { + auto get_cp_len = [&](int length, int rank) -> int { int cp_quo, cp_rem; params.cp_divmod(cp_quo, cp_rem, length); - return (cp_quo + (cp_rem > params.cp_rank ? 1 : 0)); + return (cp_quo + (cp_rem > rank ? 1 : 0)); }; const int last_K = history_len + min(query_idx + CTA_Q, input_len); const int last_K_tile = - (get_cp_len(last_K) - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion + (get_cp_len(last_K, 0) - 1) / CTA_S + 1; // past-the-end index to past-the-end tile index conversion const int first_K = max(history_len + query_idx - (params.window_size - 1), 0); - const int first_K_tile = get_cp_len(first_K) / CTA_S; + const int first_K_tile = get_cp_len(first_K, 0) / CTA_S; const int tile_count = last_K_tile - first_K_tile; @@ -430,7 +430,7 @@ struct AttentionUniversal { const int offset_K = (first_K_tile + iter_end - 1) * CTA_S; // This is for avoiding OOB access only - const int max_K = min(get_cp_len(context_len), (first_K_tile + iter_end) * CTA_S); + const int max_K = min(get_cp_len(context_len, params.cp_rank), (first_K_tile + iter_end) * CTA_S); int tile_iter = iter_end - iter_begin; diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu index ae47e74e91..cae1deca44 100644 --- a/src/turbomind/models/llama/cp_utils.cu +++ b/src/turbomind/models/llama/cp_utils.cu @@ -61,8 +61,9 @@ __global__ void ReduceK(float2* cp_ML, block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); } - float expdiff_M = exp2f((frag_M - block_M) * exp_scale); - float block_L = expdiff_M * frag_L; + float expdiff_M = + (frag_M == -std::numeric_limits::infinity()) ? 0.0f : exp2f((frag_M - block_M) * exp_scale); + float block_L = expdiff_M * frag_L; PRAGMA_UNROLL for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { From 8c5b28986967d93c7847861fe531c3de3a7ff490 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 04:10:41 +0000 Subject: [PATCH 21/31] always use seperate reduce for cp --- .../kernels/attention/attention_template.h | 3 +- .../kernels/attention/attention_universal.h | 25 +- .../kernels/attention/decoding_template.h | 3 +- src/turbomind/kernels/attention/reduce.cu | 3 - .../kernels/attention/reduce_kernel.h | 22 +- src/turbomind/models/llama/LlamaBatch.cc | 4 +- src/turbomind/models/llama/cp_utils.cu | 352 ++++++------------ .../models/llama/unified_attention_layer.cc | 17 +- .../models/llama/unified_attention_layer.h | 3 +- 9 files changed, 127 insertions(+), 305 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index e75568c936..ce97585b87 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -82,8 +82,7 @@ void invokeAttention(const typename Kernel::ParamType& params) } if (params.cp_fn) { - int split_k = Kernel::need_separate_reduce(split_cnt) ? split_cnt : 1; - params.cp_fn(params.cp_fn_ctx, split_k); + params.cp_fn(params.cp_fn_ctx, split_cnt); } else if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) { attention::invokeReduce(params.out, diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 28b78307b6..d8757f6ab1 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -505,7 +505,7 @@ struct AttentionUniversal { const bool separate_reduce = need_separate_reduce(cta_map.split_count()); - if (separate_reduce && iter_end == tile_count && head_idx == 0) { + if ((separate_reduce || (params.cp_size > 1 && split_cnt > 1)) && iter_end == tile_count && head_idx == 0) { // Store actual split count, only used by separate reduction kernel for (int ti = threadIdx.x; ti < CTA_Q; ti += kWarpCount * WARP_SIZE) { if (qi_begin + ti < qi_end) { @@ -519,7 +519,7 @@ struct AttentionUniversal { } else { StorePartial(frag_O, frag_M, frag_L, split_cnt, qi_begin, qi_end, head_idx, split_idx, params, storage); - if (!separate_reduce && split_cnt > 1) { + if (!separate_reduce && params.cp_size == 1) { Reduce(qi_begin, head_idx, split_idx, iter_end == tile_count, params, cta_map, smem_buf); } } @@ -550,9 +550,6 @@ struct AttentionUniversal { params.partial_M, params.partial_L, params.partial_O, - params.cp_ML, - params.cp_k_ML, - params.cp_q_offset, qi_begin, head_idx, params.num_heads, @@ -628,7 +625,7 @@ struct AttentionUniversal { if (split_cnt > 1) { // decode Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec); } - if (params.cp_size > 1 && split_cnt == 1) { + if (params.cp_size > 1 && split_cnt == 1) { // prefill const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * kHeadDim + di; Store(¶ms.out[index], cast(vec)); } @@ -643,18 +640,12 @@ struct AttentionUniversal { params.partial_L[index] = L; } - auto save_cp_stats = [&](int max_split_k, int split_idx, float* ml, float M, float L) { - const int q = qi_begin + qi - params.cp_q_offset; - const int index = (q * params.num_heads + (head_idx + hi)) * max_split_k + split_idx; - ml[index * 2] = M; - ml[index * 2 + 1] = L; - }; - if (params.cp_size > 1) { - if (split_cnt == 1) { - save_cp_stats(1, 0, params.cp_ML, M, L); - } - save_cp_stats(params.max_split_k, split_idx, params.cp_k_ML, M, L); + const int q = qi_begin + qi - params.cp_q_offset; + const int index = (q * params.num_heads + (head_idx + hi)) * params.max_split_k + split_idx; + + params.cp_ML[index * 2] = M; + params.cp_ML[index * 2 + 1] = L; } } }); diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index 25e72605d4..0706d82dde 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -81,8 +81,7 @@ bool invokeDecoding(const typename Kernel::ParamType& params) } if (params.cp_fn) { - int split_k = Kernel::need_separate_reduce(split_cnt) ? split_cnt : 1; - params.cp_fn(params.cp_fn_ctx, split_k); + params.cp_fn(params.cp_fn_ctx, split_cnt); } else if (Kernel::need_separate_reduce(split_cnt)) { attention::invokeReduce(params.out, diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index 23f7547372..c654f40d05 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -35,9 +35,6 @@ void invokeReduce(T* out, partial_L, partial_O, nullptr, - nullptr, - 0, - nullptr, split_cnt, partial_len, head_num, diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h index 2986119e62..b4c9064cfe 100644 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ b/src/turbomind/kernels/attention/reduce_kernel.h @@ -27,9 +27,6 @@ struct Reduce { float* partial_M, float* partial_L, float* partial_O, - float* cp_ML, - float* cp_k_ML, - int cp_q_offset, int query_idx, int head_idx, int head_num, @@ -105,7 +102,7 @@ struct Reduce { Array scale; PRAGMA_UNROLL for (int k = 0; k < K; ++k) { - scale[k] = (IsFinal && cp_ML == nullptr) ? expdiff_M[k] / block_L : expdiff_M[k]; + scale[k] = IsFinal ? expdiff_M[k] / block_L : expdiff_M[k]; } if (hi < CTA_H) { @@ -127,17 +124,6 @@ struct Reduce { } } } - else { - if (cp_ML != nullptr && lane_id % L == 0 && hi < hi_end) { - const int idx1 = ((query_idx - cp_q_offset) * head_num + head_idx + hi) * 2; - cp_ML[idx1] = block_M; - cp_ML[idx1 + 1] = block_L; - - const int idx2 = idx1 * max_split_cnt; - cp_k_ML[idx2] = block_M; - cp_k_ML[idx2 + 1] = block_L; - } - } } __syncthreads(); @@ -226,9 +212,6 @@ __global__ void reduce_kernel(typename Reduce::T* out, float* partial_M, float* partial_L, float* partial_O, - float* cp_ML, - float* cp_k_ML, - int cp_q_offset, int* signals, const int* split_cnt_, int max_split_cnt, @@ -255,9 +238,6 @@ __global__ void reduce_kernel(typename Reduce::T* out, partial_M, partial_L, partial_O, - cp_ML, - cp_k_ML, - cp_q_offset, query_idx, head_idx, head_num, diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index b3a9eae370..5cc55af4db 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -829,7 +829,9 @@ void LlamaBatch::AllocSymmBuffers() symm_logits_buf_ = {{max_batch_size_, vocab_size_padded}, data_type_, symm_alloc_}; if (param_.attn_cp_size > 1) { - symm_cp_ML_ = {{param_.attn_cp_size, max_forward_token_num_, (int)model_->local_head_num_, 2}, symm_alloc_}; + // prefill(cp, q, h, 1, 2), decode(cp, q, h, k, 2) + const int cp_workspace_tokens = UnifiedAttentionLayer::kMaxWorkspaceTokens + max_forward_token_num_; + symm_cp_ML_ = {{param_.attn_cp_size, cp_workspace_tokens, (int)model_->local_head_num_, 2}, symm_alloc_}; } } diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu index cae1deca44..8d4a88c46a 100644 --- a/src/turbomind/models/llama/cp_utils.cu +++ b/src/turbomind/models/llama/cp_utils.cu @@ -19,232 +19,10 @@ int next_power_of_two(int v) return v; } -template -__global__ void ReduceK(float2* cp_ML, - float* partial_M, // q, h, k - float* partial_L, // q, h, k - int* split_cnt_, - int max_split_k, - int num_tokens, - int num_heads, - int stride_k, - int offset_k, - float exp_scale) -{ - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - - offset_k *= blockIdx.z; - - const int q = blockIdx.x * WarpCnt + warp_id; - const int h = blockIdx.y; - const int split_cnt = (q >= num_tokens) ? 0 : split_cnt_[q]; - if (offset_k >= split_cnt) { - return; - } - - float frag_M = -std::numeric_limits::infinity(); - float frag_L = 0.0f; - - const int ki = lane_id * stride_k + offset_k; - const bool mask = ki < split_cnt && h < num_heads; - const int index = (q * num_heads + h) * max_split_k + ki; - - if (mask) { - frag_M = partial_M[index]; - frag_L = partial_L[index]; - } - - float block_M = frag_M; - PRAGMA_UNROLL - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); - } - - float expdiff_M = - (frag_M == -std::numeric_limits::infinity()) ? 0.0f : exp2f((frag_M - block_M) * exp_scale); - float block_L = expdiff_M * frag_L; - - PRAGMA_UNROLL - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); - } - - if (mask) { - partial_M[index] = block_M; - partial_L[index] = block_L; - - if (ki == 0 && gridDim.z == 1) { - cp_ML[q * num_heads + h] = {block_M, block_L}; - } - } -} - -template -void invokeReduceK(CpPostContext* ctx, AttentionParams* params, int split_cnt) -{ - constexpr int MaxN = 32; - - int split_k = split_cnt; - int stride_k = 1; - int offset_k = 1; - - auto invoke = [&](auto n) { - constexpr int WarpCnt = 4; - const dim3 block(WarpCnt * WARP_SIZE); - const dim3 grid((params->token_num + WarpCnt - 1) / WarpCnt, params->num_heads, (split_k + n - 1) / n); - ReduceK<<stream>>>( // - (float2*)ctx->cp_ML + params->cp_rank * params->token_num * params->num_heads, - params->partial_M, - params->partial_L, - params->split_cnt, - params->max_split_k, - params->token_num, - params->num_heads, - stride_k, - offset_k * n, - params->inv_sqrt_dh); - sync_check_cuda_error(); - - stride_k *= n; - offset_k *= n; - split_k = (split_k + n - 1) / n; - }; - - auto dispatch_n = [&](int n) { - n = min(next_power_of_two(n), MaxN); - switch (n) { - case 2: - return invoke(std::integral_constant{}); - case 4: - return invoke(std::integral_constant{}); - case 8: - return invoke(std::integral_constant{}); - case 16: - return invoke(std::integral_constant{}); - case 32: - return invoke(std::integral_constant{}); - default: - TM_CHECK(0); - } - }; - - while (split_k > 1) { - dispatch_n(split_k); - } -} - -template -__global__ void ReduceCP(float2* cp_ML, // cp, q, h, 2 - int cp_size, - int num_heads, - int total, - int stride, - int offset, - float exp_scale) -{ - __shared__ float2 s_ML[WarpCnt][WARP_SIZE + 1]; - - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - - offset *= blockIdx.y; - const int qh_offset = blockIdx.x * WARP_SIZE; - if (qh_offset >= total || offset >= cp_size) { - return; - } - - float2 ml = {-std::numeric_limits::infinity(), 0.f}; - - int qh = qh_offset + lane_id; - int ki = warp_id * stride + offset; - if (ki < cp_size && qh < total) { - ml = cp_ML[ki * total + qh]; - } - s_ML[warp_id][lane_id] = ml; - - __syncthreads(); - - // Reduce - const int qh_i = lane_id / (WarpCnt * 2) * (WarpCnt * 2) + lane_id % (WarpCnt * 2) / WarpCnt + warp_id * 2; - const int wi = lane_id % WarpCnt; - - ml = s_ML[wi][qh_i]; - float frag_M = ml.x; - float frag_L = ml.y; - - float block_M = frag_M; - PRAGMA_UNROLL - for (int mask = WarpCnt / 2; mask >= 1; mask /= 2) { - block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); - } - - float expdiff_M = exp2f((frag_M - block_M) * exp_scale); - - float block_L = frag_L * expdiff_M; - PRAGMA_UNROLL - for (int mask = WarpCnt / 2; mask >= 1; mask /= 2) { - block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); - } - - if (wi == 0 && (qh_offset + qh_i < total)) { - cp_ML[qh_offset + qh_i] = {block_M, block_L}; - } -} - -template -void invokeReduceCP(CpPostContext* ctx, AttentionParams* params) -{ - constexpr int MaxN = 8; - const int total = params->token_num * params->num_heads; - - int split_k = params->cp_size; - int stride_k = 1; - int offset_k = 1; - - auto invoke = [&](auto n) { - const dim3 block(n * WARP_SIZE); - const dim3 grid((total + WARP_SIZE - 1) / WARP_SIZE, (split_k + n - 1) / n); - const int shm_size = sizeof(float2) * n * (WARP_SIZE + 1); - ReduceCP<<stream>>>( // - (float2*)ctx->cp_ML, - params->cp_size, - params->num_heads, - total, - stride_k, - offset_k * n, - params->inv_sqrt_dh); - sync_check_cuda_error(); - - stride_k *= n; - offset_k *= n; - split_k = (split_k + n - 1) / n; - }; - - auto dispatch_n = [&](int n) { - n = min(next_power_of_two(n), MaxN); - switch (n) { - case 2: - return invoke(std::integral_constant{}); - case 4: - return invoke(std::integral_constant{}); - case 8: - return invoke(std::integral_constant{}); - default: - TM_CHECK(0); - } - }; - - while (split_k > 1) { - dispatch_n(split_k); - } -} - template __global__ void ReduceOutput(T* out, // float* partial_O, - float* cp_k_ML, // q, h, k, 2 - float2* cp_ML, // q, h, 2 + float* cp_ML, // q, h, k, 2 cutlass::FastDivmod h_divmod, int* split_cnt_, int max_split_cnt, @@ -269,23 +47,16 @@ __global__ void ReduceOutput(T* out, // } offset_k *= blockIdx.y; - const int split_cnt = (split_cnt_ != nullptr) ? split_cnt_[q] : 1; + const int split_cnt = (split_cnt_ != nullptr) ? max(split_cnt_[q], 1) : 1; if (offset_k >= split_cnt) { return; } - float scale = 1.0f; - float2 global_ML; - - auto get_scale = [&](float2 ml, int ki) { - int index = (q * num_heads + h) * max_split_cnt * 2 + ki * 2; - return exp2f((cp_k_ML[index] - ml.x) * exp_scale) / ml.y; + auto get_scale = [&](int q, int h, int ki) { // q, h, k, 2 + int index = ((q * num_heads + h) * max_split_cnt + ki) * 2; + return cp_ML[index]; }; - if (stride_k == 1) { - global_ML = cp_ML[q * num_heads + h]; - } - // HeadDim / WARP_SIZE // 128 -> 4 // 64, 192 -> 2 @@ -298,8 +69,8 @@ __global__ void ReduceOutput(T* out, // // in most cases,no split_k if constexpr (N == 1) { - VecT frag_O; - scale = get_scale(global_ML, 0); + VecT frag_O; + float scale = get_scale(q, h, 0); PRAGMA_UNROLL for (int c = 0; c < iterC; ++c) { @@ -320,9 +91,7 @@ __global__ void ReduceOutput(T* out, // const int base = (((q * num_heads + h) * max_split_cnt + ki) * HeadDim); // q, h, k, d if (ki < split_cnt) { - if (stride_k == 1) { - scale = get_scale(global_ML, ki); - } + float scale = (stride_k == 1) ? get_scale(q, h, ki) : 1.0f; PRAGMA_UNROLL for (int c = 0; c < iterC; ++c) { @@ -400,8 +169,7 @@ void invokeReduceOutput(CpPostContext* ctx, AttentionParams* params, int spli ReduceOutput<<stream>>>( // params->out + params->cp_q_offset * params->num_heads * params->size_per_head, params->partial_O, - params->cp_k_ML, - (float2*)ctx->cp_ML, + ctx->cp_ML + params->cp_rank * params->token_num * params->num_heads * params->max_split_k * 2, h_divmod, split_cnt > 1 ? params->split_cnt : nullptr, params->max_split_k, @@ -458,16 +226,105 @@ void invokeReduceOutput(CpPostContext* ctx, AttentionParams* params, int spli } } +template +__global__ void ReduceScale(float* cp_ML, // cp, q, h, k, 2 + int num_tokens, + cutlass::FastDivmod num_heads, + int* split_cnt_, + int max_split_cnt, + int cp_size, + int cp_rank, + float exp_scale) +{ + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + int qh = blockIdx.x * WarpCnt + warp_id; + int q = num_heads.div(qh); + + if (q >= num_tokens) { + return; + } + + float frag_M0 = -std::numeric_limits::infinity(); + float frag_L0 = 0.0f; + + const int split_per_rank = (split_cnt_ == nullptr) ? 1 : max(split_cnt_[q], 1); + const int split_all_rank = split_per_rank * cp_size; + + int split_i, split_k; + for (int i = lane_id; i < split_all_rank; i += WARP_SIZE) { + split_i = i / split_per_rank; + split_k = i % split_per_rank; + int index = (split_i * num_tokens * num_heads + qh) * max_split_cnt + split_k; + + float frag_M1 = cp_ML[index * 2]; + float frag_L1 = cp_ML[index * 2 + 1]; + float frag_M = fmaxf(frag_M0, frag_M1); + + frag_L1 = (frag_M1 == -std::numeric_limits::infinity()) ? + 0.0f : + exp2f((frag_M1 - frag_M) * exp_scale) * frag_L1; + frag_L0 = (frag_M0 == -std::numeric_limits::infinity()) ? + 0.0f : + exp2f((frag_M0 - frag_M) * exp_scale) * frag_L0; + + frag_L0 = frag_L1 + frag_L0; + frag_M0 = frag_M; + } + + float block_M = frag_M0; + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); + } + + float block_L = + (frag_M0 == -std::numeric_limits::infinity()) ? 0.0f : exp2f((frag_M0 - block_M) * exp_scale) * frag_L0; + + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); + } + + for (int i = lane_id; i < split_per_rank; i += WARP_SIZE) { + split_k = i % split_per_rank; + int index = (cp_rank * num_tokens * num_heads + qh) * max_split_cnt + split_k; + + float frag_M1 = cp_ML[index * 2]; + float scale = (frag_M1 == -std::numeric_limits::infinity()) ? + 0.0f : + exp2f((frag_M1 - block_M) * exp_scale) / block_L; + cp_ML[index * 2] = scale; // save to M + } +} + +template +void invokeReduceScale(CpPostContext* ctx, AttentionParams* params, int split_cnt) +{ + constexpr int WarpCnt = 4; // each warp process one token + const dim3 block(WarpCnt * WARP_SIZE); + const dim3 grid((params->token_num * params->num_heads + WarpCnt - 1) / WarpCnt); + + ReduceScale<<stream>>>( // + ctx->cp_ML, + params->token_num, + cutlass::FastDivmod(params->num_heads), + split_cnt > 1 ? params->split_cnt : nullptr, + params->max_split_k, + params->cp_size, + params->cp_rank, + params->inv_sqrt_dh); + + sync_check_cuda_error(); +} + template void CpReduce(CpPostContext* ctx, AttentionParams* params, int split_cnt) { NvtxScope scope("CpReduce"); - if (split_cnt > 1) { - invokeReduceK(ctx, params, split_cnt); - } - - const int count = params->token_num * params->num_heads * 2; + const int count = params->token_num * params->num_heads * params->max_split_k * 2; ctx->d_comm->AllGather(ctx->cp_ML + params->cp_rank * count, // ctx->cp_ML, count, @@ -476,7 +333,8 @@ void CpReduce(CpPostContext* ctx, AttentionParams* params, int split_cnt) params->stream); sync_check_cuda_error(); - invokeReduceCP(ctx, params); + invokeReduceScale(ctx, params, split_cnt); + invokeReduceOutput(ctx, params, split_cnt); } diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index a1afbc9ebe..e58584ab35 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -100,11 +100,6 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, split_cnt_ = Tensor_({kMaxWorkspaceTokens}, kDEVICE); barriers_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); - if (engine_param_.attn_cp_size > 1) { - const int cp_workspace_tokens = kMaxWorkspaceTokens + engine_param_.max_forward_token_num; - cp_k_ML_ = Tensor_({cp_workspace_tokens, local_head_num_, 2}, kDEVICE); - } - Clear(split_cnt_.buffer()); Clear(barriers_.buffer()); @@ -346,16 +341,18 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, if (params.cp_size > 1) { params.cp_divmod = cutlass::FastDivmod(params.cp_size); - const int offset_ML = engine_param_.attn_cp_size * offset * local_head_num_ * 2; - params.cp_ML = cp_ML_.data() + offset_ML + params.cp_rank * params.token_num * local_head_num_ * 2; - params.cp_k_ML = cp_k_ML_.data() + (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0); - params.cp_q_offset = offset; + const int offset_stage = + engine_param_.attn_cp_size * (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0); + const int offset_rank = params.cp_rank * params.token_num * local_head_num_ * params.max_split_k * 2; + + params.cp_ML = cp_ML_.data() + offset_stage + offset_rank; // (cp, q, h, k, 2) + params.cp_q_offset = offset; // postprocess func params.cp_fn = CpPost; params.cp_fn_ctx = (void*)&cp_fn_ctx_; - cp_fn_ctx_.cp_ML = cp_ML_.data() + offset_ML; + cp_fn_ctx_.cp_ML = cp_ML_.data() + offset_stage; cp_fn_ctx_.attn_param = (void*)¶ms; cp_fn_ctx_.attn_type = attn.dtype(); } diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index 5f805cbf13..35d76655d2 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -122,8 +122,7 @@ class UnifiedAttentionLayer { Tensor_ barriers_; // always zero // context parallel - Tensor_ cp_ML_; // cp, (d+p), h, 2 - Tensor_ cp_k_ML_; // (d+p), h, k, 2 + Tensor_ cp_ML_; // cp, (d+p), h, k, 2 CpPostContext cp_fn_ctx_; Event event_; From 4005547d6d38bf002ce2f81bd5e9707576267b37 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 06:46:43 +0000 Subject: [PATCH 22/31] add cp configuration parameter --- benchmark/profile_throughput.py | 4 ++-- lmdeploy/cli/cli.py | 2 +- lmdeploy/cli/serve.py | 10 ++++++---- lmdeploy/cli/utils.py | 6 +++--- lmdeploy/messages.py | 1 + lmdeploy/turbomind/turbomind.py | 5 ++--- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index b172dac79e..50f69921e9 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -327,7 +327,7 @@ def parse_args(): tb_group._group_actions.append(dtype_act) ArgumentHelper.dp(tb_group) - ArgumentHelper.attn_cp_size(tb_group) + ArgumentHelper.cp(tb_group) ArgumentHelper.model_format(tb_group, default='hf') ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) @@ -345,7 +345,7 @@ def main(): max_batch_size=args.concurrency // args.dp, tp=args.tp, dp=args.dp, - attn_cp_size=args.attn_cp_size, + cp=args.cp, cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, model_format=args.model_format, diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index d6982982d6..9c0a19138e 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -76,7 +76,7 @@ def add_parser_chat(): ArgumentHelper.model_format(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.communicator(tb_group) - ArgumentHelper.attn_cp_size(tb_group) + ArgumentHelper.cp(tb_group) @staticmethod def add_parser_checkenv(): diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index a3e58713dc..cb6a2a32fa 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -110,7 +110,7 @@ def add_parser_api_server(): model_format = ArgumentHelper.model_format(pt_group) hf_overrides = ArgumentHelper.hf_overrides(pt_group) enable_metrics = ArgumentHelper.enable_metrics(pt_group) - ArgumentHelper.dp(pt_group) + dp = ArgumentHelper.dp(pt_group) ArgumentHelper.ep(pt_group) ArgumentHelper.enable_microbatch(pt_group) ArgumentHelper.enable_eplb(pt_group) @@ -135,7 +135,8 @@ def add_parser_api_server(): tb_group._group_actions.append(model_format) tb_group._group_actions.append(hf_overrides) tb_group._group_actions.append(enable_metrics) - ArgumentHelper.attn_cp_size(tb_group) + tb_group._group_actions.append(dp) + ArgumentHelper.cp(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) @@ -233,7 +234,8 @@ def api_server(args): from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig(dtype=args.dtype, tp=args.tp, - attn_cp_size=args.attn_cp_size, + dp=args.dp, + cp=args.cp, max_batch_size=max_batch_size, session_len=args.session_len, model_format=args.model_format, @@ -252,7 +254,7 @@ def api_server(args): from lmdeploy.messages import VisionConfig vision_config = VisionConfig(args.vision_max_batch_size) - if args.dp == 1: + if args.dp == 1 or backend == 'turbomind': from lmdeploy.serve.openai.api_server import serve as run_api_server run_api_server(args.model_path, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 929938a95b..9fced32128 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -189,11 +189,11 @@ def ep(parser): help='expert parallelism. dp is required when pytorch engine is used.') @staticmethod - def attn_cp_size(parser): - """Add argument attn_cp_size to parser.""" + def cp(parser): + """Add argument cp to parser.""" return parser.add_argument( - '--attn-cp-size', + '--cp', type=int, default=1, help='context parallelism size in attention for turbomind backend. Should divide tp.') diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 51ce42cc3b..0ecef06e60 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -235,6 +235,7 @@ class TurbomindEngineConfig: model_format: Optional[str] = None tp: int = 1 dp: int = 1 + cp: int = 1 device_num: int = None attn_tp_size: int = None attn_cp_size: int = None diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index ca9c36f0d5..877b1489b4 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -87,7 +87,6 @@ def complete_parallel_config(cfg: TurbomindEngineConfig): def update_parallel_config(cfg: TurbomindEngineConfig): if not complete_parallel_config(cfg): - attn_cp_size = cfg.attn_cp_size or 1 total = cfg.dp * cfg.tp if not cfg.device_num: count = torch.cuda.device_count() @@ -101,8 +100,8 @@ def update_parallel_config(cfg: TurbomindEngineConfig): inner_tp_size = cfg.tp // mlp_tp_size cfg.outer_dp_size = cfg.dp // attn_dp_size cfg.attn_dp_size = attn_dp_size - cfg.attn_tp_size = inner_tp_size // attn_cp_size - cfg.attn_cp_size = attn_cp_size + cfg.attn_tp_size = inner_tp_size // cfg.cp + cfg.attn_cp_size = cfg.cp cfg.mlp_dp_size = 1 cfg.mlp_tp_size = mlp_tp_size * inner_tp_size assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size From 1d2b0983c6db229806dd9db2e0fb65f29660a116 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 07:03:30 +0000 Subject: [PATCH 23/31] remove redundant parameters --- .../kernels/attention/attention_params.h | 8 +++---- .../kernels/attention/attention_universal.h | 4 ++-- .../kernels/attention/kv_cache_utils_v2.cu | 22 +++++++++---------- .../kernels/attention/kv_cache_utils_v2.h | 8 +++---- .../models/llama/unified_attention_layer.cc | 2 +- 5 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index 151a3ed9eb..ed98c673d5 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -84,11 +84,9 @@ struct AttentionParams { // context parallel int cp_rank{0}; - int cp_size{1}; - cutlass::FastDivmod cp_divmod{1}; - int cp_q_offset{0}; // decode offset - float* cp_ML{nullptr}; // cp, q, h, 2 - float* cp_k_ML{nullptr}; // q, h, k, 2 + cutlass::FastDivmod cp_size{1}; + int cp_q_offset{0}; // decode offset + float* cp_ML{nullptr}; // cp, q, h, k, 2 cp_post_fn cp_fn{nullptr}; void* cp_fn_ctx{nullptr}; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index d8757f6ab1..9ed5168b37 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -257,7 +257,7 @@ struct AttentionUniversal { const int ti = history_len; int cp_quo, cp_rem; - params.cp_divmod(cp_quo, cp_rem, ti); + cp_quo = params.cp_size.divmod(cp_rem, ti); Array param_K[1]; Array param_V[1]; @@ -379,7 +379,7 @@ struct AttentionUniversal { auto get_cp_len = [&](int length, int rank) -> int { int cp_quo, cp_rem; - params.cp_divmod(cp_quo, cp_rem, length); + cp_quo = params.cp_size.divmod(cp_rem, length); return (cp_quo + (cp_rem > rank ? 1 : 0)); }; diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 18457af18d..72395c3808 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -31,7 +31,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, int64_t stride_s, int layer_id, int cp_rank, - FastDivmod cp_divmod, + FastDivmod cp_size, BlockLayout block_layout) { @@ -166,7 +166,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, for (int s = 0; s < ITER_S; ++s) { const int qi = offset.y + s * Map::kDeltaS + token_idx; // local offset into `input_length` const int ti = history_len + qi; // timestep - cp_divmod(cp_quo, cp_rem, ti); + cp_quo = cp_size.divmod(cp_rem, ti); if (qi < q_len && cp_rem == cp_rank) { block_head.with((char**)blocks, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL @@ -206,7 +206,7 @@ void invokeProcessKV_v2(char** blocks, int block_seq_len, int layer_id, int cp_rank, - FastDivmod cp_divmod, + FastDivmod cp_size, int max_q_len, int head_num, int head_dim, @@ -243,7 +243,7 @@ void invokeProcessKV_v2(char** blocks, stride_s, layer_id, cp_rank, - cp_divmod, + cp_size, block_layout); }; @@ -288,7 +288,7 @@ void invokeProcessKV_v2(char** blocks, int block_seq_len, \ int layer_id, \ int cp_rank, \ - FastDivmod cp_divmod, \ + FastDivmod cp_size, \ int max_q_len, \ int head_num, \ int head_dim, \ @@ -314,7 +314,7 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, int64_t stride_s, int layer_id, int cp_rank, - FastDivmod cp_divmod, + FastDivmod cp_size, BlockLayout block_layout) { constexpr int kVecSize = sizeof(uint4) / sizeof(T); @@ -361,7 +361,7 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int si = offset.y + s * Map::kDeltaS + token_idx; - cp_divmod(cp_quo, cp_rem, si); + cp_quo = cp_size.divmod(cp_rem, si); if (si < seq_len && cp_rem == cp_rank) { block_head.with((char**)blocks, cp_quo, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) { PRAGMA_UNROLL @@ -409,7 +409,7 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, for (int c = 0; c < ITER_C; ++c) { const int si = offset.y + s * Map::kDeltaS + token_idx; const int di = offset.x + c * Map::kDeltaC; - cp_divmod(cp_quo, cp_rem, si); + cp_quo = cp_size.divmod(cp_rem, si); if (si < seq_len && cp_rem == cp_rank) { const int64_t index = (batch_idx * stride_b + ti_beg * stride_c + cp_quo * stride_s + head_idx * stride_h) * HeadDim + di; @@ -434,7 +434,7 @@ void invokeFlattenKV_v2(T* k, int block_seq_len, int layer_id, int cp_rank, - FastDivmod cp_divmod, + FastDivmod cp_size, int max_seq_len, int head_num, int head_dim, @@ -468,7 +468,7 @@ void invokeFlattenKV_v2(T* k, stride_s, layer_id, cp_rank, - cp_divmod, + cp_size, block_layout); }; @@ -510,7 +510,7 @@ void invokeFlattenKV_v2(T* k, int block_seq_len, \ int layer_id, \ int cp_rank, \ - FastDivmod cp_divmod, \ + FastDivmod cp_size, \ int max_seq_len, \ int head_num, \ int head_dim, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index c959a9c9bf..e06b329e55 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -24,7 +24,7 @@ void invokeProcessKV_v2(char** blocks, int block_seq_len, int layer_id, int cp_rank, - cutlass::FastDivmod cp_divmod, + cutlass::FastDivmod cp_size, int max_q_len, int head_num, int head_dim, @@ -51,7 +51,7 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.block_iter_params.block_len, params.block_iter_params.layer_id, params.cp_rank, - params.cp_divmod, + params.cp_size, params.max_q_len, params.num_kv_heads, params.size_per_head, @@ -74,7 +74,7 @@ void invokeFlattenKV_v2(T* k, int block_seq_len, int layer_id, int cp_rank, - cutlass::FastDivmod cp_divmod, + cutlass::FastDivmod cp_size, int max_seq_len, int head_num, int head_dim, @@ -100,7 +100,7 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) params.block_iter_params.block_len, params.block_iter_params.layer_id, params.cp_rank, - params.cp_divmod, + params.cp_size, params.max_k_len, params.num_kv_heads, params.size_per_head, diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index e58584ab35..47f42d5e7c 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -339,7 +339,7 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, params.cp_rank = engine_param_.attn_cp_rank; params.cp_size = engine_param_.attn_cp_size; if (params.cp_size > 1) { - params.cp_divmod = cutlass::FastDivmod(params.cp_size); + params.cp_size = cutlass::FastDivmod(params.cp_size); const int offset_stage = engine_param_.attn_cp_size * (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0); From 77920f8dfb682aaa885a412c140711425984a743 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 07:15:07 +0000 Subject: [PATCH 24/31] remove redundant parameters --- src/turbomind/models/llama/cp_utils.cu | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu index 8d4a88c46a..530c211e12 100644 --- a/src/turbomind/models/llama/cp_utils.cu +++ b/src/turbomind/models/llama/cp_utils.cu @@ -23,11 +23,10 @@ template __global__ void ReduceOutput(T* out, // float* partial_O, float* cp_ML, // q, h, k, 2 - cutlass::FastDivmod h_divmod, + cutlass::FastDivmod num_heads, int* split_cnt_, int max_split_cnt, int total, - int num_heads, int stride_k, int offset_k, float exp_scale) @@ -40,7 +39,7 @@ __global__ void ReduceOutput(T* out, // // warp_id, q, h const int qh = blockIdx.x * M + warp_id % M; int q, h; - h_divmod(q, h, qh); + q = num_heads.divmod(h, qh); if (q * num_heads + h >= total) { return; @@ -156,7 +155,7 @@ void invokeReduceOutput(CpPostContext* ctx, AttentionParams* params, int spli int stride_k = 1; int offset_k = 1; - cutlass::FastDivmod h_divmod = cutlass::FastDivmod(params->num_heads); + cutlass::FastDivmod num_heads = cutlass::FastDivmod(params->num_heads); auto invoke = [&](auto n, auto head_dim) { constexpr int WarpCnt = 4; @@ -170,11 +169,10 @@ void invokeReduceOutput(CpPostContext* ctx, AttentionParams* params, int spli params->out + params->cp_q_offset * params->num_heads * params->size_per_head, params->partial_O, ctx->cp_ML + params->cp_rank * params->token_num * params->num_heads * params->max_split_k * 2, - h_divmod, + num_heads, split_cnt > 1 ? params->split_cnt : nullptr, params->max_split_k, total, - params->num_heads, stride_k, offset_k * n, params->inv_sqrt_dh); From f54ca433ef1d6f90765ef42c55f7077f017bac0e Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 08:30:10 +0000 Subject: [PATCH 25/31] fix build --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 699d2862a8..eacf887ad3 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ project(TurboMind LANGUAGES CXX CUDA) if (MSVC) # use standard conformant preprocessor add_compile_options($<$:/Zc:preprocessor>) + add_compile_options($<$:/Zc:__cplusplus>) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/Zc:preprocessor -Xcompiler=/Zc:__cplusplus") endif () From 1ac308018cc37c480f684a5017461353b8c02892 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 10:34:58 +0000 Subject: [PATCH 26/31] fix xgrammar build --- CMakeLists.txt | 4 ++++ builder/windows/generate.ps1 | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index eacf887ad3..c33f0bf260 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,6 +102,10 @@ if(NOT xgrammar_POPULATED) # Bring the populated content into the build add_subdirectory(${xgrammar_SOURCE_DIR} ${xgrammar_BINARY_DIR}) + if(TARGET xgrammar) + target_compile_options(xgrammar PRIVATE $<$:/utf-8>) + target_compile_options(xgrammar PRIVATE $<$:/utf-8>) + endif() endif() # the environment variable diff --git a/builder/windows/generate.ps1 b/builder/windows/generate.ps1 index 0c133b37d0..e54f8fe742 100644 --- a/builder/windows/generate.ps1 +++ b/builder/windows/generate.ps1 @@ -1,4 +1,4 @@ -cmake .. -A x64 -T "v142,cuda=$env:CUDA_PATH" ` +cmake .. -A x64 -T "v143,cuda=$env:CUDA_PATH" ` -DCMAKE_BUILD_TYPE=Release ` -DCMAKE_INSTALL_PREFIX=install ` -DBUILD_PY_FFI=ON ` From 78722251dade3a183c9fbf1aa899020ab8762994 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 13:24:10 +0000 Subject: [PATCH 27/31] update docs --- docs/en/advance/context_parallel.md | 4 ++-- docs/zh_cn/advance/context_parallel.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/en/advance/context_parallel.md b/docs/en/advance/context_parallel.md index f890de3c64..cf0c97f48b 100644 --- a/docs/en/advance/context_parallel.md +++ b/docs/en/advance/context_parallel.md @@ -18,7 +18,7 @@ kv_cache stored on cp_rank1: 1, 3, 5, 7 Taking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way: ``` -lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2 +lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2 -lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2 +lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2 ``` diff --git a/docs/zh_cn/advance/context_parallel.md b/docs/zh_cn/advance/context_parallel.md index 68bace2181..faea118505 100644 --- a/docs/zh_cn/advance/context_parallel.md +++ b/docs/zh_cn/advance/context_parallel.md @@ -18,7 +18,7 @@ kv_cache stored on cp_rank1: 1, 3, 5, 7 以 `Intern-S1` / `Qwen3-235B-A22B` 为例,他们的 `num_key_value_heads` 为 4,若要用 `TP=8` 的方式部署,并避免 kv_cache 的拷贝,可以用如下的方式部署 ``` -lmdeploy serve api_server internlm/Intern-S1 --tp 8 --attn-cp-size 2 +lmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2 -lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --attn-cp-size 2 +lmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2 ``` From 0f82ef1c9d022c8ba8b83c5a62ea91177810f727 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 5 Nov 2025 13:31:29 +0000 Subject: [PATCH 28/31] remove unused --- src/turbomind/models/llama/unified_attention_layer.cc | 2 -- src/turbomind/models/llama/unified_attention_layer.h | 3 --- src/turbomind/models/llama/unified_decoder.cc | 1 - src/turbomind/models/llama/unified_decoder.h | 2 -- 4 files changed, 8 deletions(-) diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 47f42d5e7c..d68c2cce22 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -73,9 +73,7 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, param_(attn), model_param_(model), engine_param_(engine), - attn_cp_group_(ctx.comm.d_cp_group), cp_fn_ctx_(ctx.comm.d_comm, ctx.comm.d_cp_group), - d_comm_(ctx.comm.d_comm), lora_param_(lora), context_(ctx), stream_(ctx.stream), diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index 35d76655d2..c058e79f28 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -101,9 +101,6 @@ class UnifiedAttentionLayer { cudaEvent_t qkv_event_; cudaEvent_t aux_event_; - const int attn_cp_group_; - comm::DeviceCommImpl* const d_comm_; - std::array streams_; RNG rng_; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 1e42a6025b..b771f0f00d 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -33,7 +33,6 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, rmsnorm_eps_(model.norm_eps), stream_(ctx.stream), d_comm_(ctx.comm.d_comm), - engine_param_(engine), tune_layer_num_(model.tune_layer_num) { attn_layer_ = std::make_unique(model, attn, engine, lora, attn_tp_size_, ctx); diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index 2d001c9bc3..dd03293744 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -35,8 +35,6 @@ class UnifiedDecoder { const int attn_tp_group_; - const EngineParam engine_param_; - const float rmsnorm_eps_; cudaStream_t const stream_; From 1b3bb9cafd3f8dc08b60cf412bfefcdff35bff8b Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 6 Nov 2025 03:18:44 +0000 Subject: [PATCH 29/31] fix test_attention --- src/turbomind/kernels/attention/CMakeLists.txt | 4 ++++ src/turbomind/kernels/attention/test_attention.cu | 12 +++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/turbomind/kernels/attention/CMakeLists.txt b/src/turbomind/kernels/attention/CMakeLists.txt index d1fee315cc..5ea2d64e3b 100644 --- a/src/turbomind/kernels/attention/CMakeLists.txt +++ b/src/turbomind/kernels/attention/CMakeLists.txt @@ -60,6 +60,7 @@ if (BUILD_TEST) target_link_libraries(test_attention PRIVATE attention # flash_attention + nvidia::cutlass::cutlass Llama unfused_attention_kernels logger @@ -68,4 +69,7 @@ if (BUILD_TEST) add_executable(test_quant test_quant.cu test_utils.cu) target_compile_options(test_quant PRIVATE --generate-line-info -O3 -use_fast_math --expt-relaxed-constexpr) + target_link_libraries(test_quant PRIVATE + nvidia::cutlass::cutlass + ) endif () diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 3ab706c2df..235f9a3388 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -152,7 +152,9 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, seq_len, 1, block_seq_len, - 0, + 0, // layer_id + 0, // cp_rank + 1, // cp_size seq_len, head_num, head_dim, @@ -176,7 +178,9 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, seq_len, 1, block_seq_len, - 0, + 0, // layer_id + 0, // cp_rank + 1, // cp_size seq_len, head_num, head_dim, @@ -565,7 +569,9 @@ int test_attention() kContextLen, 1, kBlockSz, - 0, + 0, // layer_id + 0, // cp_rank + 1, // cp_size kContextLen, KvHeadNum, kHeadDim, From 56b9e27c50bcd9cc2504cce545d673b7013f1083 Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 7 Nov 2025 11:33:24 +0000 Subject: [PATCH 30/31] unify attn split_k reduction w/ w/o cp --- .../kernels/attention/attention_params.h | 9 +- .../kernels/attention/attention_template.h | 31 +- .../kernels/attention/attention_universal.h | 79 +--- .../kernels/attention/decoding_template.h | 31 +- src/turbomind/kernels/attention/reduce.cu | 378 +++++++++++++++--- src/turbomind/kernels/attention/reduce.h | 24 +- .../kernels/attention/reduce_kernel.h | 254 ------------ .../kernels/attention/test_attention.cu | 14 +- src/turbomind/models/llama/LlamaBatch.cc | 21 +- src/turbomind/models/llama/LlamaBatch.h | 2 +- src/turbomind/models/llama/LlamaV2.cc | 4 +- src/turbomind/models/llama/LlamaV2.h | 2 +- src/turbomind/models/llama/cp_utils.cu | 343 +--------------- src/turbomind/models/llama/cp_utils.h | 12 +- .../models/llama/unified_attention_layer.cc | 44 +- .../models/llama/unified_attention_layer.h | 7 +- .../triton_backend/llama/LlamaTritonModel.cc | 2 +- 17 files changed, 439 insertions(+), 818 deletions(-) delete mode 100644 src/turbomind/kernels/attention/reduce_kernel.h diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index ed98c673d5..4ec526d3e2 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -24,7 +24,7 @@ struct BlockIteratorParams { int block_len; }; -typedef void (*cp_post_fn)(void* context, int split_cnt); +typedef void (*cp_post_fn)(void* context); /// TODO: Rename to attention::Param template @@ -78,15 +78,12 @@ struct AttentionParams { int max_split_k; int* split_cnt; float* partial_O; - float* partial_M; - float* partial_L; - int* locks; + float* partial_ML; // context parallel int cp_rank{0}; cutlass::FastDivmod cp_size{1}; - int cp_q_offset{0}; // decode offset - float* cp_ML{nullptr}; // cp, q, h, k, 2 + int offset_q{0}; // decode offset cp_post_fn cp_fn{nullptr}; void* cp_fn_ctx{nullptr}; diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index ce97585b87..5c8d0ddbb7 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -12,8 +12,7 @@ namespace turbomind { template void invokeAttention(const typename Kernel::ParamType& params) { - static const size_t kSmemSize = - std::max(sizeof(typename Kernel::SharedStorage), sizeof(typename Kernel::ReduceOp::SharedStorage)); + static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage); if constexpr (1) { @@ -82,20 +81,22 @@ void invokeAttention(const typename Kernel::ParamType& params) } if (params.cp_fn) { - params.cp_fn(params.cp_fn_ctx, split_cnt); + params.cp_fn(params.cp_fn_ctx); } - else if (split_cnt > 1 && Kernel::need_separate_reduce(split_cnt)) { - attention::invokeReduce(params.out, - params.partial_M, - params.partial_L, - params.partial_O, - params.split_cnt, - params.max_split_k, - split_cnt, - params.token_num, - params.num_heads, - params.inv_sqrt_dh, - params.stream); + + if (split_cnt > 1 || params.cp_size > 1) { + attention::invokeReduceV2(params.out + params.offset_q * params.num_heads * Kernel::kHeadDim, + params.partial_ML, + params.partial_O, + split_cnt > 1 ? params.split_cnt : nullptr, + params.max_split_k, + split_cnt, + params.cp_size, + params.cp_rank, + params.token_num, + params.num_heads, + params.inv_sqrt_dh, + params.stream); } } diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 9ed5168b37..ce2719aa37 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -3,7 +3,6 @@ #pragma once #include "quantization.h" -#include "src/turbomind/kernels/attention/reduce_kernel.h" #include "src/turbomind/kernels/attention/rotary_embedding.h" #include "src/turbomind/kernels/core/array_ops.h" #include "src/turbomind/kernels/core/layout.h" @@ -46,8 +45,6 @@ struct AttentionUniversal { static constexpr int CTA_Q = Impl::CTA_Q; static constexpr int CTA_S = Impl::CTA_S; - using ReduceOp = attention::Reduce; - using SharedStorage = typename Mainloop::SharedStorage; static constexpr bool kProcessKV = CTA_Q == 1; @@ -505,11 +502,11 @@ struct AttentionUniversal { const bool separate_reduce = need_separate_reduce(cta_map.split_count()); - if ((separate_reduce || (params.cp_size > 1 && split_cnt > 1)) && iter_end == tile_count && head_idx == 0) { + if (split_cnt > 1 && iter_end == tile_count && head_idx == 0) { // Store actual split count, only used by separate reduction kernel for (int ti = threadIdx.x; ti < CTA_Q; ti += kWarpCount * WARP_SIZE) { if (qi_begin + ti < qi_end) { - params.split_cnt[qi_begin + ti] = split_idx ? split_idx + 1 : 0; + params.split_cnt[qi_begin + ti] = split_idx ? split_idx + 1 : (params.cp_size > 1 ? 1 : 0); } } } @@ -519,52 +516,6 @@ struct AttentionUniversal { } else { StorePartial(frag_O, frag_M, frag_L, split_cnt, qi_begin, qi_end, head_idx, split_idx, params, storage); - if (!separate_reduce && params.cp_size == 1) { - Reduce(qi_begin, head_idx, split_idx, iter_end == tile_count, params, cta_map, smem_buf); - } - } - } - - __device__ void Reduce(int qi_begin, - int head_idx, - int split_idx, - bool is_last, - const ParamType& params, - const CtaMap& cta_map, - char* smem_buf) - { - // Note: `head_idx` is cta_map.head_idx() * CTA_H - const auto index = (cta_map.batch_idx() * params.num_heads + cta_map.head_idx()) * params.max_split_k; - const auto locks = params.locks + index; - - if (!is_last) { // all but last split - sem_post(&locks[split_idx], 1, threadIdx.x == 0); - } - else { // only the last split - const int split_count = split_idx + 1; - - sem_wait_many(&locks[threadIdx.x], split_count - 1, threadIdx.x < split_count - 1); - - ReduceOp reduce_op; - reduce_op(params.out, - params.partial_M, - params.partial_L, - params.partial_O, - qi_begin, - head_idx, - params.num_heads, - hi_end_, - split_idx + 1, - params.max_split_k, - params.inv_sqrt_dh, - 1, - 0, - *(typename ReduceOp::SharedStorage*)smem_buf, - std::true_type{}); - - if (threadIdx.x < split_idx) { - locks[threadIdx.x] = 0; - } } } @@ -616,37 +567,21 @@ struct AttentionUniversal { { auto get_index = [&](int hi, int qi) { // [B, H, k, D] - return (qi_begin + qi) * params.num_heads * params.max_split_k + (head_idx + hi) * params.max_split_k - + split_idx; + return (qi_begin + qi - params.offset_q) * params.num_heads * params.max_split_k + + (head_idx + hi) * params.max_split_k + split_idx; }; Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) { if (qi_begin + qi < qi_end && check_h(hi)) { - if (split_cnt > 1) { // decode - Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec); - } - if (params.cp_size > 1 && split_cnt == 1) { // prefill - const int index = ((qi_begin + qi) * params.num_heads + (head_idx + hi)) * kHeadDim + di; - Store(¶ms.out[index], cast(vec)); - } + Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec); } }); Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) { const int index = get_index(hi, qi); if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) { - if (split_cnt > 1) { // decode - params.partial_M[index] = M; - params.partial_L[index] = L; - } - - if (params.cp_size > 1) { - const int q = qi_begin + qi - params.cp_q_offset; - const int index = (q * params.num_heads + (head_idx + hi)) * params.max_split_k + split_idx; - - params.cp_ML[index * 2] = M; - params.cp_ML[index * 2 + 1] = L; - } + params.partial_ML[index * 2] = M; + params.partial_ML[index * 2 + 1] = L; } }); } diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index 0706d82dde..d35c09f8ff 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -12,8 +12,7 @@ namespace turbomind { template bool invokeDecoding(const typename Kernel::ParamType& params) { - static const size_t kSmemSize = - std::max(sizeof(typename Kernel::SharedStorage), sizeof(typename Kernel::ReduceOp::SharedStorage)); + static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage); if constexpr (1) { [[maybe_unused]] static const int _ = [&] { @@ -81,20 +80,22 @@ bool invokeDecoding(const typename Kernel::ParamType& params) } if (params.cp_fn) { - params.cp_fn(params.cp_fn_ctx, split_cnt); + params.cp_fn(params.cp_fn_ctx); } - else if (Kernel::need_separate_reduce(split_cnt)) { - attention::invokeReduce(params.out, - params.partial_M, - params.partial_L, - params.partial_O, - params.split_cnt, - params.max_split_k, - split_cnt, - params.token_num, - params.num_heads, - params.inv_sqrt_dh, - params.stream); + + if (split_cnt > 1 || params.cp_size > 1) { + attention::invokeReduceV2(params.out, + params.partial_ML, + params.partial_O, + split_cnt > 1 ? params.split_cnt : nullptr, + params.max_split_k, + split_cnt, + params.cp_size, + params.cp_rank, + params.token_num, + params.num_heads, + params.inv_sqrt_dh, + params.stream); } return true; diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index c654f40d05..c8e7f8df14 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -1,79 +1,355 @@ // Copyright (c) OpenMMLab. All rights reserved. +#include "cutlass/fast_math.h" #include "src/turbomind/kernels/attention/cta_map.h" -#include "src/turbomind/kernels/attention/reduce_kernel.h" +#include "src/turbomind/kernels/core/array_ops.h" +#include "src/turbomind/kernels/core/thread_map.h" +#include "src/turbomind/utils/cuda_utils.h" #include namespace turbomind::attention { +int next_power_of_two(int v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__global__ void reduce_output(T* out, + const float* partial_ML, + float* partial_O, + const int* split_cnt_, + int max_split_cnt, + int query_num, + int head_num, + float exp_scale, + int stride_k, + int offset_k) +{ + __shared__ float s_out[WarpCnt][HeadDim]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + const int head_idx = ReduceCtaMap::head_idx(); + const int query_idx = ReduceCtaMap::query_idx(); + const int chunk_idx = ReduceCtaMap::split_idx(); + + offset_k *= chunk_idx; + const int split_cnt = (split_cnt_ != nullptr) ? split_cnt_[query_idx] : 1; + if (offset_k >= split_cnt) { // out of bound + return; + } + + // HeadDim / WARP_SIZE + // 128 -> 4 + // 64, 192 -> 2 + constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2; + + using Map = RakedThreadMap; + static_assert(Map::kIterS == 1); + + constexpr int C = Map::kIterC; + + using Vec = Array; + + Vec accu_O[C]{}; + Vec frag_O[C]; + + const int2 d = Map::get_offset(warp_id, lane_id); + + auto for_each = [&](auto fn) { + const int ki = d.y; + PRAGMA_UNROLL + for (int c = 0; c < C; ++c) { + const int di = d.x + c * Map::kDeltaC; + fn(c, ki, di); + } + }; + + PRAGMA_UNROLL + for (int k = 0; k < CTA_K; k += WarpCnt) { + for_each([&](int c, int ki, int di) { + using namespace ops; + ki += k; + const int split_idx = offset_k + stride_k * ki; + const bool mask = split_idx < split_cnt; + const int index = (query_idx * head_num + head_idx) * max_split_cnt + split_idx; + const int offset = index * HeadDim + di; + if (mask) { + Load(frag_O[c], &partial_O[offset]); + accu_O[c] = accu_O[c] + frag_O[c] * (First ? partial_ML[index * 2] : 1.0f); + } + }); + } + + for_each([&](int c, int ki, int di) { + Store(&s_out[ki][di], accu_O[c]); // + }); + + PRAGMA_UNROLL + for (int w = WarpCnt / 2; w > 0; w /= 2) { + __syncthreads(); + for_each([&](int c, int ki, int di) { + using namespace ops; + if (ki < w) { + (Vec&)s_out[ki][di] = (Vec&)s_out[ki][di] + (Vec&)s_out[w + ki][di]; + } + }); + } + + for_each([&](int c, int ki, int di) { + if (ki == 0) { + if (gridDim.z == 1) { + const int offset = (query_idx * head_num + head_idx) * HeadDim + di; + Store(&out[offset], cast((Vec&)s_out[ki][di])); + } + else { + const int offset = ((query_idx * head_num + head_idx) * max_split_cnt + offset_k) * HeadDim + di; + Store(&partial_O[offset], (Vec&)s_out[ki][di]); + } + } + }); +} + template -void invokeReduce(T* out, - float* partial_M, - float* partial_L, - float* partial_O, - const int* split_cnt, - int partial_len, - int max_split_cnt, - int query_num, - int head_num, - float exp_scale, - cudaStream_t stream) +void invokeReduceOutput(T* out, + const float* partial_ML, // scale + float* partial_O, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream) { constexpr int CTA_K = 32; // warp size - using Reduce = attention::Reduce; - - static constexpr size_t kSmemSize = sizeof(typename Reduce::SharedStorage); - static_assert(kSmemSize < (48 << 10)); - - auto invoke = [&](auto is_final, int stride_k) { - const dim3 block = Reduce::kWarpCnt * 32; - const dim3 grid = ReduceCtaMap::get_grid_shape(query_num, head_num, max_split_cnt, CTA_K); - reduce_kernel<<>>(out, // - partial_M, - partial_L, - partial_O, - nullptr, - split_cnt, - partial_len, - head_num, - exp_scale, - stride_k); + auto invoke = [&](auto is_first, int stride_k) { + constexpr int kWarpCnt = 4; + const dim3 block = kWarpCnt * WARP_SIZE; + const dim3 grid = ReduceCtaMap::get_grid_shape(query_num, head_num, max_split_cnt, CTA_K); + + static constexpr size_t kSmemSize = sizeof(float) * kWarpCnt * HeadDim; + static_assert(kSmemSize < (48 << 10)); + + reduce_output<<>>( // + out, + partial_ML, + partial_O, + split_cnt, + partial_len, + query_num, + head_num, + exp_scale, + stride_k, + stride_k * CTA_K); + + sync_check_cuda_error(); }; int stride_k = 1; + invoke(std::true_type{}, stride_k); while (max_split_cnt > CTA_K) { - invoke(std::false_type{}, stride_k); max_split_cnt = (max_split_cnt + CTA_K - 1) / CTA_K; stride_k *= CTA_K; + invoke(std::false_type{}, stride_k); } +} - invoke(std::true_type{}, stride_k); +template +__global__ void reduce_ML(float* partial_ML, // cp, q, h, k, 2 + const int* split_cnt_, + int max_split_cnt, + int query_num, + cutlass::FastDivmod head_num, + float exp_scale, + int cp_size, + int dim0) +{ + constexpr int kIterWarp = N / WARP_SIZE; + + float frag_M[kIterWarp]; + float frag_L[kIterWarp]; + + int qh = blockIdx.x * blockDim.y + threadIdx.y; + if (qh >= query_num * head_num) { + return; + } + + const int split_k = split_cnt_ != nullptr ? split_cnt_[head_num.div(qh)] : 1; + const int split_cnt = cp_size * split_k; + + float block_M = -std::numeric_limits::infinity(); + float block_L = 0.f; + + PRAGMA_UNROLL + for (int i = 0; i < kIterWarp; ++i) { + int ki = threadIdx.x + i * WARP_SIZE; + int index = (qh * max_split_cnt + ki) * 2; + bool mask = ki < split_cnt; + + if (mask && dim0 > 0) { // handle cp case + int cp_i = ki / split_k; + ki = ki % split_k; + index = cp_i * dim0 + (qh * max_split_cnt + ki) * 2; + } + + frag_M[i] = mask ? partial_ML[index] : -std::numeric_limits::infinity(); + frag_L[i] = mask ? partial_ML[index + 1] : 0.f; + block_M = max(block_M, frag_M[i]); + } + + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); + } + + PRAGMA_UNROLL + for (int i = 0; i < kIterWarp; ++i) { + block_L += (frag_M[i] == -std::numeric_limits::infinity()) ? + 0.0f : + exp2f((frag_M[i] - block_M) * exp_scale) * frag_L[i]; + } + + PRAGMA_UNROLL + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); + } + + PRAGMA_UNROLL + for (int i = 0; i < kIterWarp; ++i) { + int ki = threadIdx.x + i * WARP_SIZE; + int index = (qh * max_split_cnt + ki) * 2; + bool mask = ki < split_cnt; + + if (dim0 > 0) { // handle cp case + int cp_i = ki / split_k; + ki = ki % split_k; + index = cp_i * dim0 + (qh * max_split_cnt + ki) * 2; + } + + float scale = (frag_M[i] == -std::numeric_limits::infinity()) ? + 0.0f : + exp2f((frag_M[i] - block_M) * exp_scale) / block_L; + if (mask) { + partial_ML[index] = scale; // save scale to M + } + } +} + +void invokeReduceML(float* partial_ML, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int cp_size, + int cp_rank, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream) +{ + max_split_cnt *= cp_size; + TM_CHECK(max_split_cnt > 1); + + const int warp_cnt = 4; + const dim3 block(WARP_SIZE, warp_cnt); + const dim3 grid((query_num * head_num + warp_cnt - 1) / warp_cnt); + + const int dim0 = cp_size > 1 ? query_num * head_num * partial_len * 2 : 0; + partial_ML -= cp_rank * dim0; // begin address of cp_rank0 + + int n = max(next_power_of_two(max_split_cnt), WARP_SIZE); + switch (n) { +#define LAUNCH_REDUCE_ML(n) \ + case n: \ + reduce_ML<<>>( \ + partial_ML, split_cnt, partial_len, query_num, cutlass::FastDivmod(head_num), exp_scale, cp_size, dim0); \ + break; + + LAUNCH_REDUCE_ML(32); + LAUNCH_REDUCE_ML(64); + LAUNCH_REDUCE_ML(128); + LAUNCH_REDUCE_ML(256); + LAUNCH_REDUCE_ML(512); + LAUNCH_REDUCE_ML(1024); + default: + TM_CHECK(false) << "reduce_ML does not support max_split_cnt = " << max_split_cnt; +#undef LAUNCH_REDUCE_ML + } + + sync_check_cuda_error(); +} + +template +void invokeReduceV2(T* out, + float* partial_ML, + float* partial_O, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int cp_size, + int cp_rank, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream) +{ + invokeReduceML(partial_ML, // + split_cnt, + partial_len, + max_split_cnt, + cp_size, + cp_rank, + query_num, + head_num, + exp_scale, + stream); + + invokeReduceOutput(out, // + partial_ML, + partial_O, + split_cnt, + partial_len, + max_split_cnt, + query_num, + head_num, + exp_scale, + stream); } -#define INSTANTIATE_invokeReduce(dim, type) \ - template void invokeReduce(type * out, \ - float* partial_M, \ - float* partial_L, \ - float* partial_O, \ - const int* split_cnt, \ - int partial_len, \ - int max_split_cnt, \ - int query_num, \ - int head_num, \ - float exp_scale, \ - cudaStream_t stream); - -INSTANTIATE_invokeReduce(64, half); -INSTANTIATE_invokeReduce(128, half); -INSTANTIATE_invokeReduce(192, half); +#define INSTANTIATE_invokeReduceV2(dim, type) \ + template void invokeReduceV2(type * out, \ + float* partial_ML, \ + float* partial_O, \ + const int* split_cnt, \ + int partial_len, \ + int max_split_cnt, \ + int cp_size, \ + int cp_rank, \ + int query_num, \ + int head_num, \ + float exp_scale, \ + cudaStream_t stream); + +INSTANTIATE_invokeReduceV2(64, half); +INSTANTIATE_invokeReduceV2(128, half); +INSTANTIATE_invokeReduceV2(192, half); #if ENABLE_BF16 -INSTANTIATE_invokeReduce(64, nv_bfloat16); -INSTANTIATE_invokeReduce(128, nv_bfloat16); -INSTANTIATE_invokeReduce(192, nv_bfloat16); +INSTANTIATE_invokeReduceV2(64, nv_bfloat16); +INSTANTIATE_invokeReduceV2(128, nv_bfloat16); +INSTANTIATE_invokeReduceV2(192, nv_bfloat16); #endif } // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/reduce.h b/src/turbomind/kernels/attention/reduce.h index c078de5958..53f40163e8 100644 --- a/src/turbomind/kernels/attention/reduce.h +++ b/src/turbomind/kernels/attention/reduce.h @@ -12,16 +12,16 @@ namespace turbomind::attention { template -void invokeReduce(T* out, - float* partial_M, - float* partial_L, - float* partial_O, - const int* split_cnt, - int partial_len, - int max_split_cnt, - int query_num, - int head_num, - float exp_scale, - cudaStream_t stream); - +void invokeReduceV2(T* out, + float* partial_ML, + float* partial_O, + const int* split_cnt, + int partial_len, + int max_split_cnt, + int cp_size, + int cp_rank, + int query_num, + int head_num, + float exp_scale, + cudaStream_t stream); } // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h deleted file mode 100644 index b4c9064cfe..0000000000 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "src/turbomind/kernels/attention/cta_map.h" -#include "src/turbomind/kernels/core/array_ops.h" -#include "src/turbomind/kernels/core/thread_map.h" -#include - -namespace turbomind::attention { - -template -struct Reduce { - using T = T_; - - static constexpr int CTA_H = CTA_H_; - static constexpr int CTA_K = CTA_K_; - static constexpr int kWarpCnt = WarpCnt; - - static_assert((CTA_K & (CTA_K - 1)) == 0, "must be pow of 2"); - - struct SharedStorage { - float scale[CTA_H][CTA_K]; - float O[CTA_H][WarpCnt][HeadDim]; - }; - - template - __device__ void operator()(T* out, - float* partial_M, - float* partial_L, - float* partial_O, - int query_idx, - int head_idx, - int head_num, - int hi_end, - int split_cnt, - int max_split_cnt, - float exp_scale, - int stride_k, - int offset_k, - SharedStorage& storage, - std::integral_constant) - { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - - // iterations per warp, K > 1 when CTA_K is multiple of WARP_SIZE - constexpr int K = (CTA_K + WARP_SIZE - 1) / WARP_SIZE; - // heads per warp iteration, M > 1 when WARP_SIZE is multiple of CTA_K - constexpr int M = (WARP_SIZE + CTA_K - 1) / CTA_K; - // lanes per head, a warp is processing M heads in parallel - constexpr int L = WARP_SIZE / M; - - PRAGMA_UNROLL - for (int h = 0; h < CTA_H; h += WarpCnt * M) { - - const int hi = h + warp_id * M + lane_id / L; - - Array frag_M; - Array frag_L; - - fill(frag_M, -std::numeric_limits::infinity()); - fill(frag_L, 0.f); - - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - const int si = (lane_id % L + k * L) * stride_k + offset_k; - const int idx = (query_idx * head_num + head_idx + hi) * max_split_cnt + si; - const bool mask = hi < hi_end && si < split_cnt; - if (mask) { - frag_M[k] = partial_M[idx]; - frag_L[k] = partial_L[idx]; - } - } - - float block_M = frag_M[0]; - PRAGMA_UNROLL - for (int k = 1; k < K; ++k) { - block_M = fmaxf(block_M, frag_M[k]); - } - - PRAGMA_UNROLL - for (int mask = L / 2; mask >= 1; mask /= 2) { - block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); - } - - Array expdiff_M; - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - expdiff_M[k] = exp2f((frag_M[k] - block_M) * exp_scale); - } - - float block_L{}; - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - block_L += expdiff_M[k] * frag_L[k]; - } - - PRAGMA_UNROLL - for (int mask = L / 2; mask >= 1; mask /= 2) { - block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); - } - - Array scale; - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - scale[k] = IsFinal ? expdiff_M[k] / block_L : expdiff_M[k]; - } - - if (hi < CTA_H) { - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - storage.scale[hi][lane_id % L + k * L] = scale[k]; - } - } - - if constexpr (!IsFinal) { - PRAGMA_UNROLL - for (int k = 0; k < K; ++k) { - const int si = (lane_id % L + k * L) * stride_k + offset_k; - const int idx = (query_idx * head_num + head_idx + hi) * max_split_cnt + si; - const bool mask = hi < hi_end && si < split_cnt; - if (mask) { - partial_M[idx] = block_M; - partial_L[idx] = block_L; - } - } - } - } - - __syncthreads(); - - // HeadDim / WARP_SIZE - // 128 -> 4 - // 64, 192 -> 2 - constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2; - - using Map = RakedThreadMap; - - static_assert(Map::kIterS == CTA_H); - - constexpr int S = Map::kIterS; - constexpr int C = Map::kIterC; - - using Vec = Array; - - Vec accu_O[S][C]{}; - Vec frag_O[S][C]; - - const int2 d = Map::get_offset(warp_id, lane_id); - - auto for_each = [&](auto fn) { - PRAGMA_UNROLL - for (int s = 0; s < S; ++s) { - const int si = d.y + s * Map::kDeltaS; - const int hi = si % CTA_H; - const int ki = si / CTA_H; - PRAGMA_UNROLL - for (int c = 0; c < C; ++c) { - const int di = d.x + c * Map::kDeltaC; - fn(s, c, ki, hi, di); - } - } - }; - - PRAGMA_UNROLL - for (int k = 0; k < CTA_K; k += WarpCnt) { - for_each([&](int s, int c, int ki, int hi, int di) { - using namespace ops; - ki += k; - const int split_idx = offset_k + stride_k * ki; - const bool mask = split_idx < split_cnt && hi < hi_end; - const int offset = ((query_idx * head_num + head_idx + hi) * max_split_cnt + split_idx) * HeadDim + di; - if (mask) { - Load(frag_O[s][c], &partial_O[offset]); - accu_O[s][c] = accu_O[s][c] + frag_O[s][c] * storage.scale[hi][ki]; - } - }); - } - - for_each([&](int s, int c, int ki, int hi, int di) { - Store(&storage.O[hi][ki][di], accu_O[s][c]); // - }); - - PRAGMA_UNROLL - for (int w = WarpCnt / 2; w > 0; w /= 2) { - __syncthreads(); - for_each([&](int s, int c, int ki, int hi, int di) { - using namespace ops; - if (ki < w) { - (Vec&)storage.O[hi][ki][di] = (Vec&)storage.O[hi][ki][di] + (Vec&)storage.O[hi][w + ki][di]; - } - }); - } - - for_each([&](int s, int c, int ki, int hi, int di) { - if (ki == 0 && hi < hi_end) { - if constexpr (IsFinal) { - const int offset = (query_idx * head_num + head_idx + hi) * HeadDim + di; - Store(&out[offset], cast((Vec&)storage.O[hi][ki][di])); - } - else { - const int offset = - ((query_idx * head_num + head_idx + hi) * max_split_cnt + offset_k) * HeadDim + di; - Store(&partial_O[offset], (Vec&)storage.O[hi][ki][di]); - } - } - }); - } -}; - -template -__global__ void reduce_kernel(typename Reduce::T* out, - float* partial_M, - float* partial_L, - float* partial_O, - int* signals, - const int* split_cnt_, - int max_split_cnt, - int head_num, - float exp_scale, - int stride_k) -{ - extern __shared__ char smem[]; - - const int head_idx = ReduceCtaMap::head_idx(); - const int query_idx = ReduceCtaMap::query_idx(); - const int chunk_idx = ReduceCtaMap::split_idx(); - - const int split_cnt = split_cnt_[query_idx]; - - const int chunk_offset = chunk_idx * stride_k * Reduce::CTA_K; - - if (chunk_offset >= split_cnt) { // out of bound - return; - } - - Reduce reduce{}; - reduce(out, - partial_M, - partial_L, - partial_O, - query_idx, - head_idx, - head_num, - 1, // hi_end - split_cnt, - max_split_cnt, - exp_scale, - stride_k, - chunk_offset, - *(typename Reduce::SharedStorage*)smem, - std::integral_constant{}); -} - -} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 235f9a3388..f07fe273c5 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -317,19 +317,15 @@ int test_attention() thrust::universal_vector cu_seqlens(kBatchSize + 1); thrust::universal_vector cu_kv_lens(kBatchSize + 1); - thrust::device_vector partial_M(kTokenNum * kHeadNum * kMaxSplitK); - thrust::device_vector partial_L(kTokenNum * kHeadNum * kMaxSplitK); + thrust::device_vector partial_ML(kTokenNum * kHeadNum * kMaxSplitK * 2); thrust::device_vector partial_O(kTokenNum * kHeadNum * kMaxSplitK * kHeadDim); thrust::device_vector split_cnt(kTokenNum); - thrust::device_vector semaphores(kTokenNum * kHeadNum * kMaxSplitK); thrust::universal_vector qk_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen); thrust::universal_vector pr_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen); thrust::universal_vector sinks(kHeadNum); - thrust::fill(semaphores.begin(), semaphores.end(), 0); - rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f); rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim); @@ -447,11 +443,9 @@ int test_attention() float scale_factor = -std::log2f(kRoPEBase) / kRoPEDim; params.rope_param = RopeKernelParam{RopeType::kDefault, nullptr, kRoPEDim, scale_factor, 1.f}; - params.split_cnt = split_cnt.data().get(); - params.partial_L = partial_L.data().get(); - params.partial_M = partial_M.data().get(); - params.partial_O = partial_O.data().get(); - params.locks = semaphores.data().get(); + params.split_cnt = split_cnt.data().get(); + params.partial_ML = partial_ML.data().get(); + params.partial_O = partial_O.data().get(); params.max_split_k = kMaxSplitK; params.arch = getSMVersion(); diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 5cc55af4db..70673511a5 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -823,16 +823,19 @@ void LlamaBatch::AllocSymmBuffers() const ssize_t vocab_size_padded = model_->vocab_size_padded_; // Native comm fuses allreduce & rmsnorm in token granularity - TM_CHECK(max_forward_token_num_ % tp_size_ == 0); + TM_CHECK(max_forward_token_num_ % tp_size_ == 0) << max_forward_token_num_ << " vs " << tp_size_; symm_hidden_states_buf_ = {{max_forward_token_num_ * param_.attn_dp_size, hidden_units}, data_type_, symm_alloc_}; symm_logits_buf_ = {{max_batch_size_, vocab_size_padded}, data_type_, symm_alloc_}; - if (param_.attn_cp_size > 1) { - // prefill(cp, q, h, 1, 2), decode(cp, q, h, k, 2) - const int cp_workspace_tokens = UnifiedAttentionLayer::kMaxWorkspaceTokens + max_forward_token_num_; - symm_cp_ML_ = {{param_.attn_cp_size, cp_workspace_tokens, (int)model_->local_head_num_, 2}, symm_alloc_}; - } + // for context parallel, we use symm_alloc_ and both prefill and decode stage have reduce process + // w/o context parallel, we use common alloc and only decode stage has reduce process + // perhaps it would be more appropriate to put this buffer in the unified_attention_layer. + Allocator alloc = param_.attn_cp_size > 1 ? symm_alloc_ : core::Context::alloc(kDEVICE); + const ssize_t attn_ws_tokens = param_.attn_cp_size > 1 ? + UnifiedAttentionLayer::kMaxWorkspaceTokens + max_forward_token_num_ : + UnifiedAttentionLayer::kMaxWorkspaceTokens; + symm_partial_ML_ = {{param_.attn_cp_size, attn_ws_tokens, (int)model_->local_head_num_, 2}, alloc}; } void LlamaBatch::FreeSymmBuffers() @@ -840,7 +843,7 @@ void LlamaBatch::FreeSymmBuffers() symm_hidden_states_buf_ = {}; symm_logits_buf_ = {}; - symm_cp_ML_ = {}; + symm_partial_ML_ = {}; } LlamaBatch::~LlamaBatch() @@ -1581,7 +1584,7 @@ bool LlamaBatch::Forward(GenerationState& g) state_->h_context_length.slice(first, mini_batch_size), rope_theta_.slice(first, mini_batch_size), &mrope, - symm_cp_ML_, + symm_partial_ML_, finished_buf_.slice(first, mini_batch_size), Buffer(local_token_nums.data(), local_token_nums.size(), kCPU), lora_mask_buf_, @@ -1774,7 +1777,7 @@ void LlamaBatch::Warmup() Buffer{&input_length, 1, kCPU}, rope_theta_.slice(0, bsz), nullptr, // mrope - symm_cp_ML_, + symm_partial_ML_, finished_buf_.slice(0, bsz), Buffer{local_token_nums.data(), (int)local_token_nums.size(), kCPU}, Buffer{}, diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 23bfdfc7a9..55386c9aff 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -246,7 +246,7 @@ class LlamaBatch { Tensor symm_logits_buf_; // context parallel - Tensor_ symm_cp_ML_; + Tensor_ symm_partial_ML_; Tensor decoder_output_buf_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 247313fa06..1adc6c8409 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -163,7 +163,7 @@ void LlamaV2::Forward(Buffer_ input_ids, Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, - Tensor cp_ML, + Tensor partial_ML, Buffer finished, Buffer local_token_nums, Buffer lora_mask, @@ -259,7 +259,7 @@ void LlamaV2::Forward(Buffer_ input_ids, {"decode_num", Buffer{&decode_num, 1, kCPU}}, {"prefil_num", Buffer{&prefil_num, 1, kCPU}}, {"rope_base", rope_base}, - {"cp_ML", cp_ML}, + {"partial_ML", partial_ML}, {"cu_block_nums", cu_block_nums}, {"kv_block_ptrs", kv_block_ptrs}, {"local_token_nums", local_token_nums}}; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 9283d5967b..7d77db0812 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -69,7 +69,7 @@ class LlamaV2 { Buffer_ h_context_length, Buffer rope_base, MropeRope* mrope, - Tensor cp_ML, + Tensor partial_ML, Buffer finished, Buffer local_token_nums, Buffer lora_mask, diff --git a/src/turbomind/models/llama/cp_utils.cu b/src/turbomind/models/llama/cp_utils.cu index 530c211e12..6b56e7f10f 100644 --- a/src/turbomind/models/llama/cp_utils.cu +++ b/src/turbomind/models/llama/cp_utils.cu @@ -1,351 +1,20 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include "src/turbomind/kernels/core/array.h" -#include "src/turbomind/kernels/core/array_ops.h" #include "src/turbomind/models/llama/cp_utils.h" -#include "src/turbomind/models/llama/llama_utils.h" namespace turbomind { -int next_power_of_two(int v) +void CpPost(void* context) { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v++; - return v; -} - -template -__global__ void ReduceOutput(T* out, // - float* partial_O, - float* cp_ML, // q, h, k, 2 - cutlass::FastDivmod num_heads, - int* split_cnt_, - int max_split_cnt, - int total, - int stride_k, - int offset_k, - float exp_scale) -{ - __shared__ float s_out[WarpCnt][HeadDim]; - - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - - // warp_id, q, h - const int qh = blockIdx.x * M + warp_id % M; - int q, h; - q = num_heads.divmod(h, qh); - - if (q * num_heads + h >= total) { - return; - } - - offset_k *= blockIdx.y; - const int split_cnt = (split_cnt_ != nullptr) ? max(split_cnt_[q], 1) : 1; - if (offset_k >= split_cnt) { - return; - } - - auto get_scale = [&](int q, int h, int ki) { // q, h, k, 2 - int index = ((q * num_heads + h) * max_split_cnt + ki) * 2; - return cp_ML[index]; - }; - - // HeadDim / WARP_SIZE - // 128 -> 4 - // 64, 192 -> 2 - constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2; - constexpr int iterC = HeadDim / (WARP_SIZE * kVecSize); - - using namespace ops; - using VecF = Array; - using VecT = Array; - - // in most cases,no split_k - if constexpr (N == 1) { - VecT frag_O; - float scale = get_scale(q, h, 0); - - PRAGMA_UNROLL - for (int c = 0; c < iterC; ++c) { - Load(frag_O, &out[(q * num_heads + h) * HeadDim + lane_id * kVecSize + c * WARP_SIZE * kVecSize]); - frag_O = cast(cast(frag_O) * scale); - Store(&out[(q * num_heads + h) * HeadDim + lane_id * kVecSize + c * WARP_SIZE * kVecSize], frag_O); - } - - return; - } - - VecF accu_O[iterC]{}; - VecF frag_O[iterC]; - - PRAGMA_UNROLL - for (int k = 0; k < N; k += WarpCnt / M) { - const int ki = (warp_id / M + k) * stride_k + offset_k; - const int base = (((q * num_heads + h) * max_split_cnt + ki) * HeadDim); // q, h, k, d - - if (ki < split_cnt) { - float scale = (stride_k == 1) ? get_scale(q, h, ki) : 1.0f; - - PRAGMA_UNROLL - for (int c = 0; c < iterC; ++c) { - const int index = base + lane_id * kVecSize + c * WARP_SIZE * kVecSize; - Load(frag_O[c], &partial_O[index]); - accu_O[c] = accu_O[c] + frag_O[c] * scale; - } - } - } - - PRAGMA_UNROLL - for (int c = 0; c < iterC; ++c) { - Store(&s_out[warp_id][c * WARP_SIZE * kVecSize + lane_id * kVecSize], accu_O[c]); - } - - // PRAGMA_UNROLL - // for (int w = WarpCnt / 2 / M; w > 0; w /= 2) { - // const int ki = warp_id / M; - // __syncthreads(); - // if (ki < w) { - // PRAGMA_UNROLL - // for (int c = 0; c < iterC; ++c) { - // const int index = c * WARP_SIZE * kVecSize + lane_id * kVecSize; - // (VecF&)s_out[warp_id][index] = (VecF&)s_out[warp_id][index] + (VecF&)s_out[warp_id + w * M][index]; - // } - // } - // } - - __syncthreads(); - if (warp_id / M == 0) { - PRAGMA_UNROLL - for (int k = 1; k < WarpCnt / M; ++k) { - for (int c = 0; c < iterC; ++c) { - const int index = c * WARP_SIZE * kVecSize + lane_id * kVecSize; - (VecF&)s_out[warp_id][index] = (VecF&)s_out[warp_id][index] + (VecF&)s_out[warp_id + k * M][index]; - } - } - } - - if (warp_id / M == 0) { - const int base = gridDim.y == 1 ? (q * num_heads + h) * HeadDim : - (((q * num_heads + h) * max_split_cnt + offset_k) * HeadDim); - PRAGMA_UNROLL - for (int c = 0; c < iterC; ++c) { - const int off = c * WARP_SIZE * kVecSize + lane_id * kVecSize; - if (gridDim.y == 1) { - Store(&out[base + off], cast((VecF&)s_out[warp_id][off])); - } - else { - Store(&partial_O[base + off], (VecF&)s_out[warp_id][off]); - } - } - } -} - -template -void invokeReduceOutput(CpPostContext* ctx, AttentionParams* params, int split_cnt) -{ - constexpr int MaxN = 32; - - int split_k = split_cnt; - int stride_k = 1; - int offset_k = 1; - - cutlass::FastDivmod num_heads = cutlass::FastDivmod(params->num_heads); - - auto invoke = [&](auto n, auto head_dim) { - constexpr int WarpCnt = 4; - constexpr int M = (WarpCnt + n - 1) / n; // item per block, 1, 2, 4 - const int total = params->token_num * params->num_heads; - - const dim3 block(WarpCnt * WARP_SIZE); - const dim3 grid((total + M - 1) / M, (split_k + n - 1) / n); - const int shm_size = WarpCnt * sizeof(float) * head_dim; - ReduceOutput<<stream>>>( // - params->out + params->cp_q_offset * params->num_heads * params->size_per_head, - params->partial_O, - ctx->cp_ML + params->cp_rank * params->token_num * params->num_heads * params->max_split_k * 2, - num_heads, - split_cnt > 1 ? params->split_cnt : nullptr, - params->max_split_k, - total, - stride_k, - offset_k * n, - params->inv_sqrt_dh); - - sync_check_cuda_error(); - - stride_k *= n; - offset_k *= n; - split_k = (split_k + n - 1) / n; - }; - - auto dispatch_n = [&](int split_k, auto head_dim) { - int n = min(next_power_of_two(split_k), MaxN); - - switch (n) { - case 1: - return invoke(std::integral_constant{}, head_dim); - case 2: - return invoke(std::integral_constant{}, head_dim); - case 4: - return invoke(std::integral_constant{}, head_dim); - case 8: - return invoke(std::integral_constant{}, head_dim); - case 16: - return invoke(std::integral_constant{}, head_dim); - case 32: - return invoke(std::integral_constant{}, head_dim); - default: - TM_CHECK(0); - } - }; - - auto dispatch_head_dim = [&](int split_k) { - switch (params->size_per_head) { - case 64: - return dispatch_n(split_k, std::integral_constant{}); - case 128: - return dispatch_n(split_k, std::integral_constant{}); - case 192: - return dispatch_n(split_k, std::integral_constant{}); - default: - TM_CHECK(0); - } - }; - - dispatch_head_dim(split_k); - while (split_k > 1) { - dispatch_head_dim(split_k); - } -} - -template -__global__ void ReduceScale(float* cp_ML, // cp, q, h, k, 2 - int num_tokens, - cutlass::FastDivmod num_heads, - int* split_cnt_, - int max_split_cnt, - int cp_size, - int cp_rank, - float exp_scale) -{ - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - - int qh = blockIdx.x * WarpCnt + warp_id; - int q = num_heads.div(qh); - - if (q >= num_tokens) { - return; - } - - float frag_M0 = -std::numeric_limits::infinity(); - float frag_L0 = 0.0f; - - const int split_per_rank = (split_cnt_ == nullptr) ? 1 : max(split_cnt_[q], 1); - const int split_all_rank = split_per_rank * cp_size; - - int split_i, split_k; - for (int i = lane_id; i < split_all_rank; i += WARP_SIZE) { - split_i = i / split_per_rank; - split_k = i % split_per_rank; - int index = (split_i * num_tokens * num_heads + qh) * max_split_cnt + split_k; - - float frag_M1 = cp_ML[index * 2]; - float frag_L1 = cp_ML[index * 2 + 1]; - float frag_M = fmaxf(frag_M0, frag_M1); - - frag_L1 = (frag_M1 == -std::numeric_limits::infinity()) ? - 0.0f : - exp2f((frag_M1 - frag_M) * exp_scale) * frag_L1; - frag_L0 = (frag_M0 == -std::numeric_limits::infinity()) ? - 0.0f : - exp2f((frag_M0 - frag_M) * exp_scale) * frag_L0; - - frag_L0 = frag_L1 + frag_L0; - frag_M0 = frag_M; - } - - float block_M = frag_M0; - PRAGMA_UNROLL - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask)); - } - - float block_L = - (frag_M0 == -std::numeric_limits::infinity()) ? 0.0f : exp2f((frag_M0 - block_M) * exp_scale) * frag_L0; - - PRAGMA_UNROLL - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask); - } - - for (int i = lane_id; i < split_per_rank; i += WARP_SIZE) { - split_k = i % split_per_rank; - int index = (cp_rank * num_tokens * num_heads + qh) * max_split_cnt + split_k; - - float frag_M1 = cp_ML[index * 2]; - float scale = (frag_M1 == -std::numeric_limits::infinity()) ? - 0.0f : - exp2f((frag_M1 - block_M) * exp_scale) / block_L; - cp_ML[index * 2] = scale; // save to M - } -} - -template -void invokeReduceScale(CpPostContext* ctx, AttentionParams* params, int split_cnt) -{ - constexpr int WarpCnt = 4; // each warp process one token - const dim3 block(WarpCnt * WARP_SIZE); - const dim3 grid((params->token_num * params->num_heads + WarpCnt - 1) / WarpCnt); - - ReduceScale<<stream>>>( // - ctx->cp_ML, - params->token_num, - cutlass::FastDivmod(params->num_heads), - split_cnt > 1 ? params->split_cnt : nullptr, - params->max_split_k, - params->cp_size, - params->cp_rank, - params->inv_sqrt_dh); - - sync_check_cuda_error(); -} - -template -void CpReduce(CpPostContext* ctx, AttentionParams* params, int split_cnt) -{ - NvtxScope scope("CpReduce"); + auto ctx = reinterpret_cast(context); - const int count = params->token_num * params->num_heads * params->max_split_k * 2; - ctx->d_comm->AllGather(ctx->cp_ML + params->cp_rank * count, // - ctx->cp_ML, - count, + ctx->d_comm->AllGather(ctx->partial_ML + ctx->cp_rank * ctx->count, // + ctx->partial_ML, + ctx->count, DataType::kFloat, ctx->attn_cp_group, - params->stream); + ctx->stream); sync_check_cuda_error(); - - invokeReduceScale(ctx, params, split_cnt); - - invokeReduceOutput(ctx, params, split_cnt); -} - -void CpPost(void* context, int split_cnt) -{ - auto ctx = reinterpret_cast(context); - - auto invoke = [&](auto t) { - using T = decltype(t); - CpReduce(ctx, static_cast*>(ctx->attn_param), split_cnt); - }; - - TM_DISPATCH_PRIMARY_DTYPES(ctx->attn_type, invoke); } } // namespace turbomind diff --git a/src/turbomind/models/llama/cp_utils.h b/src/turbomind/models/llama/cp_utils.h index f1389089d7..ae94112ada 100644 --- a/src/turbomind/models/llama/cp_utils.h +++ b/src/turbomind/models/llama/cp_utils.h @@ -1,8 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "src/turbomind/comm/device_comm.h" -#include "src/turbomind/core/core.h" -#include "src/turbomind/kernels/attention/attention_params.h" +#include "src/turbomind/utils/cuda_utils.h" namespace turbomind { @@ -13,11 +12,12 @@ struct CpPostContext { comm::DeviceCommImpl* d_comm; int attn_cp_group; - float* cp_ML; - void* attn_param; - DataType attn_type; + int cp_rank; + int count; + float* partial_ML; + cudaStream_t stream; }; -void CpPost(void* context, int split_cnt); +void CpPost(void* context); } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index d68c2cce22..c987f242e5 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -92,14 +92,17 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, init_rope_kernel_param(param_.rope, rope_param_); - partial_M_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); - partial_L_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); - partial_O_ = Tensor_({kMaxWorkspaceTokens, local_head_num_, size_per_head_}, kDEVICE); + // partial_O layout: + // w/ cp, decode(q, h, k, 2) + prefill(q, h, 1, 2) + // w/o cp, decode(q, h, k, 2) + const ssize_t attn_ws_tokens = engine_param_.attn_cp_size > 1 ? + kMaxWorkspaceTokens + engine_param_.max_forward_token_num : + kMaxWorkspaceTokens; + + partial_O_ = Tensor_({attn_ws_tokens, local_head_num_, size_per_head_}, kDEVICE); split_cnt_ = Tensor_({kMaxWorkspaceTokens}, kDEVICE); - barriers_ = Tensor_({kMaxWorkspaceTokens, local_head_num_}, kDEVICE); Clear(split_cnt_.buffer()); - Clear(barriers_.buffer()); const auto max_batch_size = engine.max_batch_size; @@ -138,9 +141,7 @@ void UnifiedAttentionLayer::Initialize(TensorMap& args) cu_block_nums_ = args.at("cu_block_nums").buffer(); kv_block_ptrs_ = args.at("kv_block_ptrs").buffer(); - if (engine_param_.attn_cp_size > 1) { - cp_ML_ = args.at("cp_ML").borrow(); - } + partial_ML_ = args.at("partial_ML").borrow(); // rotary embedding, add offest when forward if (rope_param_.type == RopeType::kDynamic) { @@ -327,10 +328,8 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, // Decoding use only for now params.split_cnt = split_cnt_.data(); - params.partial_L = partial_L_.data(); - params.partial_M = partial_M_.data(); + params.partial_ML = partial_ML_.data(); params.partial_O = partial_O_.data(); - params.locks = barriers_.data(); params.max_split_k = std::min(std::max(1, kMaxWorkspaceTokens / params.token_num), max_kv_splits); // context parallel @@ -339,20 +338,23 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, if (params.cp_size > 1) { params.cp_size = cutlass::FastDivmod(params.cp_size); - const int offset_stage = + // update ML,O offset if both prefill and decode present + const int offset_ML_stage = engine_param_.attn_cp_size * (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0); - const int offset_rank = params.cp_rank * params.token_num * local_head_num_ * params.max_split_k * 2; + const int offset_ML_rank = params.cp_rank * params.token_num * local_head_num_ * params.max_split_k * 2; + const int offset_O = offset ? kMaxWorkspaceTokens * local_head_num_ * size_per_head_ : 0; - params.cp_ML = cp_ML_.data() + offset_stage + offset_rank; // (cp, q, h, k, 2) - params.cp_q_offset = offset; + params.partial_ML = partial_ML_.data() + offset_ML_stage + offset_ML_rank; + params.partial_O = partial_O_.data() + offset_O; + params.offset_q = offset; // postprocess func - params.cp_fn = CpPost; - params.cp_fn_ctx = (void*)&cp_fn_ctx_; - - cp_fn_ctx_.cp_ML = cp_ML_.data() + offset_stage; - cp_fn_ctx_.attn_param = (void*)¶ms; - cp_fn_ctx_.attn_type = attn.dtype(); + params.cp_fn = CpPost; + params.cp_fn_ctx = (void*)&cp_fn_ctx_; + cp_fn_ctx_.cp_rank = params.cp_rank; + cp_fn_ctx_.count = params.token_num * local_head_num_ * params.max_split_k * 2; + cp_fn_ctx_.partial_ML = partial_ML_.data() + offset_ML_stage; + cp_fn_ctx_.stream = stream; } params.arch = arch_; diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index c058e79f28..06b0c02531 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -112,15 +112,12 @@ class UnifiedAttentionLayer { int decode_num_; int prefil_num_; - Tensor_ partial_M_; - Tensor_ partial_L_; + Tensor_ partial_ML_; Tensor_ partial_O_; Tensor_ split_cnt_; - Tensor_ barriers_; // always zero // context parallel - Tensor_ cp_ML_; // cp, (d+p), h, k, 2 - CpPostContext cp_fn_ctx_; + CpPostContext cp_fn_ctx_; Event event_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 18e245966c..0d88626088 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -387,7 +387,7 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ engine_param_.devices = engine_reader["devices"].as>(); { - auto tp = engine_param_.attn_tp_size; + auto tp = engine_param_.attn_tp_size * engine_param_.attn_cp_size; engine_param_.max_forward_token_num = ((size_t)max_forward_token_num + tp - 1) / tp * tp; } From 4211c3c0372a55d91c94b4608e9cacc0c5b4a86d Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 7 Nov 2025 17:58:33 +0000 Subject: [PATCH 31/31] fix nccl found --- src/turbomind/comm/nccl/nccl.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/turbomind/comm/nccl/nccl.cu b/src/turbomind/comm/nccl/nccl.cu index 557c63e9a9..af4faf29aa 100644 --- a/src/turbomind/comm/nccl/nccl.cu +++ b/src/turbomind/comm/nccl/nccl.cu @@ -65,7 +65,7 @@ static NcclApis& nccl_apis() static auto value = [] { int version{}; ncclGetVersion(&version); - auto handle = dlopen(nullptr, RTLD_LAZY); + auto handle = dlopen("libnccl.so.2", RTLD_LAZY); NcclApis apis{}; if (!handle) { return apis; @@ -222,11 +222,10 @@ public: int Split(int color, int key, int group) override { - // auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit); + auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit); ncclComm_t comm{}; - // NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr)); - NCCLCHECK(ncclCommSplit(groups_.at(group), color, key, &comm, nullptr)); + NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr)); int index = groups_.size(); groups_.push_back(comm);