Skip to content

Commit 9a39920

Browse files
committed
draft for prealloc cpp kvcache transceiver
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent ce7a42f commit 9a39920

File tree

13 files changed

+142
-26
lines changed

13 files changed

+142
-26
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ class BaseCacheTransceiver
209209
[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;
210210

211211
virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
212+
213+
virtual void prepareContextRequest(LlmRequest* llmRequest) = 0;
212214
};
213215

214216
class CacheTransceiver : public BaseCacheTransceiver
@@ -251,6 +253,8 @@ class CacheTransceiver : public BaseCacheTransceiver
251253

252254
virtual bool cancelRequest(LlmRequest* llmRequest) override;
253255

256+
void prepareContextRequest(LlmRequest* llmRequest) override;
257+
254258
private:
255259
void initializeCommState();
256260

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ enum class LlmRequestState : int32_t
4949
kUNKNOWN = 0, ///< Unknown state
5050
kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
5151

52+
kDISAGG_CONTEXT_WAIT_SCHEDULE = 7, ///< Context-only request waiting for scheduling
5253
kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model
5354
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache
5455

@@ -65,6 +66,7 @@ enum class LlmRequestState : int32_t
6566
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
6667
/// after computation finished
6768
kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission.
69+
kDISAGG_GENERATION_WAIT_TOKENS = 23, ///< Generation-only request waiting for context tokens to be received.
6870

6971
// error states
7072
kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,4 +695,13 @@ bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest)
695695
return false;
696696
}
697697

698+
void CacheTransceiver::prepareContextRequest(LlmRequest* llmRequest)
699+
{
700+
if (llmRequest->isContextOnlyRequest() && llmRequest->getState() == LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULE
701+
&& mCacheSender->checkContextRequestReady(*llmRequest))
702+
{
703+
llmRequest->setState(LlmRequestState::kCONTEXT_INIT);
704+
}
705+
}
706+
698707
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ class CacheSender::Impl
306306
std::scoped_lock lkResp(mSenderMutex);
307307
mReadyResponses.emplace(
308308
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
309+
// if the request is already in the pending queue, submit a send request to ready queue
310+
auto it = mPendingRequests.find(llmRequest.mRequestId);
311+
if (it != mPendingRequests.end())
312+
{
313+
mReadyPendingRequests.push(std::move(it->second));
314+
mPendingRequests.erase(it);
315+
}
309316
}
310317
std::unique_lock lkCond(mCondMutex);
311318
mAnyReady = true;
@@ -353,6 +360,17 @@ class CacheSender::Impl
353360

354361
[[nodiscard]] RequestInfo recvRequestInfo()
355362
{
363+
// if there is a pending request in the ready queue, respond to it first
364+
{
365+
std::scoped_lock lk(mSenderMutex);
366+
if (!mReadyPendingRequests.empty())
367+
{
368+
auto info = std::move(mReadyPendingRequests.front());
369+
mReadyPendingRequests.pop();
370+
return info;
371+
}
372+
}
373+
356374
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
357375
bool isAgent = agentConnectionManager != nullptr;
358376

@@ -619,14 +637,14 @@ class CacheSender::Impl
619637
{
620638
break;
621639
}
640+
auto const& requestInfo = recvRequestInfo();
641+
auto reqId = requestInfo.getRequestId();
622642
if (!mReadyResponses.empty())
623643
{
624-
auto const& requestInfo = recvRequestInfo();
625644
if (mTerminate || !mManager->isRunning())
626645
{
627646
return;
628647
}
629-
auto reqId = requestInfo.getRequestId();
630648

631649
{
632650
std::scoped_lock lk(mSenderMutex);
@@ -638,26 +656,11 @@ class CacheSender::Impl
638656
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
639657
}
640658
}
641-
auto it = getCurrentResponse();
659+
auto it = getReadyResponse(requestInfo);
642660
if (it != mReadyResponses.end())
643661
{
644662
sendResponse(it);
645663
}
646-
else
647-
{
648-
auto it = getCurrentResponse();
649-
while (it == mReadyResponses.end())
650-
{
651-
std::unique_lock lk(mCondMutex);
652-
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
653-
if (mTerminate)
654-
{
655-
break;
656-
}
657-
it = getCurrentResponse();
658-
}
659-
sendResponse(it);
660-
}
661664
}
662665
}
663666
catch (std::exception const& err)
@@ -692,6 +695,7 @@ class CacheSender::Impl
692695
{
693696
std::scoped_lock lkResp(mSenderMutex);
694697
mReadyResponses.erase(it);
698+
mPendingRequests.erase(it->first);
695699
}
696700
if (mReadyResponses.empty())
697701
{
@@ -705,10 +709,29 @@ class CacheSender::Impl
705709
return mCurrentRequest.value();
706710
}
707711

