Skip to content

Commit d4b6ddb

Browse files
committed
fix: fix potential processor leakage, move session related fields to SamplingInputs and refater the code
1 parent 443db85 commit d4b6ddb

File tree

6 files changed

+138
-177
lines changed

6 files changed

+138
-177
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,6 @@ def __init__(self,
382382
dtype=engine_config.dtype)
383383
self.executor.init()
384384

385-
self.session_to_cleanup = []
386-
387385
# strategies
388386
self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config)
389387
self.sampling_strategy = self.strategy_factory.build_sampling_strategy()
@@ -915,14 +913,6 @@ def __need_logits(seqs: SeqList):
915913

916914
sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num
917915

918-
session_ctx = [{
919-
'session_id': seq.session.session_id,
920-
'seq_id': seq.seq_id,
921-
} for seq in running]
922-
923-
session_to_cleanup = self.session_to_cleanup
924-
self.session_to_cleanup = []
925-
926916
return dict(
927917
running=running,
928918
inputs=inputs,
@@ -935,8 +925,6 @@ def __need_logits(seqs: SeqList):
935925
is_dummy=False,
936926
sync_long_context=sync_long_context,
937927
extra_inputs=extra_inputs,
938-
session_ctx=session_ctx,
939-
session_to_cleanup=session_to_cleanup,
940928
)
941929

942930
async def _await_forward_event(self, forward_event: asyncio.Event):
@@ -1250,7 +1238,7 @@ def start_loop(self):
12501238
def end_session(self, session_id: int):
12511239
"""End session."""
12521240
if session_id in self.scheduler.sessions:
1253-
self.session_to_cleanup.append(session_id)
1241+
self.sampling_strategy.on_session_end(session_id)
12541242
self.scheduler.end_session(session_id)
12551243
return True
12561244
return False
Lines changed: 94 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import copy
32
import json
43
import logging
5-
from typing import Optional
4+
from typing import Any, Dict, List, Optional, Tuple
65

76
import torch
87
import xgrammar as xgr
@@ -11,103 +10,96 @@
1110
logger = logging.getLogger('lmdeploy')
1211

1312

