Skip to content

Commit 51f0b5f

Browse files
[Bugfix] Clean up and fix multi-modal processors (vllm-project#13012)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent fde7126 commit 51f0b5f

File tree

7 files changed

+124
-154
lines changed

7 files changed

+124
-154
lines changed

docs/source/features/compatibility_matrix.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
297297
*
298298
*
299299
* ?
300-
* [](gh-issue:7968>)
300+
* [](gh-issue:7968)
301301
* ?
302302
*
303303
*

tests/models/decoder_only/language/test_models.py

+10
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
"google/gemma-1.1-2b-it", # gemma
2727
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
2828
),
29+
pytest.param(
30+
"THUDM/chatglm3-6b", # ChatGLM (text-only)
31+
),
2932
pytest.param(
3033
"meta-llama/Llama-3.2-1B-Instruct", # llama
3134
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
@@ -43,6 +46,9 @@
4346
"microsoft/phi-2", # phi
4447
marks=[pytest.mark.core_model],
4548
),
49+
pytest.param(
50+
"Qwen/Qwen-7B", # qwen (text-only)
51+
),
4652
pytest.param(
4753
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
4854
marks=[pytest.mark.core_model],
@@ -68,6 +74,10 @@ def test_models(
6874
) -> None:
6975

7076
with hf_runner(model, dtype=dtype) as hf_model:
77+
if model.startswith("THUDM/chatglm3"):
78+
hf_model.model.get_output_embeddings = lambda: \
79+
hf_model.model.transformer.output_layer
80+
7181
hf_outputs = hf_model.generate_greedy_logprobs_limit(
7282
example_prompts, max_tokens, num_logprobs)
7383

tests/models/multimodal/processing/test_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _test_processing_correctness(
8989
mm_data = {
9090
k:
9191
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
92-
for _ in range(rng.randint(limit))]
92+
for _ in range(rng.randint(limit + 1))]
9393
for k, limit in limit_mm_per_prompt.items()
9494
}
9595

tests/multimodal/utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@ def random_video(
1717
min_wh: int,
1818
max_wh: int,
1919
):
20-
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
2120
num_frames = rng.randint(min_frames, max_frames)
22-
num_frames = (num_frames // 2) * 2
23-
2421
w, h = rng.randint(min_wh, max_wh, size=(2, ))
2522
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
2623

vllm/model_executor/models/chatglm.py

+63-97
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# https://github.com/THUDM/CogAgent
55
"""Inference-only CogAgent model compatible with THUDM weights."""
66
from argparse import Namespace
7-
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
8-
TypedDict, Union)
7+
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
8+
Union)
99

1010
import torch
1111
from torch import nn
@@ -19,7 +19,6 @@
1919
from vllm.attention import Attention, AttentionMetadata
2020
from vllm.config import CacheConfig, VllmConfig
2121
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
22-
from vllm.logger import init_logger
2322
from vllm.model_executor.layers.activation import SiluAndMul
2423
from vllm.model_executor.layers.layernorm import RMSNorm
2524
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -37,12 +36,10 @@
3736
from vllm.model_executor.sampling_metadata import SamplingMetadata
3837
from vllm.multimodal import MULTIMODAL_REGISTRY
3938
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
40-
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
39+
from vllm.multimodal.parse import MultiModalDataItems
4140
from vllm.multimodal.processing import (BaseMultiModalProcessor,
4241
BaseProcessingInfo, BatchFeature,
43-
BoundPromptReplacement,
4442
MultiModalFieldConfig,
45-
PlaceholderFeaturesInfo,
4643
PromptReplacement)
4744
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
4845
from vllm.sequence import IntermediateTensors
@@ -53,39 +50,6 @@
5350
make_empty_intermediate_tensors_factory, make_layers,
5451
maybe_prefix, merge_multimodal_embeddings)
5552

