Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class BaseCacheTransceiver
[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;

virtual bool cancelRequest(LlmRequest* llmRequest) = 0;

virtual void prepareContextRequest(LlmRequest* llmRequest) = 0;
};

class CacheTransceiver : public BaseCacheTransceiver
Expand Down Expand Up @@ -251,6 +253,8 @@ class CacheTransceiver : public BaseCacheTransceiver

virtual bool cancelRequest(LlmRequest* llmRequest) override;

void prepareContextRequest(LlmRequest* llmRequest) override;

private:
void initializeCommState();

Expand Down
2 changes: 2 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ enum class LlmRequestState : int32_t
kUNKNOWN = 0, ///< Unknown state
kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)

kDISAGG_CONTEXT_WAIT_SCHEDULE = 7, ///< Context-only request waiting for scheduling
kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache

Expand All @@ -65,6 +66,7 @@ enum class LlmRequestState : int32_t
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
/// after computation finished
kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission.
kDISAGG_GENERATION_WAIT_TOKENS = 23, ///< Generation-only request waiting for context tokens to be received.

// error states
kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission
Expand Down
9 changes: 9 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,4 +695,13 @@ bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest)
return false;
}

void CacheTransceiver::prepareContextRequest(LlmRequest* llmRequest)
{
if (llmRequest->isContextOnlyRequest() && llmRequest->getState() == LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULE
&& mCacheSender->checkContextRequestReady(*llmRequest))
{
llmRequest->setState(LlmRequestState::kCONTEXT_INIT);
}
}

} // namespace tensorrt_llm::batch_manager
70 changes: 50 additions & 20 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ class CacheSender::Impl
std::scoped_lock lkResp(mSenderMutex);
mReadyResponses.emplace(
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
// if the request is already in the pending queue, submit a send request to ready queue
auto it = mPendingRequests.find(llmRequest.mRequestId);
if (it != mPendingRequests.end())
{
mReadyPendingRequests.push(std::move(it->second));
mPendingRequests.erase(it);
}
}
std::unique_lock lkCond(mCondMutex);
mAnyReady = true;
Expand Down Expand Up @@ -353,6 +360,17 @@ class CacheSender::Impl

[[nodiscard]] RequestInfo recvRequestInfo()
{
// if there is a pending request in the ready queue, respond to it first
{
std::scoped_lock lk(mSenderMutex);
if (!mReadyPendingRequests.empty())
{
auto info = std::move(mReadyPendingRequests.front());
mReadyPendingRequests.pop();
return info;
}
}

auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
bool isAgent = agentConnectionManager != nullptr;

Expand Down Expand Up @@ -619,14 +637,14 @@ class CacheSender::Impl
{
break;
}
auto const& requestInfo = recvRequestInfo();
auto reqId = requestInfo.getRequestId();
if (!mReadyResponses.empty())
{
auto const& requestInfo = recvRequestInfo();
if (mTerminate || !mManager->isRunning())
{
return;
}
auto reqId = requestInfo.getRequestId();

{
std::scoped_lock lk(mSenderMutex);
Expand All @@ -638,26 +656,11 @@ class CacheSender::Impl
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
}
}
auto it = getCurrentResponse();
auto it = getReadyResponse(requestInfo);
if (it != mReadyResponses.end())
{
sendResponse(it);
}
else
{
auto it = getCurrentResponse();
while (it == mReadyResponses.end())
{
std::unique_lock lk(mCondMutex);
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
if (mTerminate)
{
break;
}
it = getCurrentResponse();
}
sendResponse(it);
}
}
}
catch (std::exception const& err)
Expand Down Expand Up @@ -692,6 +695,7 @@ class CacheSender::Impl
{
std::scoped_lock lkResp(mSenderMutex);
mReadyResponses.erase(it);
mPendingRequests.erase(it->first);
}
if (mReadyResponses.empty())
{
Expand All @@ -705,10 +709,29 @@ class CacheSender::Impl
return mCurrentRequest.value();
}

