Skip to content

Commit 322b133

Browse files
authored
Add vision id for Qwen3-VL (#4183)
* add vision id * add test case * format * fix typo * add for tm, simplify qwen3 vl proc_msg * tiny rename req_state
1 parent 51cbd2c commit 322b133

25 files changed

+160
-55
lines changed

lmdeploy/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ class EngineOutput:
586586

587587
@dataclass
588588
class VisionConfig:
589-
"""Vison model configs.
589+
"""Vision model configs.
590590
591591
Args:
592592
max_batch_size (int): the max image size passed to the model, since

lmdeploy/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,11 +773,13 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
773773
kwargs.pop('enable_thinking')
774774
if 'reasoning_effort' in kwargs and kwargs.get('reasoning_effort', None) is None:
775775
kwargs.pop('reasoning_effort')
776+
add_vision_id = kwargs.pop('add_vision_id', False)
776777
add_generation_prompt = messages[-1]['role'] != 'assistant'
777778
if sequence_start:
778779
prompt = self.tokenizer.apply_chat_template(messages,
779780
tokenize=False,
780781
add_generation_prompt=add_generation_prompt,
782+
add_vision_id=add_vision_id,
781783
**kwargs)
782784
else:
783785
# Use a sentinel position to avoid the influence of default system role in the tokenizer's chat template
@@ -788,6 +790,7 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
788790
prompt = self.tokenizer.apply_chat_template(sentinel_messages + messages,
789791
tokenize=False,
790792
add_generation_prompt=add_generation_prompt,
793+
add_vision_id=add_vision_id,
791794
**kwargs)
792795
# remove the sentinel part
793796
prompt = prompt[len(sentinel_prompt):]

lmdeploy/serve/async_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ async def generate(
768768
rewind_stop_tokens: bool = False,
769769
input_ids: Optional[List] = None,
770770
enable_thinking: Optional[bool] = None,
771+
add_vision_id: Optional[bool] = False,
771772
**kwargs):
772773
"""Generate responses.
773774
@@ -819,6 +820,7 @@ async def generate(
819820
tools=tools,
820821
reasoning_effort=reasoning_effort,
821822
enable_thinking=enable_thinking,
823+
add_vision_id=add_vision_id,
822824
**kwargs)
823825
prompt = prompt_input['prompt']
824826
input_ids = prompt_input['input_ids']
@@ -889,12 +891,12 @@ def is_error(status):
889891
sequence_end=sequence_end,
890892
step=history_len) as gen:
891893
hit_stop_token = 0
892-
req_state = RequestStats(prompt_tokens=input_len) # per-request stats
894+
req_stats = RequestStats(prompt_tokens=input_len) # per-request stats
893895
async for outputs in gen:
894896
iteration_stats = IterationStats() # per-iteration stats
895897
specdecode_stats = SpeculativeDecodingStats(
896898
self.num_spec_token) if self.num_spec_token > 0 else None
897-
metrics_processor.queue_update((outputs, req_state, iteration_stats, specdecode_stats))
899+
metrics_processor.queue_update((outputs, req_stats, iteration_stats, specdecode_stats))
898900
# decode res
899901
if is_error(outputs.status):
900902
break

lmdeploy/serve/openai/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
469469
do_preprocess=do_preprocess,
470470
adapter_name=adapter_name,
471471
enable_thinking=request.enable_thinking,
472+
add_vision_id=request.add_vision_id,
472473
)
473474

474475
def create_stream_response_json(index: int,

lmdeploy/serve/openai/protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class ChatCompletionRequest(BaseModel):
150150
min_new_tokens: Optional[int] = Field(default=None, examples=[None])
151151
min_p: float = 0.0
152152
enable_thinking: Optional[bool] = None
153+
add_vision_id: Optional[bool] = False
153154
return_token_ids: Optional[bool] = False
154155
include_stop_str_in_output: Optional[bool] = False
155156

lmdeploy/serve/vl_async_engine.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ async def _get_prompt_input(self,
5656
adapter_name: str,
5757
tools: Optional[List[object]] = None,
5858
enable_thinking: Optional[bool] = None,
59+
add_vision_id: Optional[bool] = False,
5960
**kwargs):
6061
"""Process messages and return the required data for the inference
6162
engines.
@@ -70,6 +71,7 @@ async def _get_prompt_input(self,
7071
adapter_name,
7172
tools=tools,
7273
enable_thinking=enable_thinking,
74+
add_vision_id=add_vision_id,
7375
**kwargs)
7476
elif isinstance(messages, List):
7577
has_multimodal_input = any(
@@ -82,6 +84,7 @@ async def _get_prompt_input(self,
8284
adapter_name,
8385
tools,
8486
enable_thinking=enable_thinking,
87+
add_vision_id=add_vision_id,
8588
**kwargs)
8689
else:
8790
raise RuntimeError(f'unsupported messages {messages}')
@@ -101,7 +104,8 @@ async def _get_prompt_input(self,
101104
self.tokenizer,
102105
sequence_start,
103106
tools=tools,
104-
enable_thinking=enable_thinking)
107+
enable_thinking=enable_thinking,
108+
add_vision_id=add_vision_id)
105109
elif self.backend == 'pytorch':
106110
# for pt engine, this module only conduct the image preprocessing
107111
# It leaves the vision embedding to the pt engine
@@ -110,7 +114,8 @@ async def _get_prompt_input(self,
110114
self.tokenizer,
111115
sequence_start,
112116
tools=tools,
113-
enable_thinking=enable_thinking)
117+
enable_thinking=enable_thinking,
118+
add_vision_id=add_vision_id)
114119
return results
115120