56-
logger = init_logger(__name__)
57-
58-
IMAGE_TOKEN_ID = 151329
59-
60-
61-
def build_normalization_transform(image_size: int) -> transforms.Compose:
62-
"""
63-
Build a normalization transform which can be applied to one or
64-
more input images from which we want to extract visual features.
65-
66-
Args:
67-
image_size: size of the image to be processed for visual embeddings.
68-
69-
Returns:
70-
Callable transform for normalizing and resizing one RGB image.
71-
"""
72-
73-
return transforms.Compose([
74-
transforms.Resize(
75-
(image_size, image_size),
76-
interpolation=InterpolationMode.BICUBIC,
77-
),
78-
transforms.ToTensor(),
79-
transforms.Normalize(
80-
(0.48145466, 0.4578275, 0.40821073),
81-
(0.26862954, 0.26130258, 0.27577711),
82-
),
83-
])
84-
85-
86-
def calculate_image_placeholder(vision_config):
87-
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
88-
8953

9054
class GLMImagePixelInputs(TypedDict):
9155
pixel_values: torch.Tensor
@@ -109,9 +73,20 @@ def __init__(
10973
self.config = config
11074
self.tokenizer = tokenizer
11175

112-
if hasattr(self.config, "vision_config"):
113-
self.image_transform = build_normalization_transform(
114-
config.vision_config["image_size"])
76+
if vision_config := getattr(config, "vision_config", None):
77+
image_size = vision_config["image_size"]
78+
79+
self.image_transform = transforms.Compose([
80+
transforms.Resize(
81+
(image_size, image_size),
82+
interpolation=InterpolationMode.BICUBIC,
83+
),
84+
transforms.ToTensor(),
85+
transforms.Normalize(
86+
mean=(0.48145466, 0.4578275, 0.40821073),
87+
std=(0.26862954, 0.26130258, 0.27577711),
88+
),
89+
])
11590
else:
11691
self.image_transform = None
11792

@@ -150,9 +125,19 @@ def __call__(
150125

151126
class GLM4VProcessingInfo(BaseProcessingInfo):
152127

153-
def __init__(self, ctx):
154-
super().__init__(ctx)
155-
self._pre_calculate()
128+
def get_tokenizer(self):
129+
tokenizer = self.ctx.tokenizer
130+
assert isinstance(tokenizer, PreTrainedTokenizer)
131+
return tokenizer
132+
133+
def get_hf_config(self):
134+
return self.ctx.get_hf_config(ChatGLMConfig)
135+
136+
def get_hf_processor(self) -> GLM4VProcessor:
137+
return GLM4VProcessor(
138+
self.get_hf_config(),
139+
self.get_tokenizer(),
140+
)
156141

157142
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
158143
return {"image": 1}
@@ -162,27 +147,21 @@ def get_mm_max_tokens_per_item(
162147
seq_len: int,
163148
mm_counts: Mapping[str, int],
164149
) -> Mapping[str, int]:
150+
return {"image": self.get_num_image_feature_tokens()}
165151

166-
return {"image": self.image_token_num + 2}
167-
168-
def _pre_calculate(self):
152+
def get_num_image_tokens(self) -> int:
169153
hf_config = self.get_hf_config()
170-
vision_config = hf_config.vision_config
171-
self.image_token_num = calculate_image_placeholder(vision_config)
172-
self.image_size = vision_config["image_size"]
154+
if not (vision_config := getattr(hf_config, "vision_config", None)):
155+
return 0
173156

174-
def get_num_image_tokens(self) -> int:
175-
return self.image_token_num + 2
157+
image_size = vision_config["image_size"]
158+
patch_size = vision_config["patch_size"]
159+
grid_length = image_size // patch_size // 2
160+
return grid_length * grid_length
176161

177-
def get_image_size(self) -> ImageSize:
178-
179-
return ImageSize(height=self.image_size, width=self.image_size)
180-
181-
def get_hf_processor(self) -> GLM4VProcessor:
182-
return GLM4VProcessor(
183-
self.get_hf_config(),
184-
self.get_tokenizer(),
185-
)
162+
def get_num_image_feature_tokens(self) -> int:
163+
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
164+
return self.get_num_image_tokens() + 2
186165

187166

188167
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
@@ -192,18 +171,24 @@ def get_dummy_processor_inputs(
192171
seq_len: int,
193172
mm_counts: Mapping[str, int],
194173
) -> ProcessorInputs:
174+
hf_config = self.info.get_hf_config()
175+
if not (vision_config := getattr(hf_config, "vision_config", None)):
176+
return ProcessorInputs(prompt_text="", mm_data={})
177+
178+
target_width = target_height = vision_config["image_size"]
195179
num_images = mm_counts.get("image", 0)
196-
target_width, target_height = self.info.get_image_size()
197180

198181
mm_data = {
199182
"image":
200183
self._get_dummy_images(width=target_width,
201184
height=target_height,
202185
num_images=num_images)
203186
}
204-
text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
187+
188+
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
189+
205190
return ProcessorInputs(
206-
prompt_text=text,
191+
prompt_text=base_text * num_images,
207192
mm_data=mm_data,
208193
)
209194

@@ -223,47 +208,28 @@ def _get_prompt_replacements(
223208
hf_processor_mm_kwargs: Mapping[str, object],
224209
out_mm_kwargs: MultiModalKwargs,
225210
) -> list[PromptReplacement]:
211+
hf_config = self.info.get_hf_config()
212+
if not hasattr(hf_config, "vision_config"):
213+
return []
214+
215+
boi_token_id = hf_config.boi_token_id
216+
image_token_id = hf_config.pad_token_id
217+
eoi_token_id = hf_config.eoi_token_id
226218

227219
def get_replacement(item_idx: int):
228-
image_tokens = self.info.image_token_num
229-
return [IMAGE_TOKEN_ID] * image_tokens
220+
num_image_tokens = self.info.get_num_image_tokens()
221+
image_tokens = [image_token_id] * num_image_tokens
222+
223+
return [boi_token_id] + image_tokens + [eoi_token_id]
230224

231225
return [
232226
PromptReplacement(
233227
modality="image",
234-
target=[IMAGE_TOKEN_ID],
228+
target=[boi_token_id, image_token_id, eoi_token_id],
235229
replacement=get_replacement,
236230
),
237231
]
238232

239-
def _apply_prompt_replacements(
240-
self,
241-
token_ids: list[int],
242-
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
243-
mm_item_counts: Mapping[str, int],
244-
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
245-
token_ids, text, placeholders = super()._apply_prompt_replacements(
246-
token_ids=token_ids,
247-
mm_prompt_repls=mm_prompt_repls,
248-
mm_item_counts=mm_item_counts,
249-
)
250-
hf_config = self.info.get_hf_config()
251-
boi_token_id = hf_config.boi_token_id
252-
eoi_token_id = hf_config.eoi_token_id
253-
placeholders = {
254-
modality: [
255-
PlaceholderFeaturesInfo(
256-
modality=p.modality,
257-
item_idx=p.item_idx,
258-
start_idx=p.start_idx - 1,
259-
tokens=[boi_token_id] + p.tokens + [eoi_token_id],
260-
) for p in ps
261-
]
262-
for modality, ps in placeholders.items()
263-
}
264-
265-
return token_ids, text, placeholders
266-
267233

268234
class GLMAttention(nn.Module):
269235

@@ -618,7 +584,7 @@ def get_input_embeddings(
618584
multimodal_embeddings=multimodal_embeddings,
619585
placeholder_token_id=[
620586
self.config.boi_token_id,
621-
IMAGE_TOKEN_ID,
587+
self.config.pad_token_id,
622588
self.config.eoi_token_id,
623589
],
624590
)

0 commit comments

Comments
 (0)