diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 1f17235522..f79ed234e6 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -64,7 +64,6 @@ jobs: python3 -m pip install /root/packages/cu118/flash_attn-*.whl python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt python3 -m pip install -e . - python3 -m pip install -U 'numpy<2.0' - name: Check env run: | python3 -m pip list diff --git a/docker/prepare_wheel.sh b/docker/prepare_wheel.sh index 1ffbbcf06b..4250c8820a 100755 --- a/docker/prepare_wheel.sh +++ b/docker/prepare_wheel.sh @@ -17,7 +17,6 @@ if [[ ${PYTHON_VERSION} = "3.13" ]]; then pip install setuptools_rust pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/google/sentencepiece.git@v0.2.0#subdirectory=python" - pip wheel -v --no-build-isolation --no-deps -w /wheels --use-deprecated=legacy-resolver outlines_core==0.1.26 fi if [[ "${CUDA_VERSION_SHORT}" != "cu118" ]]; then diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 2179b5a99f..bbc03e085c 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -551,7 +551,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs): if len(msgs) > 0 and msgs[0].preserve_cache: self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED) else: - self.scheduler.end_session(session_id) + self.end_session(session_id) resp_type = ResponseType.SUCCESS if resp: self._response(req.resp, resp_type) @@ -912,6 +912,7 @@ def __need_logits(seqs: SeqList): stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running) sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num + return dict( running=running, inputs=inputs, @@ -1237,6 +1238,7 @@ def start_loop(self): def end_session(self, session_id: int): """End session.""" if session_id in self.scheduler.sessions: + self.sampling_strategy.on_session_end(session_id) self.scheduler.end_session(session_id) return True return False diff --git a/lmdeploy/pytorch/engine/guided_process.py b/lmdeploy/pytorch/engine/guided_process.py index cc25906f60..414f95351b 100644 --- a/lmdeploy/pytorch/engine/guided_process.py +++ b/lmdeploy/pytorch/engine/guided_process.py @@ -1,161 +1,105 @@ -# Copyright 2024- the Outlines developers -# This file is adapted from -# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -import copy -import math -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import defaultdict -from functools import lru_cache -from typing import DefaultDict, Dict, List, Union +# Copyright (c) OpenMMLab. All rights reserved. +import json +import logging +from typing import Any, Dict, List, Optional, Tuple import torch -from outlines.fsm.guide import CFGGuide, Generate, RegexGuide, Write -from outlines.fsm.json_schema import build_regex_from_schema -from pydantic import BaseModel +import xgrammar as xgr from transformers import PreTrainedTokenizerBase - -class BaseLogitsProcessor: - - def init_state(self): - """Initialize the FSM states.""" - self.fsm_state: DefaultDict[int, int] = defaultdict(int) - - def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token.""" - - seq_id = hash(tuple(input_ids)) - - if len(input_ids) == 0: - self.init_state() - else: - last_token = input_ids[-1] - last_seq_id = hash(tuple(input_ids[:-1])) - self.fsm_state[seq_id] = self.fsm.get_next_state(state=self.fsm_state[last_seq_id], token_id=last_token) - - instruction = self.fsm.get_next_instruction(self.fsm_state[seq_id]) - - if type(instruction) == Generate: - allowed_tokens = instruction.tokens - elif type(instruction) == Write: - # TODO: support fast forward tokens - allowed_tokens = [instruction.tokens[0]] +logger = logging.getLogger('lmdeploy') + + +class GuidedDecodingMangager: + processors = {} + + def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: Optional[int]): + if vocab_size is None: + vocab_size = tokenizer.vocab_size + + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size) + self.compiler = xgr.GrammarCompiler(tokenizer_info) + self.vocab_size = vocab_size + + def get_processors(self, session_ctx: List[Dict[str, Any]], + response_formats: Tuple[Dict]) -> Dict[int, xgr.GrammarMatcher]: + processors = {} + for i, _format in enumerate(response_formats): + if isinstance(_format, Dict) and _format.get('type', 'text') != 'text': + if _format['type'] == 'json_schema': + schema = _format['json_schema'] + if isinstance(schema, Dict): + for key in ['json_schema', 'schema']: + if key in schema: + schema = json.dumps(schema[key], ensure_ascii=False) + + if not isinstance(schema, str): + raise ValueError(f'Cannot parse schema {schema}. The schema must be ' + 'either a dictionary or a string that contains the' + ' JSON Schema specification') + elif _format['type'] == 'regex_schema': + schema = _format.get('regex_schema', '') + else: + raise ValueError(f"unsupported format type: {_format['type']}") + + session_id = session_ctx[i]['session_id'] + seq_id = session_ctx[i]['seq_id'] + + processors[i] = self.get_processor(session_id, seq_id, schema, _format['type']) + + return processors + + def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) -> xgr.GrammarMatcher: + if session_id in self.processors: + session_dict = self.processors[session_id] + if seq_id in session_dict: + processor = session_dict[seq_id] + return processor + + if type == 'json_schema': + if isinstance(schema, str): + schema = json.loads(schema) + + assert isinstance(schema, dict) + compiled = self.compiler.compile_json_schema(schema) + elif type == 'regex_schema': + compiled = self.compiler.compile_regex_grammar(schema) else: - raise TypeError(f'Unsupported instruction type {type(instruction)}') - - mask = torch.full((scores.shape[-1], ), -math.inf, device=scores.device) - mask[allowed_tokens] = 0 - scores.add_(mask) - - return scores - - def adapt_tokenizer(self, tokenizer): - """Adapt tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of `transformers`. In addition we need to handle - the missing spaces to Llama's tokenizer to be able to compile FSMs for this model. - """ - from outlines.integrations.utils import adapt_tokenizer - tokenizer = adapt_tokenizer(tokenizer) - # vocab size greater than logits shape because of '[UNUSED_TOKEN_...]' - if hasattr(tokenizer, '_tokenizer'): - tokenizer.vocabulary = tokenizer._tokenizer.get_vocab(with_added_tokens=False) - return tokenizer - + assert False, f'Do not support schema type {type}' -class RegexLogitsProcessor(BaseLogitsProcessor): + processor = xgr.GrammarMatcher(compiled, terminate_without_stop_token=True) + self.processors.setdefault(session_id, {})[seq_id] = processor + logger.info(f'create guided processor for session_id={session_id}, seq_id={seq_id}, and ' + f'total_processors={len(self.processors)}') + return processor - def __init__(self, regex_string: str, tokenizer): - """Compile the FSM that drives the regex-structured generation. + def remove_processor(self, session_id: int): + if session_id in self.processors: + del self.processors[session_id] + logger.info( + f'delete guided processor for session_id={session_id}, and total_processors={len(self.processors)}') - Args: - regex_string: A string that represents a regular expression - tokenizer: The model's tokenizer - """ - tokenizer = self.adapt_tokenizer(copy.deepcopy(tokenizer)) - fsm = RegexGuide(regex_string, tokenizer) - self.fsm = fsm + def allocate_batched_bitmap(self, batch_size: int) -> torch.Tensor: + return xgr.allocate_token_bitmask(batch_size, self.vocab_size) + def fill_bitmap(self, processor: xgr.GrammarMatcher, guided_bitmask: torch.Tensor, index: int) -> None: + processor.fill_next_token_bitmask(guided_bitmask, index) -class JSONLogitsProcessor(RegexLogitsProcessor): + def accept_token(self, processor: xgr.GrammarMatcher, token: int) -> None: + processor.accept_token(token) - def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer): - """Compile the FSM that drives the JSON-guided generation. + def apply_batched_bitmap(self, logits: torch.Tensor, guided_bitmask: torch.Tensor) -> None: + device = logits.device + dtype = logits.dtype - Args: - schema: A str schema that encodes the structure we want the model - to generate - tokenizer: The model's tokenizer - """ - regex_string = build_regex_from_schema(schema) - super().__init__(regex_string, tokenizer) - - -class CFGLogitsProcessor(BaseLogitsProcessor): - - def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): - """Compile the FSM that drives the context free grammar generation. - - Parameters - ---------- - cfg - A string that represents a context-free grammar - tokenizer - The model's tokenizer - """ - tokenizer = self.adapt_tokenizer(tokenizer) - fsm = CFGGuide(cfg, tokenizer) - self.fsm = fsm - - -# copied from https://github.com/vllm-project/vllm/blob/a7f65c2be93f491771aca31106f790bf381c0bad/vllm/model_executor/guided_decoding/outlines_decoding.py#L31 # noqa -JSON_GRAMMAR = r""" -?start: object | array - -?value: object -| array -| UNESCAPED_STRING -| SIGNED_NUMBER -> number -| "true" -> true -| "false" -> false -| "null" -> null - -array : "[" [value ("," value)*] "]" -object : "{" [pair ("," pair)*] "}" -pair : UNESCAPED_STRING ":" value - -%import common.UNESCAPED_STRING -%import common.SIGNED_NUMBER -%import common.WS - -%ignore WS -""" - - -@lru_cache(maxsize=32) -def _get_guided_logits_processor(guide: str, tokenizer: PreTrainedTokenizerBase, type: str): - try: - if type == 'json_object': - return CFGLogitsProcessor(guide, tokenizer) - elif type == 'json_schema': - return JSONLogitsProcessor(guide, tokenizer) - elif type == 'regex_schema': - return RegexLogitsProcessor(guide, tokenizer) + if device.type in {'cpu', 'cuda'}: + xgr.apply_token_bitmask_inplace(logits, guided_bitmask.to(device)) else: - return None - except Exception as e: - from lmdeploy.utils import get_logger - logger = get_logger('lmdeploy') - logger.error(e) - return None + cpu_logits = logits.cpu().float() + cpu_mask = guided_bitmask.cpu() + xgr.apply_token_bitmask_inplace(cpu_logits, cpu_mask) + logits.copy_(cpu_logits.to(device, dtype)) + + def clear(self) -> None: + self.processors.clear() + logger.info(f'clear guided processors, total_processors={len(self.processors)}') diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index b30fbb3992..2fe3c6150a 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -1,15 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import json from dataclasses import dataclass, fields -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch from lmdeploy.messages import LogitsProcessor -from lmdeploy.tokenizer import Tokenizer from ..messages import SchedulerSequence +from .guided_process import GuidedDecodingMangager def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor): @@ -78,37 +77,6 @@ def _multinomial_sampling(scores: torch.Tensor, return multinomial_sampling(scores, seeds, offsets, indices) -def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided_input_ids: Optional[torch.Tensor], - tokenizer: object): - if guided_input_ids is None: - return scores - for i in range(len(response_formats)): - _format = response_formats[i] - if isinstance(_format, Dict) and _format.get('type', 'text') != 'text': - if _format['type'] == 'json_schema': - schema = _format['json_schema'] - if isinstance(schema, Dict): - for key in ['json_schema', 'schema']: - if key in schema: - schema = json.dumps(schema[key], ensure_ascii=False) - elif schema is None: - from .guided_process import JSON_GRAMMAR - schema = JSON_GRAMMAR - elif isinstance(schema, str): - raise ValueError(f'Cannot parse schema {schema}. The schema must be ' - 'either a dictionary or a string that contains the' - ' JSON Schema specification') - elif _format['type'] == 'regex_schema': - schema = _format.get('regex_schema', '') - else: - raise ValueError(f"unsupported format type: {_format['type']}") - from .guided_process import _get_guided_logits_processor - processor = _get_guided_logits_processor(schema, tokenizer, _format['type']) - if processor: - scores[i] = processor(guided_input_ids[i].tolist(), scores[i]) - return scores - - SeqList = List[SchedulerSequence] @@ -131,9 +99,10 @@ class SamplingInputs: logits_processors: List[List[LogitsProcessor]] = None max_num_logprobs: Optional[int] = None all_ids: Optional[torch.Tensor] = None - guided_input_ids: Optional[torch.Tensor] = None num_ignore_eos: torch.Tensor = None batch_size: int = 0 + session_ctx: Optional[List[Dict[str, Any]]] = None + session_to_cleanup: Optional[List[int]] = None def to_device(self, device: str, non_blocking: bool = False): """To device.""" @@ -160,15 +129,25 @@ def _apply_custom_logits_processors(batched_logits_processors, all_ids, logits): class FusedLogitsProcessor: """Custom logits processor.""" - def __init__(self, - sampling_inputs: SamplingInputs, - tokenizer: Optional[Tokenizer] = None, - sampling_vocab_size: Optional[int] = None, - logprobs_mode: Optional[str] = None): + def __init__( + self, + sampling_inputs: SamplingInputs, + sampling_vocab_size: Optional[int] = None, + logprobs_mode: Optional[str] = None, + guided_decoding_manager: Optional[GuidedDecodingMangager] = None, + ): self.sampling_inputs: SamplingInputs = sampling_inputs - self.tokenizer = tokenizer self.sampling_vocab_size = sampling_vocab_size self.logprobs_mode = logprobs_mode + self.guided_decoding_manager = guided_decoding_manager + if sampling_inputs.session_to_cleanup: + self.cleanup_sessions(sampling_inputs.session_to_cleanup) + + if self.guided_decoding_manager: + self.guided_processors = self.guided_decoding_manager.get_processors(sampling_inputs.session_ctx, + sampling_inputs.response_formats) + else: + self.guided_processors = {} async def _wait_stream_once(self): """Wait stream once.""" @@ -205,9 +184,20 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: sampling_inputs = self.sampling_inputs all_ids = sampling_inputs.all_ids - guided_input_ids = sampling_inputs.guided_input_ids - custom_logits_processors = self.sampling_inputs.logits_processors + if self.guided_decoding_manager and self.guided_processors: + if not hasattr(self, 'guided_bitmask'): + self.guided_bitmask = self.guided_decoding_manager.allocate_batched_bitmap(len(scores)) + + assert self.guided_bitmask is not None + guided_bitmask = self.guided_bitmask + + await self._wait_stream_once() + for i, processor in self.guided_processors.items(): + self.guided_decoding_manager.fill_bitmap(processor, guided_bitmask, i) + + self.guided_decoding_manager.apply_batched_bitmap(scores, guided_bitmask) + if any(custom_logits_processors): await self._wait_stream_once() scores = _apply_custom_logits_processors(custom_logits_processors, all_ids, scores) @@ -232,9 +222,6 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: stop_mask = torch.where(ignore_eos[:, None], stop_mask, False) scores = _process_bad_words_(scores, stop_words, stop_mask) - if guided_input_ids is not None: - await self._wait_stream_once() - scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer) return scores, logprobs @torch.inference_mode() @@ -272,7 +259,7 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): logits = logits[..., :self.sampling_vocab_size] if sampling_inputs.max_top_k == 1: - return logits.argmax(-1) + result = logits.argmax(-1) else: # sort logits is too slow. and we only need topk logits max_topk = sampling_inputs.max_top_k @@ -280,7 +267,13 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): scores, indices = logits.sort(1, descending=True) else: scores, indices = logits.topk(max_topk, dim=1) - return __random_sampling(scores, indices) + result = __random_sampling(scores, indices) + + if self.guided_decoding_manager and self.guided_processors: + for i, processor in self.guided_processors.items(): + self.guided_decoding_manager.accept_token(processor, result[i]) + + return result @torch.inference_mode() def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor): @@ -297,3 +290,8 @@ def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTens indices = torch.cat([indices, topk_indices], dim=-1) return logprobs, indices.to(torch.int32) + + def cleanup_sessions(self, session_ids: List[int]): + if self.guided_decoding_manager: + for session_id in session_ids: + self.guided_decoding_manager.remove_processor(session_id) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 43f42c93be..11918b00e7 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -30,6 +30,7 @@ from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine +from .guided_process import GuidedDecodingMangager from .logits_process import FusedLogitsProcessor, SamplingInputs logger = get_logger('lmdeploy') @@ -352,6 +353,7 @@ def __init__(self, self.patched_model = None self.cache_engine = None self.profiler: AgentProfiler = None + self.guided_decoding_manager = GuidedDecodingMangager(self.tokenizer, self.sampling_vocab_size) # microbatch self.enable_microbatch = self.dist_ctx.dist_config.enable_microbatch @@ -544,10 +546,12 @@ async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: Sam # record function does not support async function # so we can not decorate it on async_sampling_logits with record_function('sampling_logits'): - logits_processor = FusedLogitsProcessor(sampling_inputs, - self.tokenizer, - sampling_vocab_size=self.sampling_vocab_size, - logprobs_mode=self.misc_config.logprobs_mode) + logits_processor = FusedLogitsProcessor( + sampling_inputs, + sampling_vocab_size=self.sampling_vocab_size, + logprobs_mode=self.misc_config.logprobs_mode, + guided_decoding_manager=self.guided_decoding_manager, + ) origin_logits = logits logits, raw_logprobs = await logits_processor(origin_logits) next_token_ids = logits_processor.sampling(logits) @@ -856,6 +860,8 @@ def stop(self): if not self._preprocess_task.done(): self._preprocess_task.cancel() + self.guided_decoding_manager.clear() + async def stop_async(self): """Stop task.""" if self.dist_ctx.dp > 1: @@ -884,6 +890,8 @@ async def stop_async(self): except asyncio.CancelledError: logger.debug('ModelAgent preprocess task cancelled.') + self.guided_decoding_manager.clear() + def set_forward_inputs(self, inputs): """Set forward inputs.""" assert self._pre_in_que is not None, ('Please start backendground task before forward.') diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 4598e02011..dd8fd0c3e6 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -75,10 +75,6 @@ def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: if all_ids is not None: sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1) - guided_input_ids = sampling_inputs.guided_input_ids - if guided_input_ids is not None: - sampling_inputs.guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None]], 1) - return sampling_inputs def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria: diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index b2516f091a..50818d75e8 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -27,22 +27,6 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) return output -def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): - """Gather input ids for guided decode.""" - if not any(sampling_inputs.response_formats or ()): - return None - batch = len(seqs) - max_len = max(seq.num_new_tokens for seq in seqs) - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_new_tokens - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.generated_ids) - output[idx, -h_len:] = h_ids - return output - - def _get_num_ignore_eos(seqs: SeqList): """Get num ignore eos.""" ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] @@ -55,6 +39,7 @@ class ARSamplingStrategy(SamplingStrategy): def __init__(self, pad_token_id: int) -> None: pad_token_id = 0 if pad_token_id is None else pad_token_id self.pad_token_id = pad_token_id + self.session_to_cleanup = [] def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: """Create sampling inputs from the sequences.""" @@ -71,6 +56,8 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: response_formats = [None] * batch_size logits_processors = [None] * batch_size num_logprobs = [None] * batch_size + session_to_cleanup = self.session_to_cleanup + self.session_to_cleanup = [] def __gather_params(): """Gather params.""" @@ -164,6 +151,11 @@ def __get_bad_words(bad_words): max_num_logprobs = max(num_logprobs) + session_ctx = [{ + 'session_id': seq.session.session_id, + 'seq_id': seq.seq_id, + } for seq in seqs] + sampling_input = SamplingInputs( temperature=temperature, bad_words=bad_words, @@ -182,10 +174,14 @@ def __get_bad_words(bad_words): logits_processors=logits_processors, max_num_logprobs=max_num_logprobs, batch_size=batch_size, + session_ctx=session_ctx, + session_to_cleanup=session_to_cleanup, ) pad_token_id = self.pad_token_id sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) - sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input) sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) return sampling_input + + def on_session_end(self, session_id: int): + self.session_to_cleanup.append(session_id) diff --git a/lmdeploy/pytorch/strategies/base/sampling.py b/lmdeploy/pytorch/strategies/base/sampling.py index 172454157b..f69e5af17e 100644 --- a/lmdeploy/pytorch/strategies/base/sampling.py +++ b/lmdeploy/pytorch/strategies/base/sampling.py @@ -15,3 +15,8 @@ class SamplingStrategy(ABC): def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: """Create sampling inputs from the sequences.""" pass + + @abstractmethod + def on_session_end(self, session_id: int) -> None: + """Invoked on session ends.""" + pass diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index 2ad5d5ecd7..45048e25a5 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -35,7 +35,6 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: 'random_seeds', 'random_offsets', 'all_ids', - 'guided_input_ids', 'num_ignore_eos', ] for name in update_attr_names: diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index 984ebdc166..8e037ef521 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -7,7 +7,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.11.1 pillow @@ -24,3 +23,4 @@ torch-npu>=2.3.1,<2.8.0 torchvision>=0.18.1,<0.23.0 transformers uvicorn +xgrammar diff --git a/requirements/runtime_camb.txt b/requirements/runtime_camb.txt index 4ba6ef8462..5b37b003c0 100644 --- a/requirements/runtime_camb.txt +++ b/requirements/runtime_camb.txt @@ -6,7 +6,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.11.1 pillow @@ -21,3 +20,4 @@ torch<=2.6.0,>=2.4.0 torchvision<=0.21.0,>=0.15.0 transformers uvicorn +xgrammar diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index 2e0309062d..f7ac027ee5 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -7,7 +7,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.14.0 pillow @@ -26,3 +25,4 @@ torchvision<=0.23.0,>=0.15.0 transformers triton<=3.4.0,>=3.0.0; sys_platform == "linux" uvicorn +xgrammar diff --git a/requirements/runtime_maca.txt b/requirements/runtime_maca.txt index 19a016cbed..70202d5ce5 100644 --- a/requirements/runtime_maca.txt +++ b/requirements/runtime_maca.txt @@ -6,7 +6,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.11.1 pillow @@ -22,3 +21,4 @@ torchvision<=0.21.0,>=0.15.0 transformers triton>=2.1.0; sys_platform == "linux" uvicorn +xgrammar diff --git a/requirements/runtime_rocm.txt b/requirements/runtime_rocm.txt index 094ca30314..47d6f66fcd 100644 --- a/requirements/runtime_rocm.txt +++ b/requirements/runtime_rocm.txt @@ -6,7 +6,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.14.0 pillow @@ -20,3 +19,4 @@ shortuuid tiktoken transformers uvicorn +xgrammar diff --git a/tests/test_lmdeploy/test_grammar.py b/tests/test_lmdeploy/test_grammar.py index 438d22f6c0..e45b4f1a42 100644 --- a/tests/test_lmdeploy/test_grammar.py +++ b/tests/test_lmdeploy/test_grammar.py @@ -4,7 +4,7 @@ from jsonschema import validate from lmdeploy import pipeline -from lmdeploy.messages import GenerationConfig, TurbomindEngineConfig # , PytorchEngineConfig +from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig MODEL_IDS = [ 'Qwen/Qwen3-0.6B', @@ -13,7 +13,7 @@ BACKEND_FACTORIES = [ ('tm', lambda: TurbomindEngineConfig(max_batch_size=2, session_len=1024)), - # ('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)), + ('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)), ] GUIDE_SCHEMA = {