14-
class BaseLogitsProcessor:
15-
"""Base logits processor that uses xgrammar matcher for guided decoding."""
16-
17-
def __init__(self, compiled_grammar: xgr.CompiledGrammar, tokenizer_info: xgr.TokenizerInfo):
18-
self.matcher = xgr.GrammarMatcher(compiled_grammar, terminate_without_stop_token=True)
19-
20-
def fill_bitmap(self, guided_bitmask: torch.Tensor, index: int) -> None:
21-
"""Fill the bitmask for the next token prediction at given index."""
22-
self.matcher.fill_next_token_bitmask(guided_bitmask, index)
23-
24-
def accept(self, token_id: int) -> bool:
25-
"""Update matcher state after a token is generated."""
26-
return self.matcher.accept_token(token_id)
27-
28-
def reset(self):
29-
"""Reset matcher state for next generation."""
30-
self.matcher.reset()
31-
32-
33-
class RegexLogitsProcessor(BaseLogitsProcessor):
34-
"""Regex-guided logits processor using xgrammar."""
35-
36-
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None):
37-
tokenizer = copy.deepcopy(tokenizer)
38-
if vocab_size_padded is None:
39-
vocab_size_padded = tokenizer.vocab_size
40-
41-
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded)
42-
43-
compiler = xgr.GrammarCompiler(tokenizer_info)
44-
compiled = compiler.compile_regex_grammar(regex_string)
45-
46-
super().__init__(compiled, tokenizer_info)
47-
48-
49-
class JSONLogitsProcessor(BaseLogitsProcessor):
50-
"""JSON-schema guided logits processor using xgrammar."""
51-
52-
def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None):
53-
tokenizer = copy.deepcopy(tokenizer)
54-
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded)
55-
if vocab_size_padded is None:
56-
vocab_size_padded = tokenizer.vocab_size
57-
58-
compiler = xgr.GrammarCompiler(tokenizer_info)
59-
if isinstance(schema, str):
60-
schema = json.loads(schema)
61-
62-
assert isinstance(schema, dict)
63-
compiled = compiler.compile_json_schema(schema)
64-
65-
super().__init__(compiled, tokenizer_info)
66-
67-
68-
_guided_processors = {}
69-
70-
71-
def _get_guided_logits_processor(session_id: int,
72-
seq_id: int,
73-
guide: str,
74-
tokenizer: PreTrainedTokenizerBase,
75-
type: str,
76-
vocab_size_padded: Optional[int] = None):
77-
if session_id in _guided_processors:
78-
session_dict = _guided_processors[session_id]
79-
if seq_id in session_dict:
80-
processor = session_dict[seq_id]
81-
return processor
82-
83-
if type == 'json_schema':
84-
processor = JSONLogitsProcessor(guide, tokenizer, vocab_size_padded)
85-
elif type == 'regex_schema':
86-
processor = RegexLogitsProcessor(guide, tokenizer, vocab_size_padded)
87-
else:
88-
assert False, f'Do not support schema type {type}'
89-
90-
_guided_processors.setdefault(session_id, {})[seq_id] = processor
91-
return processor
92-
93-
94-
def _remove_guided_logtis_processor(session_id: int):
95-
if session_id in _guided_processors:
96-
del _guided_processors[session_id]
97-
98-
99-
def _allocate_batched_bitmap(batch_size: int, vocab_size: int):
100-
return xgr.allocate_token_bitmask(batch_size, vocab_size)
101-
102-
103-
def _apply_batched_bitmap(logits: torch.Tensor, guided_bitmask: torch.Tensor) -> None:
104-
device = logits.device
105-
dtype = logits.dtype
106-
107-
if device.type in {'cpu', 'cuda'}:
108-
xgr.apply_token_bitmask_inplace(logits, guided_bitmask.to(device))
109-
else:
110-
cpu_logits = logits.cpu().float()
111-
cpu_mask = guided_bitmask.cpu()
112-
xgr.apply_token_bitmask_inplace(cpu_logits, cpu_mask)
113-
logits.copy_(cpu_logits.to(device, dtype))
13+
class GuidedDecodingMangager:
14+
processors = {}
15+
16+
def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: Optional[int]):
17+
if vocab_size is None:
18+
vocab_size = tokenizer.vocab_size
19+
20+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size)
21+
self.compiler = xgr.GrammarCompiler(tokenizer_info)
22+
self.vocab_size = vocab_size
23+
24+
def get_processors(self, session_ctx: List[Dict[str, Any]],
25+
response_formats: Tuple[Dict]) -> Dict[int, xgr.GrammarMatcher]:
26+
processors = {}
27+
for i, _format in enumerate(response_formats):
28+
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
29+
if _format['type'] == 'json_schema':
30+
schema = _format['json_schema']
31+
if isinstance(schema, Dict):
32+
for key in ['json_schema', 'schema']:
33+
if key in schema:
34+
schema = json.dumps(schema[key], ensure_ascii=False)
35+
36+
if not isinstance(schema, str):
37+
raise ValueError(f'Cannot parse schema {schema}. The schema must be '
38+
'either a dictionary or a string that contains the'
39+
' JSON Schema specification')
40+
elif _format['type'] == 'regex_schema':
41+
schema = _format.get('regex_schema', '')
42+
else:
43+
raise ValueError(f"unsupported format type: {_format['type']}")
44+
45+
session_id = session_ctx[i]['session_id']
46+
seq_id = session_ctx[i]['seq_id']
47+
48+
processors[i] = self.get_processor(session_id, seq_id, schema, _format['type'])
49+
50+
return processors
51+
52+
def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) -> xgr.GrammarMatcher:
53+
if session_id in self.processors:
54+
session_dict = self.processors[session_id]
55+
if seq_id in session_dict:
56+
processor = session_dict[seq_id]
57+
return processor
58+
59+
if type == 'json_schema':
60+
if isinstance(schema, str):
61+
schema = json.loads(schema)
62+
63+
assert isinstance(schema, dict)
64+
compiled = self.compiler.compile_json_schema(schema)
65+
elif type == 'regex_schema':
66+
compiled = self.compiler.compile_regex_grammar(schema)
67+
else:
68+
assert False, f'Do not support schema type {type}'
69+
70+
processor = xgr.GrammarMatcher(compiled, terminate_without_stop_token=True)
71+
self.processors.setdefault(session_id, {})[seq_id] = processor
72+
logger.info(f'create guided processor for session_id={session_id}, seq_id={seq_id}, and '
73+
f'total_processors={len(self.processors)}')
74+
return processor
75+
76+
def remove_processor(self, session_id: int):
77+
if session_id in self.processors:
78+
del self.processors[session_id]
79+
logger.info(
80+
f'delete guided processor for session_id={session_id}, and total_processors={len(self.processors)}')
81+
82+
def allocate_batched_bitmap(self, batch_size: int) -> torch.Tensor:
83+
return xgr.allocate_token_bitmask(batch_size, self.vocab_size)
84+
85+
def fill_bitmap(self, processor: xgr.GrammarMatcher, guided_bitmask: torch.Tensor, index: int) -> None:
86+
processor.fill_next_token_bitmask(guided_bitmask, index)
87+
88+
def accept_token(self, processor: xgr.GrammarMatcher, token: int) -> None:
89+
processor.accept_token(token)
90+
91+
def apply_batched_bitmap(self, logits: torch.Tensor, guided_bitmask: torch.Tensor) -> None:
92+
device = logits.device
93+
dtype = logits.dtype
94+
95+
if device.type in {'cpu', 'cuda'}:
96+
xgr.apply_token_bitmask_inplace(logits, guided_bitmask.to(device))
97+
else:
98+
cpu_logits = logits.cpu().float()
99+
cpu_mask = guided_bitmask.cpu()
100+
xgr.apply_token_bitmask_inplace(cpu_logits, cpu_mask)
101+
logits.copy_(cpu_logits.to(device, dtype))
102+
103+
def clear(self) -> None:
104+
self.processors.clear()
105+
logger.info(f'clear guided processors, total_processors={len(self.processors)}')

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import asyncio
3-
import json
43
from dataclasses import dataclass, fields
54
from typing import Any, Dict, List, Optional, Tuple
65

