|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | | -import copy |
3 | 2 | import json |
4 | 3 | import logging |
5 | | -from typing import Optional |
| 4 | +from typing import Any, Dict, List, Optional, Tuple |
6 | 5 |
|
7 | 6 | import torch |
8 | 7 | import xgrammar as xgr |
|
11 | 10 | logger = logging.getLogger('lmdeploy') |
12 | 11 |
|
13 | 12 |
|
14 | | -class BaseLogitsProcessor: |
15 | | - """Base logits processor that uses xgrammar matcher for guided decoding.""" |
16 | | - |
17 | | - def __init__(self, compiled_grammar: xgr.CompiledGrammar, tokenizer_info: xgr.TokenizerInfo): |
18 | | - self.matcher = xgr.GrammarMatcher(compiled_grammar, terminate_without_stop_token=True) |
19 | | - |
20 | | - def fill_bitmap(self, guided_bitmask: torch.Tensor, index: int) -> None: |
21 | | - """Fill the bitmask for the next token prediction at given index.""" |
22 | | - self.matcher.fill_next_token_bitmask(guided_bitmask, index) |
23 | | - |
24 | | - def accept(self, token_id: int) -> bool: |
25 | | - """Update matcher state after a token is generated.""" |
26 | | - return self.matcher.accept_token(token_id) |
27 | | - |
28 | | - def reset(self): |
29 | | - """Reset matcher state for next generation.""" |
30 | | - self.matcher.reset() |
31 | | - |
32 | | - |
33 | | -class RegexLogitsProcessor(BaseLogitsProcessor): |
34 | | - """Regex-guided logits processor using xgrammar.""" |
35 | | - |
36 | | - def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None): |
37 | | - tokenizer = copy.deepcopy(tokenizer) |
38 | | - if vocab_size_padded is None: |
39 | | - vocab_size_padded = tokenizer.vocab_size |
40 | | - |
41 | | - tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded) |
42 | | - |
43 | | - compiler = xgr.GrammarCompiler(tokenizer_info) |
44 | | - compiled = compiler.compile_regex_grammar(regex_string) |
45 | | - |
46 | | - super().__init__(compiled, tokenizer_info) |
47 | | - |
48 | | - |
49 | | -class JSONLogitsProcessor(BaseLogitsProcessor): |
50 | | - """JSON-schema guided logits processor using xgrammar.""" |
51 | | - |
52 | | - def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None): |
53 | | - tokenizer = copy.deepcopy(tokenizer) |
54 | | - tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded) |
55 | | - if vocab_size_padded is None: |
56 | | - vocab_size_padded = tokenizer.vocab_size |
57 | | - |
58 | | - compiler = xgr.GrammarCompiler(tokenizer_info) |
59 | | - if isinstance(schema, str): |
60 | | - schema = json.loads(schema) |
61 | | - |
62 | | - assert isinstance(schema, dict) |
63 | | - compiled = compiler.compile_json_schema(schema) |
64 | | - |
65 | | - super().__init__(compiled, tokenizer_info) |
66 | | - |
67 | | - |
68 | | -_guided_processors = {} |
69 | | - |
70 | | - |
71 | | -def _get_guided_logits_processor(session_id: int, |
72 | | - seq_id: int, |
73 | | - guide: str, |
74 | | - tokenizer: PreTrainedTokenizerBase, |
75 | | - type: str, |
76 | | - vocab_size_padded: Optional[int] = None): |
77 | | - if session_id in _guided_processors: |
78 | | - session_dict = _guided_processors[session_id] |
79 | | - if seq_id in session_dict: |
80 | | - processor = session_dict[seq_id] |
81 | | - return processor |
82 | | - |
83 | | - if type == 'json_schema': |
84 | | - processor = JSONLogitsProcessor(guide, tokenizer, vocab_size_padded) |
85 | | - elif type == 'regex_schema': |
86 | | - processor = RegexLogitsProcessor(guide, tokenizer, vocab_size_padded) |
87 | | - else: |
88 | | - assert False, f'Do not support schema type {type}' |
89 | | - |
90 | | - _guided_processors.setdefault(session_id, {})[seq_id] = processor |
91 | | - return processor |
92 | | - |
93 | | - |
94 | | -def _remove_guided_logtis_processor(session_id: int): |
95 | | - if session_id in _guided_processors: |
96 | | - del _guided_processors[session_id] |
97 | | - |
98 | | - |
99 | | -def _allocate_batched_bitmap(batch_size: int, vocab_size: int): |
100 | | - return xgr.allocate_token_bitmask(batch_size, vocab_size) |
101 | | - |
102 | | - |
103 | | -def _apply_batched_bitmap(logits: torch.Tensor, guided_bitmask: torch.Tensor) -> None: |
104 | | - device = logits.device |
105 | | - dtype = logits.dtype |
106 | | - |
107 | | - if device.type in {'cpu', 'cuda'}: |
108 | | - xgr.apply_token_bitmask_inplace(logits, guided_bitmask.to(device)) |
109 | | - else: |
110 | | - cpu_logits = logits.cpu().float() |
111 | | - cpu_mask = guided_bitmask.cpu() |
112 | | - xgr.apply_token_bitmask_inplace(cpu_logits, cpu_mask) |
113 | | - logits.copy_(cpu_logits.to(device, dtype)) |
| 13 | +class GuidedDecodingMangager: |
| 14 | + processors = {} |
| 15 | + |
| 16 | + def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: Optional[int]): |
| 17 | + if vocab_size is None: |
| 18 | + vocab_size = tokenizer.vocab_size |
| 19 | + |
| 20 | + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size) |
| 21 | + self.compiler = xgr.GrammarCompiler(tokenizer_info) |
| 22 | + self.vocab_size = vocab_size |
| 23 | + |
| 24 | + def get_processors(self, session_ctx: List[Dict[str, Any]], |
| 25 | + response_formats: Tuple[Dict]) -> Dict[int, xgr.GrammarMatcher]: |
| 26 | + processors = {} |
| 27 | + for i, _format in enumerate(response_formats): |
| 28 | + if isinstance(_format, Dict) and _format.get('type', 'text') != 'text': |
| 29 | + if _format['type'] == 'json_schema': |
| 30 | + schema = _format['json_schema'] |
| 31 | + if isinstance(schema, Dict): |
| 32 | + for key in ['json_schema', 'schema']: |
| 33 | + if key in schema: |
| 34 | + schema = json.dumps(schema[key], ensure_ascii=False) |
| 35 | + |
| 36 | + if not isinstance(schema, str): |
| 37 | + raise ValueError(f'Cannot parse schema {schema}. The schema must be ' |
| 38 | + 'either a dictionary or a string that contains the' |
| 39 | + ' JSON Schema specification') |
| 40 | + elif _format['type'] == 'regex_schema': |
| 41 | + schema = _format.get('regex_schema', '') |
| 42 | + else: |
| 43 | + raise ValueError(f"unsupported format type: {_format['type']}") |
| 44 | + |
| 45 | + session_id = session_ctx[i]['session_id'] |
| 46 | + seq_id = session_ctx[i]['seq_id'] |
| 47 | + |
| 48 | + processors[i] = self.get_processor(session_id, seq_id, schema, _format['type']) |
| 49 | + |
| 50 | + return processors |
| 51 | + |
| 52 | + def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) -> xgr.GrammarMatcher: |
| 53 | + if session_id in self.processors: |
| 54 | + session_dict = self.processors[session_id] |
| 55 | + if seq_id in session_dict: |
| 56 | + processor = session_dict[seq_id] |
| 57 | + return processor |
| 58 | + |
| 59 | + if type == 'json_schema': |
| 60 | + if isinstance(schema, str): |
| 61 | + schema = json.loads(schema) |
| 62 | + |
| 63 | + assert isinstance(schema, dict) |
| 64 | + compiled = self.compiler.compile_json_schema(schema) |
| 65 | + elif type == 'regex_schema': |
| 66 | + compiled = self.compiler.compile_regex_grammar(schema) |
| 67 | + else: |
| 68 | + assert False, f'Do not support schema type {type}' |
| 69 | + |
| 70 | + processor = xgr.GrammarMatcher(compiled, terminate_without_stop_token=True) |
| 71 | + self.processors.setdefault(session_id, {})[seq_id] = processor |
| 72 | + logger.info(f'create guided processor for session_id={session_id}, seq_id={seq_id}, and ' |
| 73 | + f'total_processors={len(self.processors)}') |
| 74 | + return processor |
| 75 | + |
| 76 | + def remove_processor(self, session_id: int): |
| 77 | + if session_id in self.processors: |
| 78 | + del self.processors[session_id] |
| 79 | + logger.info( |
| 80 | + f'delete guided processor for session_id={session_id}, and total_processors={len(self.processors)}') |
| 81 | + |
| 82 | + def allocate_batched_bitmap(self, batch_size: int) -> torch.Tensor: |
| 83 | + return xgr.allocate_token_bitmask(batch_size, self.vocab_size) |
| 84 | + |
| 85 | + def fill_bitmap(self, processor: xgr.GrammarMatcher, guided_bitmask: torch.Tensor, index: int) -> None: |
| 86 | + processor.fill_next_token_bitmask(guided_bitmask, index) |
| 87 | + |
| 88 | + def accept_token(self, processor: xgr.GrammarMatcher, token: int) -> None: |
| 89 | + processor.accept_token(token) |
| 90 | + |
| 91 | + def apply_batched_bitmap(self, logits: torch.Tensor, guided_bitmask: torch.Tensor) -> None: |
| 92 | + device = logits.device |
| 93 | + dtype = logits.dtype |
| 94 | + |
| 95 | + if device.type in {'cpu', 'cuda'}: |
| 96 | + xgr.apply_token_bitmask_inplace(logits, guided_bitmask.to(device)) |
| 97 | + else: |
| 98 | + cpu_logits = logits.cpu().float() |
| 99 | + cpu_mask = guided_bitmask.cpu() |
| 100 | + xgr.apply_token_bitmask_inplace(cpu_logits, cpu_mask) |
| 101 | + logits.copy_(cpu_logits.to(device, dtype)) |
| 102 | + |
| 103 | + def clear(self) -> None: |
| 104 | + self.processors.clear() |
| 105 | + logger.info(f'clear guided processors, total_processors={len(self.processors)}') |
0 commit comments