708-
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
712+
[[nodiscard]] std::map<RequestIdType, Response>::iterator getReadyResponse(RequestInfo const& requestInfo)
709713
{
710714
std::scoped_lock lk(mSenderMutex);
711-
return mReadyResponses.find(getCurrentRequestId());
715+
auto reqId = requestInfo.getRequestId();
716+
auto it = mReadyResponses.find(reqId);
717+
if (it != mReadyResponses.end())
718+
{
719+
return it;
720+
}
721+
else
722+
{
723+
// If a request is received but response is not ready, stash it in the pending map to send it later
724+
TLLM_LOG_INFO("No ready response found for request %zu", reqId);
725+
mPendingRequests[reqId] = requestInfo;
726+
}
727+
return mReadyResponses.end();
728+
}
729+
730+
bool checkContextRequestReady(LlmRequest const& llmRequest)
731+
{
732+
std::scoped_lock lk(mSenderMutex);
733+
auto it = mPendingRequests.find(llmRequest.mRequestId);
734+
return it != mPendingRequests.end();
712735
}
713736

714737
private:
@@ -723,6 +746,8 @@ class CacheSender::Impl
723746
AsyncSendResource mAsyncSendResource;
724747
std::vector<std::future<void>> mAsyncSendFutures;
725748
int mDeviceId{-1};
749+
std::unordered_map<LlmRequest::RequestIdType, RequestInfo> mPendingRequests;
750+
std::queue<RequestInfo> mReadyPendingRequests;
726751

727752
executor::kv_cache::ConnectionManager* mManager;
728753
std::map<LlmRequest::RequestIdType, TransferSession> mRequestToSession;
@@ -1189,6 +1214,11 @@ void CacheSender::sendReadySignal(LlmRequest::RequestIdType requestId, bool isRe
11891214
mImpl->sendReadySignal(requestId, isReady);
11901215
}
11911216

1217+
void CacheSender::checkContextRequestReady(LlmRequest const& llmRequest)
1218+
{
1219+
return mImpl->checkContextRequestReady(llmRequest);
1220+
}
1221+
11921222
CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
11931223
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
11941224
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}

cpp/tensorrt_llm/batch_manager/dataTransceiver.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ class CacheSender
272272
/// @param isReady Whether the request is ready to be received.
273273
virtual void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady);
274274

275+
/// @brief Check if the context request is ready to be received.
276+
virtual bool checkContextRequestReady(LlmRequest const& llmRequest);
277+
275278
/// @brief Destructor.
276279
virtual ~CacheSender();
277280

cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
7979
{
8080
NB_OVERRIDE_PURE(cancelRequest, llmRequest);
8181
}
82+
83+
void prepareContextRequest(tb::LlmRequest* llmRequest) override
84+
{
85+
NB_OVERRIDE_PURE(prepareContextRequest, llmRequest);
86+
}
8287
};
8388
} // namespace
8489

@@ -93,7 +98,8 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
9398
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
9499
nb::call_guard<nb::gil_scoped_release>())
95100
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
96-
.def("cancel_request", &BaseCacheTransceiver::cancelRequest);
101+
.def("cancel_request", &BaseCacheTransceiver::cancelRequest)
102+
.def("prepare_context_request", &BaseCacheTransceiver::prepareContextRequest);
97103

98104
nb::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
99105
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
481481
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
482482
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
483483
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
484+
.value("DISAGG_CONTEXT_WAIT_SCHEDULE", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULE)
485+
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS)
484486
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR);
485487

486488
nb::class_<tr::MemoryCounters>(m, "MemoryCounters")

cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
7575
{
7676
PYBIND11_OVERLOAD_PURE(bool, tb::BaseCacheTransceiver, cancelRequest, llmRequest);
7777
}
78+
79+
void prepareContextRequest(tb::LlmRequest* llmRequest) override
80+
{
81+
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, prepareContextRequest, llmRequest);
82+
}
7883
};
7984
} // namespace
8085

@@ -89,7 +94,8 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
8994
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
9095
py::call_guard<py::gil_scoped_release>())
9196
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
92-
.def("cancel_request", &BaseCacheTransceiver::cancelRequest);
97+
.def("cancel_request", &BaseCacheTransceiver::cancelRequest)
98+
.def("prepare_context_request", &BaseCacheTransceiver::prepareContextRequest);
9399

94100
py::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
95101
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)

cpp/tensorrt_llm/pybind/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
470470
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
471471
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
472472
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
473-
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
473+
.value("DISAGG_CONTEXT_WAIT_SCHEDULE", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULE)
474+
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);
474475

475476
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
476477
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)

docker/Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ CCACHE_DIR ?= $(CODE_DIR)/cpp/.ccache
149149
CONAN_DIR ?= $(CODE_DIR)/cpp/.conan
150150
USER_CACHE_DIR ?= $(shell readlink -f "${HOME_DIR}/.cache")
151151
RUN_CMD ?=
152-
CONTAINER_NAME ?= tensorrt_llm
152+
CONTAINER_NAME ?= tensorrt_llm_$(git branch --show-current)
153153
WORK_DIR ?= $(CODE_DIR)
154154
DOCKER_PULL ?= 0
155155

@@ -167,7 +167,6 @@ endif
167167
$(GPU_OPTS) \
168168
--volume $(SOURCE_DIR):$(CODE_DIR) \
169169
$(EXTRA_VOLUMES) \
170-
$(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \
171170
--env "CCACHE_DIR=$(CCACHE_DIR)" \
172171
--env "CCACHE_BASEDIR=$(CODE_DIR)" \
173172
--env "CONAN_HOME=$(CONAN_DIR)" \

0 commit comments

Comments
 (0)