76
import torch
87

98
from lmdeploy.messages import LogitsProcessor
10-
from lmdeploy.tokenizer import Tokenizer
119

1210
from ..messages import SchedulerSequence
11+
from .guided_process import GuidedDecodingMangager
1312

1413

1514
def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor):
@@ -78,37 +77,6 @@ def _multinomial_sampling(scores: torch.Tensor,
7877
return multinomial_sampling(scores, seeds, offsets, indices)
7978

8079

81-
def _get_guided_processors(response_formats: Tuple[Dict], tokenizer: object, vocab_size_padded: int,
82-
session_ctx: List[Dict[str, Any]]):
83-
processors = {}
84-
for i, _format in enumerate(response_formats):
85-
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
86-
if _format['type'] == 'json_schema':
87-
schema = _format['json_schema']
88-
if isinstance(schema, Dict):
89-
for key in ['json_schema', 'schema']:
90-
if key in schema:
91-
schema = json.dumps(schema[key], ensure_ascii=False)
92-
93-
if not isinstance(schema, str):
94-
raise ValueError(f'Cannot parse schema {schema}. The schema must be '
95-
'either a dictionary or a string that contains the'
96-
' JSON Schema specification')
97-
elif _format['type'] == 'regex_schema':
98-
schema = _format.get('regex_schema', '')
99-
else:
100-
raise ValueError(f"unsupported format type: {_format['type']}")
101-
102-
session_id = session_ctx[i]['session_id']
103-
seq_id = session_ctx[i]['seq_id']
104-
105-
from .guided_process import _get_guided_logits_processor
106-
processors[i] = _get_guided_logits_processor(session_id, seq_id, schema, tokenizer, _format['type'],
107-
vocab_size_padded)
108-
109-
return processors
110-
111-
11280
SeqList = List[SchedulerSequence]
11381

11482

@@ -133,6 +101,8 @@ class SamplingInputs:
133101
all_ids: Optional[torch.Tensor] = None
134102
num_ignore_eos: torch.Tensor = None
135103
batch_size: int = 0
104+
session_ctx: Optional[List[Dict[str, Any]]] = None
105+
session_to_cleanup: Optional[List[int]] = None
136106

137107
def to_device(self, device: str, non_blocking: bool = False):
138108
"""To device."""
@@ -162,17 +132,19 @@ class FusedLogitsProcessor:
162132
def __init__(
163133
self,
164134
sampling_inputs: SamplingInputs,
165-
tokenizer: Optional[Tokenizer] = None,
166135
sampling_vocab_size: Optional[int] = None,
167136
logprobs_mode: Optional[str] = None,
168-
session_ctx: Optional[List[Dict[str, Any]]] = None,
137+
guided_decoding_manager: Optional[GuidedDecodingMangager] = None,
169138
):
170139
self.sampling_inputs: SamplingInputs = sampling_inputs
171-
self.tokenizer = tokenizer
172140
self.sampling_vocab_size = sampling_vocab_size
173141
self.logprobs_mode = logprobs_mode
174-
self.guided_processors = _get_guided_processors(sampling_inputs.response_formats, tokenizer,
175-
sampling_vocab_size, session_ctx)
142+
self.guided_decoding_manager = guided_decoding_manager
143+
if sampling_inputs.session_to_cleanup:
144+
self.cleanup_sessions(sampling_inputs.session_to_cleanup)
145+
146+
self.guided_processors = self.guided_decoding_manager.get_processors(sampling_inputs.session_ctx,
147+
sampling_inputs.response_formats)
176148

177149
async def _wait_stream_once(self):
178150
"""Wait stream once."""
@@ -211,19 +183,17 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
211183
all_ids = sampling_inputs.all_ids
212184
custom_logits_processors = self.sampling_inputs.logits_processors
213185
if self.guided_processors:
214-
from .guided_process import _allocate_batched_bitmap, _apply_batched_bitmap
215-
216186
if not hasattr(self, 'guided_bitmask'):
217-
self.guided_bitmask = _allocate_batched_bitmap(len(scores), self.sampling_vocab_size)
187+
self.guided_bitmask = self.guided_decoding_manager.allocate_batched_bitmap(len(scores))
218188

219189
assert self.guided_bitmask is not None
220190
guided_bitmask = self.guided_bitmask
221191

222192
await self._wait_stream_once()
223193
for i, processor in self.guided_processors.items():
224-
processor.fill_bitmap(guided_bitmask, i)
194+
self.guided_decoding_manager.fill_bitmap(processor, guided_bitmask, i)
225195

226-
_apply_batched_bitmap(scores, guided_bitmask)
196+
self.guided_decoding_manager.apply_batched_bitmap(scores, guided_bitmask)
227197

228198
if any(custom_logits_processors):
229199
await self._wait_stream_once()
@@ -298,7 +268,7 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
298268

299269
if self.guided_processors:
300270
for i, processor in self.guided_processors.items():
301-
processor.accept(result[i])
271+
self.guided_decoding_manager.accept_token(processor, result[i])
302272

303273
return result
304274

@@ -318,8 +288,6 @@ def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTens
318288

319289
return logprobs, indices.to(torch.int32)
320290

321-
@staticmethod
322-
def cleanup_sessions(session_ids: List[int]):
323-
from .guided_process import _remove_guided_logtis_processor
291+
def cleanup_sessions(self, session_ids: List[int]):
324292
for session_id in session_ids:
325-
_remove_guided_logtis_processor(session_id)
293+
self.guided_decoding_manager.remove_processor(session_id)

0 commit comments

Comments
 (0)