[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
[[nodiscard]] std::map<RequestIdType, Response>::iterator getReadyResponse(RequestInfo const& requestInfo)
{
std::scoped_lock lk(mSenderMutex);
return mReadyResponses.find(getCurrentRequestId());
auto reqId = requestInfo.getRequestId();
auto it = mReadyResponses.find(reqId);
if (it != mReadyResponses.end())
{
return it;
}
else
{
// If a request is received but response is not ready, stash it in the pending map to send it later
TLLM_LOG_INFO("No ready response found for request %zu", reqId);
mPendingRequests[reqId] = requestInfo;
}
return mReadyResponses.end();
}

bool checkContextRequestReady(LlmRequest const& llmRequest)
{
std::scoped_lock lk(mSenderMutex);
auto it = mPendingRequests.find(llmRequest.mRequestId);
return it != mPendingRequests.end();
}

private:
Expand All @@ -723,6 +746,8 @@ class CacheSender::Impl
AsyncSendResource mAsyncSendResource;
std::vector<std::future<void>> mAsyncSendFutures;
int mDeviceId{-1};
std::unordered_map<LlmRequest::RequestIdType, RequestInfo> mPendingRequests;
std::queue<RequestInfo> mReadyPendingRequests;

executor::kv_cache::ConnectionManager* mManager;
std::map<LlmRequest::RequestIdType, TransferSession> mRequestToSession;
Expand Down Expand Up @@ -1189,6 +1214,11 @@ void CacheSender::sendReadySignal(LlmRequest::RequestIdType requestId, bool isRe
mImpl->sendReadySignal(requestId, isReady);
}

void CacheSender::checkContextRequestReady(LlmRequest const& llmRequest)
{
return mImpl->checkContextRequestReady(llmRequest);
}

CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
Expand Down
3 changes: 3 additions & 0 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ class CacheSender
/// @param isReady Whether the request is ready to be received.
virtual void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady);

/// @brief Check if the context request is ready to be received.
virtual bool checkContextRequestReady(LlmRequest const& llmRequest);

/// @brief Destructor.
virtual ~CacheSender();

Expand Down
8 changes: 7 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
{
NB_OVERRIDE_PURE(cancelRequest, llmRequest);
}

void prepareContextRequest(tb::LlmRequest* llmRequest) override
{
NB_OVERRIDE_PURE(prepareContextRequest, llmRequest);
}
};
} // namespace

Expand All @@ -93,7 +98,8 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
nb::call_guard<nb::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
.def("cancel_request", &BaseCacheTransceiver::cancelRequest);
.def("cancel_request", &BaseCacheTransceiver::cancelRequest)
.def("prepare_context_request", &BaseCacheTransceiver::prepareContextRequest);

nb::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/nanobind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
.value("DISAGG_CONTEXT_WAIT_SCHEDULE", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULE)
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS)
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR);

nb::class_<tr::MemoryCounters>(m, "MemoryCounters")
Expand Down
8 changes: 7 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver
{
PYBIND11_OVERLOAD_PURE(bool, tb::BaseCacheTransceiver, cancelRequest, llmRequest);
}

void prepareContextRequest(tb::LlmRequest* llmRequest) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, prepareContextRequest, llmRequest);
}
};
} // namespace

Expand All @@ -89,7 +94,8 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
.def("cancel_request", &BaseCacheTransceiver::cancelRequest);
.def("cancel_request", &BaseCacheTransceiver::cancelRequest)
.def("prepare_context_request", &BaseCacheTransceiver::prepareContextRequest);

py::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
.value("DISAGG_CONTEXT_WAIT_SCHEDULE", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULE)
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);

py::class_<tr::MemoryCounters>(m, "MemoryCounters")
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
Expand Down
3 changes: 1 addition & 2 deletions docker/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ CCACHE_DIR ?= $(CODE_DIR)/cpp/.ccache
CONAN_DIR ?= $(CODE_DIR)/cpp/.conan
USER_CACHE_DIR ?= $(shell readlink -f "${HOME_DIR}/.cache")
RUN_CMD ?=
CONTAINER_NAME ?= tensorrt_llm
CONTAINER_NAME ?= tensorrt_llm_$(git branch --show-current)
WORK_DIR ?= $(CODE_DIR)
DOCKER_PULL ?= 0

