Skip to content

Commit fc0fbce

Browse files
committed
fix: add a session status synchronization to help model_agent manage guided processors
1 parent 0b2c106 commit fc0fbce

File tree

4 files changed

+83
-30
lines changed

4 files changed

+83
-30
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ def __init__(self,
382382
dtype=engine_config.dtype)
383383
self.executor.init()
384384

385+
self.session_to_cleanup = []
386+
385387
# strategies
386388
self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config)
387389
self.sampling_strategy = self.strategy_factory.build_sampling_strategy()
@@ -551,7 +553,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs):
551553
if len(msgs) > 0 and msgs[0].preserve_cache:
552554
self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED)
553555
else:
554-
self.scheduler.end_session(session_id)
556+
self.end_session(session_id)
555557
resp_type = ResponseType.SUCCESS
556558
if resp:
557559
self._response(req.resp, resp_type)
@@ -912,6 +914,15 @@ def __need_logits(seqs: SeqList):
912914
stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)
913915

914916
sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num
917+
918+
session_ctx = [{
919+
'session_id': seq.session.session_id,
920+
'seq_id': seq.seq_id,
921+
} for seq in running]
922+
923+
session_to_cleanup = self.session_to_cleanup
924+
self.session_to_cleanup = []
925+
915926
return dict(
916927
running=running,
917928
inputs=inputs,
@@ -924,6 +935,8 @@ def __need_logits(seqs: SeqList):
924935
is_dummy=False,
925936
sync_long_context=sync_long_context,
926937
extra_inputs=extra_inputs,
938+
session_ctx=session_ctx,
939+
session_to_cleanup=session_to_cleanup,
927940
)
928941

929942
async def _await_forward_event(self, forward_event: asyncio.Event):
@@ -1237,6 +1250,7 @@ def start_loop(self):
12371250
def end_session(self, session_id: int):
12381251
"""End session."""
12391252
if session_id in self.scheduler.sessions:
1253+
self.session_to_cleanup.append(session_id)
12401254
self.scheduler.end_session(session_id)
12411255
return True
12421256
return False

lmdeploy/pytorch/engine/guided_process.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import copy
33
import json
44
import logging
5-
from functools import lru_cache
65
from typing import Optional
76

87
import torch
98
import xgrammar as xgr
109
from transformers import PreTrainedTokenizerBase
1110

12-
logger = logging.getLogger('guided_process')
11+
logger = logging.getLogger('lmdeploy')
1312

1413

1514
class BaseLogitsProcessor:
@@ -70,18 +69,31 @@ def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase, vocab_size_p
7069
super().__init__(compiled, tokenizer_info)
7170

7271