116121
@classmethod

lmdeploy/vl/engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ async def wrap_for_pytorch(
6969
sequence_start,
7070
tools: Optional[List[object]] = None,
7171
enable_thinking: Optional[bool] = None,
72+
add_vision_id: Optional[bool] = False,
7273
) -> List[Dict]:
7374
"""
7475
Args:
@@ -93,7 +94,8 @@ async def wrap_for_pytorch(
9394
tokenizer,
9495
sequence_start,
9596
tools=tools,
96-
enable_thinking=enable_thinking)
97+
enable_thinking=enable_thinking,
98+
add_vision_id=add_vision_id)
9799
else:
98100
result = self.model.to_pytorch_with_input_ids(messages)
99101
# clear data
@@ -110,6 +112,7 @@ async def wrap_for_turbomind(
110112
sequence_start,
111113
tools: Optional[List[object]] = None,
112114
enable_thinking: Optional[bool] = None,
115+
add_vision_id: Optional[bool] = False,
113116
) -> Dict:
114117
"""
115118
Args:
@@ -130,7 +133,8 @@ async def wrap_for_turbomind(
130133
tokenizer,
131134
sequence_start,
132135
tools=tools,
133-
enable_thinking=enable_thinking)
136+
enable_thinking=enable_thinking,
137+
add_vision_id=add_vision_id)
134138
# clear data
135139
for i, message in enumerate(messages):
136140
if isinstance(message['content'], List):

lmdeploy/vl/model/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
VISION_MODELS = Registry('vision_model')
1313

1414

15-
class VisonModel(ABC):
15+
class VisionModel(ABC):
1616
"""Visual model which extract image feature."""
1717
_arch: Union[str, List[str]] = None
1818

lmdeploy/vl/model/cogvlm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from typing import Dict, List
33

44
from lmdeploy.utils import get_logger
5-
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
5+
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
66

77
logger = get_logger('lmdeploy')
88

99

1010
@VISION_MODELS.register_module()
11-
class CogVLMVisionModel(VisonModel):
11+
class CogVLMVisionModel(VisionModel):
1212
"""CogVLM vision model."""
1313

1414
_arch = 'CogVLMForCausalLM'

lmdeploy/vl/model/deepseek.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from transformers import AutoModelForCausalLM
77

88
from lmdeploy.utils import get_logger
9-
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
9+
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
1010
from lmdeploy.vl.model.utils import disable_logging
1111

1212
logger = get_logger('lmdeploy')
@@ -23,7 +23,7 @@ def check_deepseek_vl_install():
2323

2424

2525
@VISION_MODELS.register_module()
26-
class DeepSeekVisionModel(VisonModel):
26+
class DeepSeekVisionModel(VisionModel):
2727
"""Qwen vision model."""
2828

2929
_arch = 'MultiModalityCausalLM'

0 commit comments

Comments
 (0)