Expand All @@ -167,7 +167,6 @@ endif
$(GPU_OPTS) \
--volume $(SOURCE_DIR):$(CODE_DIR) \
$(EXTRA_VOLUMES) \
$(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \
--env "CCACHE_DIR=$(CCACHE_DIR)" \
--env "CCACHE_BASEDIR=$(CODE_DIR)" \
--env "CONAN_HOME=$(CONAN_DIR)" \
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ def _prepare_and_schedule_batch(self):
if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()
self._prepare_disagg_ctx_init(new_requests)

iter_stats = None
if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -1932,6 +1933,13 @@ def _check_disagg_gen_transfer_status(self):

return

@nvtx_range("_prepare_disagg_ctx_init")
def _prepare_disagg_ctx_init(self, new_requests):
for req in new_requests:
if req.is_context_only_request:
self.kv_cache_transceiver.prepare_context_request(req)
return

@nvtx_range("_check_kv_transfer_timeout")
def _check_kv_transfer_timeout(self):
if not self.kv_cache_transceiver:
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/llmapi/disagg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class DisaggServerConfig():
max_retries: int = 1
perf_metrics_max_requests: int = 0
disagg_cluster_config: Optional[DisaggClusterConfig] = None
enable_prealloc: bool = False


@dataclass
Expand Down
47 changes: 46 additions & 1 deletion tensorrt_llm/serve/openai_disagg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import asyncio
import copy
import os
import uuid
from typing import Any, Callable, Dict, Optional

from tensorrt_llm.llmapi.disagg_utils import (
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
self._ctx_client = None
self._gen_client = None
self._disagg_cluster_manager = None
self._prealloc_mode = config.enable_prealloc

async def openai_completion(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
Expand Down Expand Up @@ -134,6 +136,36 @@ async def _send_disagg_request(
return done_generator()
return ctx_response

async def _send_ctx_request_prealloc(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
) -> UCompletionResponse:
if hooks:
hooks.on_req_begin(request)
# empty server means client decides which server to use
ctx_server, gen_server = None, None
# reserve a gen_server if conditional disagg is needed
gen_server, need_ctx = await self._check_conditional_disagg(request)
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
gen_req = request
tasks = []
ctx_req, gen_req = None, None
if need_ctx:
ctx_req = self._get_ctx_request(request)
tasks.append(
asyncio.create_task(
self._ctx_client.send_request(ctx_req, server=ctx_server, hooks=hooks)
)
)
gen_req = self._get_gen_request_prealloc(request)
tasks.append(
asyncio.create_task(
self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks)
)
)
responses = await asyncio.gather(*tasks)
# TODO: handle non-streaming requests
return responses[-1]

def _need_gen(self, response: UCompletionResponse) -> bool:
if response and response.choices[0].finish_reason not in ["length", "not_finished"]:
del response.choices[0].disaggregated_params
Expand All @@ -142,11 +174,21 @@ def _need_gen(self, response: UCompletionResponse) -> bool:

def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
ctx_request = copy.deepcopy(request)
ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
ctx_request.disaggregated_params = DisaggregatedParams(
request_type="context_only", disagg_id=str(uuid.uuid4())
)
ctx_request.stream = False
ctx_request.stream_options = None
return ctx_request

def _get_gen_request_prealloc(
self,
request: UCompletionRequest,
) -> UCompletionRequest:
gen_request = copy.deepcopy(request)
gen_request.disaggregated_params = DisaggregatedParams(request_type="generation_only")
return gen_request

def _get_gen_request(
self,
request: UCompletionRequest,
Expand Down Expand Up @@ -176,6 +218,9 @@ async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool:
):
return gen_server, True
return gen_server, False
if self._prealloc_mode:
gen_server, _ = await self._gen_router.get_next_server(request)
return gen_server, True
return None, True

async def _check_gen_only_disagg(self, request: UCompletionRequest) -> bool:
Expand Down
Loading