73-
@lru_cache(maxsize=32)
74-
def _get_guided_logits_processor(guide: str,
72+
_guided_processors = {}
73+
74+
75+
def _get_guided_logits_processor(session_id: int,
76+
seq_id: int,
77+
guide: str,
7578
tokenizer: PreTrainedTokenizerBase,
7679
type: str,
7780
vocab_size_padded: Optional[int] = None):
78-
try:
79-
if type == 'json_schema':
80-
return JSONLogitsProcessor(guide, tokenizer, vocab_size_padded)
81-
elif type == 'regex_schema':
82-
return RegexLogitsProcessor(guide, tokenizer, vocab_size_padded)
83-
else:
84-
return None
85-
except Exception as e:
86-
logger.error(e)
87-
raise
81+
if session_id in _guided_processors:
82+
session_dict = _guided_processors[session_id]
83+
if seq_id in session_dict:
84+
processor = session_dict[seq_id]
85+
return processor
86+
87+
if type == 'json_schema':
88+
processor = JSONLogitsProcessor(guide, tokenizer, vocab_size_padded)
89+
elif type == 'regex_schema':
90+
processor = RegexLogitsProcessor(guide, tokenizer, vocab_size_padded)
91+
else:
92+
assert False, f'Do not support schema type {type}'
93+
94+
_guided_processors.setdefault(session_id, {})[seq_id] = processor
95+
return processor
96+
97+
98+
def _remove_guided_logtis_processor(session_id: int):
99+
del _guided_processors[session_id]

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import asyncio
33
import json
44
from dataclasses import dataclass, fields
5-
from typing import Dict, List, Optional, Tuple
5+
from typing import Any, Dict, List, Optional, Tuple
66

77
import torch
88

@@ -78,7 +78,8 @@ def _multinomial_sampling(scores: torch.Tensor,
7878
return multinomial_sampling(scores, seeds, offsets, indices)
7979

8080

81-
def _get_guided_processors(response_formats: Tuple[Dict], tokenizer: object, vocab_size_padded: int):
81+
def _get_guided_processors(response_formats: Tuple[Dict], tokenizer: object, vocab_size_padded: int,
82+
session_ctx: List[Dict[str, Any]]):
8283
processors = {}
8384
for i, _format in enumerate(response_formats):
8485
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
@@ -98,8 +99,12 @@ def _get_guided_processors(response_formats: Tuple[Dict], tokenizer: object, voc
9899
else:
99100
raise ValueError(f"unsupported format type: {_format['type']}")
100101

102+
session_id = session_ctx[i]['session_id']
103+
seq_id = session_ctx[i]['seq_id']
104+
101105
from .guided_process import _get_guided_logits_processor
102-
processors[i] = _get_guided_logits_processor(schema, tokenizer, _format['type'], vocab_size_padded)
106+
processors[i] = _get_guided_logits_processor(session_id, seq_id, schema, tokenizer, _format['type'],
107+
vocab_size_padded)
103108

104109
return processors
105110

@@ -154,17 +159,20 @@ def _apply_custom_logits_processors(batched_logits_processors, all_ids, logits):
154159
class FusedLogitsProcessor:
155160
"""Custom logits processor."""
156161

157-
def __init__(self,
158-
sampling_inputs: SamplingInputs,
159-
tokenizer: Optional[Tokenizer] = None,
160-
sampling_vocab_size: Optional[int] = None,
161-
logprobs_mode: Optional[str] = None):
162+
def __init__(
163+
self,
164+
sampling_inputs: SamplingInputs,
165+
tokenizer: Optional[Tokenizer] = None,
166+
sampling_vocab_size: Optional[int] = None,
167+
logprobs_mode: Optional[str] = None,
168+
session_ctx: Optional[List[Dict[str, Any]]] = None,
169+
):
162170
self.sampling_inputs: SamplingInputs = sampling_inputs
163171
self.tokenizer = tokenizer
164172
self.sampling_vocab_size = sampling_vocab_size
165173
self.logprobs_mode = logprobs_mode
166174
self.guided_processors = _get_guided_processors(sampling_inputs.response_formats, tokenizer,
167-
sampling_vocab_size)
175+
sampling_vocab_size, session_ctx)
168176

169177
async def _wait_stream_once(self):
170178
"""Wait stream once."""
@@ -299,3 +307,9 @@ def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTens
299307
indices = torch.cat([indices, topk_indices], dim=-1)
300308

301309
return logprobs, indices.to(torch.int32)
310+
311+
@staticmethod
312+
def cleanup_sessions(session_ids: List[int]):
313+
from .guided_process import _remove_guided_logtis_processor
314+
for session_id in session_ids:
315+
_remove_guided_logtis_processor(session_id)

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,16 +533,20 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
533533
ret['logits'] = logits
534534
return ret
535535

536-
async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs, inputs: ModelInputs):
536+
async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs, inputs: ModelInputs,
537+
session_ctx: List[Dict[str, Any]]):
537538
"""Sampling logits."""
538539

539540
# record function does not support async function
540541
# so we can not decorate it on async_sampling_logits
541542
with record_function('sampling_logits'):
542-
logits_processor = FusedLogitsProcessor(sampling_inputs,
543-
self.tokenizer,
544-
sampling_vocab_size=self.sampling_vocab_size,
545-
logprobs_mode=self.misc_config.logprobs_mode)
543+
logits_processor = FusedLogitsProcessor(
544+
sampling_inputs,
545+
self.tokenizer,
546+
sampling_vocab_size=self.sampling_vocab_size,
547+
logprobs_mode=self.misc_config.logprobs_mode,
548+
session_ctx=session_ctx,
549+
)
546550
origin_logits = logits
547551
logits, raw_logprobs = await logits_processor(origin_logits)
548552
next_token_ids = logits_processor.sampling(logits)
@@ -586,6 +590,8 @@ async def _async_step_background(
586590
is_dummy: bool = False,
587591
sync_long_context: bool = False,
588592
extra_inputs: ExtraInputs = None,
593+
session_ctx: List[Dict[str, Any]] = None,
594+
session_to_cleanup: List[int] = None,
589595
):
590596
"""Asyc forward task."""
591597
dist_ctx = get_dist_manager().current_context()
@@ -678,6 +684,9 @@ async def __prepare_dp():
678684

679685
need_output = dp > 1 or rank % tp == 0
680686

687+
if session_to_cleanup:
688+
self.cleanup_sessions(session_to_cleanup)
689+
681690
# skip dummy forward.
682691
if is_all_dummy:
683692
logger.debug(f'<ForwardTask> rank[{rank}]: all inputs are dummy, skip forward.')
@@ -709,7 +718,8 @@ async def __prepare_dp():
709718
if need_output:
710719
logger.debug(f'<ForwardTask> rank[{rank}]: Sampling [{idx}].')
711720
# sampling
712-
next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs)
721+
next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs,
722+
session_ctx)
713723

714724
with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next):
715725
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids [{idx}]')
@@ -1062,6 +1072,9 @@ def release(self):
10621072
self.cache_engine = None
10631073
torch.cuda.empty_cache()
10641074

1075+
def cleanup_sessions(self, session_ids: List[int]):
1076+
FusedLogitsProcessor.cleanup_sessions(session_ids)
1077+
10651078

10661079
class DefaultForwardInputsMaker:
10671080
"""Default forward inputs maker."""

0 commit comments

Comments
 (0)