4
4
# https://github.com/THUDM/CogAgent
5
5
"""Inference-only CogAgent model compatible with THUDM weights."""
6
6
from argparse import Namespace
7
- from typing import (Iterable , List , Mapping , Optional , Sequence , Set , Tuple ,
8
- TypedDict , Union )
7
+ from typing import (Iterable , List , Mapping , Optional , Set , Tuple , TypedDict ,
8
+ Union )
9
9
10
10
import torch
11
11
from torch import nn
19
19
from vllm .attention import Attention , AttentionMetadata
20
20
from vllm .config import CacheConfig , VllmConfig
21
21
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
22
- from vllm .logger import init_logger
23
22
from vllm .model_executor .layers .activation import SiluAndMul
24
23
from vllm .model_executor .layers .layernorm import RMSNorm
25
24
from vllm .model_executor .layers .linear import (MergedColumnParallelLinear ,
37
36
from vllm .model_executor .sampling_metadata import SamplingMetadata
38
37
from vllm .multimodal import MULTIMODAL_REGISTRY
39
38
from vllm .multimodal .inputs import MultiModalKwargs , NestedTensors
40
- from vllm .multimodal .parse import ImageSize , MultiModalDataItems
39
+ from vllm .multimodal .parse import MultiModalDataItems
41
40
from vllm .multimodal .processing import (BaseMultiModalProcessor ,
42
41
BaseProcessingInfo , BatchFeature ,
43
- BoundPromptReplacement ,
44
42
MultiModalFieldConfig ,
45
- PlaceholderFeaturesInfo ,
46
43
PromptReplacement )
47
44
from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
48
45
from vllm .sequence import IntermediateTensors
53
50
make_empty_intermediate_tensors_factory , make_layers ,
54
51
maybe_prefix , merge_multimodal_embeddings )
55
52
56
- logger = init_logger (__name__ )
57
-
58
- IMAGE_TOKEN_ID = 151329
59
-
60
-
61
- def build_normalization_transform (image_size : int ) -> transforms .Compose :
62
- """
63
- Build a normalization transform which can be applied to one or
64
- more input images from which we want to extract visual features.
65
-
66
- Args:
67
- image_size: size of the image to be processed for visual embeddings.
68
-
69
- Returns:
70
- Callable transform for normalizing and resizing one RGB image.
71
- """
72
-
73
- return transforms .Compose ([
74
- transforms .Resize (
75
- (image_size , image_size ),
76
- interpolation = InterpolationMode .BICUBIC ,
77
- ),
78
- transforms .ToTensor (),
79
- transforms .Normalize (
80
- (0.48145466 , 0.4578275 , 0.40821073 ),
81
- (0.26862954 , 0.26130258 , 0.27577711 ),
82
- ),
83
- ])
84
-
85
-
86
- def calculate_image_placeholder (vision_config ):
87
- return (vision_config ["image_size" ] // vision_config ["patch_size" ] // 2 )** 2
88
-
89
53
90
54
class GLMImagePixelInputs (TypedDict ):
91
55
pixel_values : torch .Tensor
@@ -109,9 +73,20 @@ def __init__(
109
73
self .config = config
110
74
self .tokenizer = tokenizer
111
75
112
- if hasattr (self .config , "vision_config" ):
113
- self .image_transform = build_normalization_transform (
114
- config .vision_config ["image_size" ])
76
+ if vision_config := getattr (config , "vision_config" , None ):
77
+ image_size = vision_config ["image_size" ]
78
+
79
+ self .image_transform = transforms .Compose ([
80
+ transforms .Resize (
81
+ (image_size , image_size ),
82
+ interpolation = InterpolationMode .BICUBIC ,
83
+ ),
84
+ transforms .ToTensor (),
85
+ transforms .Normalize (
86
+ mean = (0.48145466 , 0.4578275 , 0.40821073 ),
87
+ std = (0.26862954 , 0.26130258 , 0.27577711 ),
88
+ ),
89
+ ])
115
90
else :
116
91
self .image_transform = None
117
92
@@ -150,9 +125,19 @@ def __call__(
150
125
151
126
class GLM4VProcessingInfo (BaseProcessingInfo ):
152
127
153
- def __init__ (self , ctx ):
154
- super ().__init__ (ctx )
155
- self ._pre_calculate ()
128
+ def get_tokenizer (self ):
129
+ tokenizer = self .ctx .tokenizer
130
+ assert isinstance (tokenizer , PreTrainedTokenizer )
131
+ return tokenizer
132
+
133
+ def get_hf_config (self ):
134
+ return self .ctx .get_hf_config (ChatGLMConfig )
135
+
136
+ def get_hf_processor (self ) -> GLM4VProcessor :
137
+ return GLM4VProcessor (
138
+ self .get_hf_config (),
139
+ self .get_tokenizer (),
140
+ )
156
141
157
142
def get_supported_mm_limits (self ) -> Mapping [str , Optional [int ]]:
158
143
return {"image" : 1 }
@@ -162,27 +147,21 @@ def get_mm_max_tokens_per_item(
162
147
seq_len : int ,
163
148
mm_counts : Mapping [str , int ],
164
149
) -> Mapping [str , int ]:
150
+ return {"image" : self .get_num_image_feature_tokens ()}
165
151
166
- return {"image" : self .image_token_num + 2 }
167
-
168
- def _pre_calculate (self ):
152
+ def get_num_image_tokens (self ) -> int :
169
153
hf_config = self .get_hf_config ()
170
- vision_config = hf_config .vision_config
171
- self .image_token_num = calculate_image_placeholder (vision_config )
172
- self .image_size = vision_config ["image_size" ]
154
+ if not (vision_config := getattr (hf_config , "vision_config" , None )):
155
+ return 0
173
156
174
- def get_num_image_tokens (self ) -> int :
175
- return self .image_token_num + 2
157
+ image_size = vision_config ["image_size" ]
158
+ patch_size = vision_config ["patch_size" ]
159
+ grid_length = image_size // patch_size // 2
160
+ return grid_length * grid_length
176
161
177
- def get_image_size (self ) -> ImageSize :
178
-
179
- return ImageSize (height = self .image_size , width = self .image_size )
180
-
181
- def get_hf_processor (self ) -> GLM4VProcessor :
182
- return GLM4VProcessor (
183
- self .get_hf_config (),
184
- self .get_tokenizer (),
185
- )
162
+ def get_num_image_feature_tokens (self ) -> int :
163
+ # EVA2CLIPModel has embeddings for boi and eoi tokens as well
164
+ return self .get_num_image_tokens () + 2
186
165
187
166
188
167
class GLM4VDummyInputsBuilder (BaseDummyInputsBuilder [GLM4VProcessingInfo ]):
@@ -192,18 +171,24 @@ def get_dummy_processor_inputs(
192
171
seq_len : int ,
193
172
mm_counts : Mapping [str , int ],
194
173
) -> ProcessorInputs :
174
+ hf_config = self .info .get_hf_config ()
175
+ if not (vision_config := getattr (hf_config , "vision_config" , None )):
176
+ return ProcessorInputs (prompt_text = "" , mm_data = {})
177
+
178
+ target_width = target_height = vision_config ["image_size" ]
195
179
num_images = mm_counts .get ("image" , 0 )
196
- target_width , target_height = self .info .get_image_size ()
197
180
198
181
mm_data = {
199
182
"image" :
200
183
self ._get_dummy_images (width = target_width ,
201
184
height = target_height ,
202
185
num_images = num_images )
203
186
}
204
- text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
187
+
188
+ base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
189
+
205
190
return ProcessorInputs (
206
- prompt_text = text ,
191
+ prompt_text = base_text * num_images ,
207
192
mm_data = mm_data ,
208
193
)
209
194
@@ -223,47 +208,28 @@ def _get_prompt_replacements(
223
208
hf_processor_mm_kwargs : Mapping [str , object ],
224
209
out_mm_kwargs : MultiModalKwargs ,
225
210
) -> list [PromptReplacement ]:
211
+ hf_config = self .info .get_hf_config ()
212
+ if not hasattr (hf_config , "vision_config" ):
213
+ return []
214
+
215
+ boi_token_id = hf_config .boi_token_id
216
+ image_token_id = hf_config .pad_token_id
217
+ eoi_token_id = hf_config .eoi_token_id
226
218
227
219
def get_replacement (item_idx : int ):
228
- image_tokens = self .info .image_token_num
229
- return [IMAGE_TOKEN_ID ] * image_tokens
220
+ num_image_tokens = self .info .get_num_image_tokens ()
221
+ image_tokens = [image_token_id ] * num_image_tokens
222
+
223
+ return [boi_token_id ] + image_tokens + [eoi_token_id ]
230
224
231
225
return [
232
226
PromptReplacement (
233
227
modality = "image" ,
234
- target = [IMAGE_TOKEN_ID ],
228
+ target = [boi_token_id , image_token_id , eoi_token_id ],
235
229
replacement = get_replacement ,
236
230
),
237
231
]
238
232
239
- def _apply_prompt_replacements (
240
- self ,
241
- token_ids : list [int ],
242
- mm_prompt_repls : Mapping [str , Sequence [BoundPromptReplacement ]],
243
- mm_item_counts : Mapping [str , int ],
244
- ) -> tuple [list [int ], str , Mapping [str , list [PlaceholderFeaturesInfo ]]]:
245
- token_ids , text , placeholders = super ()._apply_prompt_replacements (
246
- token_ids = token_ids ,
247
- mm_prompt_repls = mm_prompt_repls ,
248
- mm_item_counts = mm_item_counts ,
249
- )
250
- hf_config = self .info .get_hf_config ()
251
- boi_token_id = hf_config .boi_token_id
252
- eoi_token_id = hf_config .eoi_token_id
253
- placeholders = {
254
- modality : [
255
- PlaceholderFeaturesInfo (
256
- modality = p .modality ,
257
- item_idx = p .item_idx ,
258
- start_idx = p .start_idx - 1 ,
259
- tokens = [boi_token_id ] + p .tokens + [eoi_token_id ],
260
- ) for p in ps
261
- ]
262
- for modality , ps in placeholders .items ()
263
- }
264
-
265
- return token_ids , text , placeholders
266
-
267
233
268
234
class GLMAttention (nn .Module ):
269
235
@@ -618,7 +584,7 @@ def get_input_embeddings(
618
584
multimodal_embeddings = multimodal_embeddings ,
619
585
placeholder_token_id = [
620
586
self .config .boi_token_id ,
621
- IMAGE_TOKEN_ID ,
587
+ self . config . pad_token_id ,
622
588
self .config .eoi_token_id ,
623
589
],
624
590
)
0 commit comments