@@ -20,8 +20,8 @@ def __init__(self, model):
2020 self .model = model
2121
2222 def forward (self , pixel_values ):
23- vit_embeds = self .model .extract_feature (pixel_values )
24- return vit_embeds
23+ vision_embeds = self .model .extract_feature (pixel_values )
24+ return vision_embeds
2525
2626
2727class QEffInternDecoderWrapper (nn .Module ):
@@ -31,21 +31,21 @@ def __init__(self, model):
3131 self .config = self .model .language_model .config
3232 self .language_model = self .model .language_model
3333
34- def forward (self , input_ids , vit_embeds , position_ids , past_key_values ):
34+ def forward (self , input_ids , vision_embeds , position_ids , past_key_values ):
3535 input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
3636 B , N , C = input_embeds .shape
3737 image_input_embeds = input_embeds .reshape (B * N , C )
3838 image_input_ids = input_ids .reshape (B * N )
3939 selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
4040 indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
4141 indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
42- image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
42+ image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
4343 image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
4444 inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
4545 outputs = self .model .language_model (
4646 inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
4747 )
48- return outputs .logits , vit_embeds , outputs .past_key_values
48+ return outputs .logits , vision_embeds , outputs .past_key_values
4949
5050
5151class QEffInternVLModel (nn .Module ):
@@ -122,7 +122,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
122122 lang_dynamic_axes = {}
123123 lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
124124 lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
125- lang_dynamic_axes ["vit_embeds " ] = {0 : "num_patches" }
125+ lang_dynamic_axes ["vision_embeds " ] = {0 : "num_patches" }
126126 vision_dynamic_axes ["pixel_values" ] = {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
127127
128128 pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
@@ -139,15 +139,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
139139 return dynamic_axes
140140
141141 def get_output_names (self , kv_offload : bool = False ):
142- vision_output_names = ["vit_embeds " ]
142+ vision_output_names = ["vision_embeds " ]
143143 lang_output_names = ["logits" ]
144144 for i in range (self .language_model .config .num_hidden_layers ):
145145 for kv in ["key" , "value" ]:
146146 lang_output_names .append (f"past_{ kv } .{ i } _RetainedState" )
147147
148148 output_names = {}
149149 if kv_offload :
150- lang_output_names .insert (1 , "vit_embeds_RetainedState " )
150+ lang_output_names .insert (1 , "vision_embeds_RetainedState " )
151151 output_names ["vision" ] = vision_output_names
152152 output_names ["lang" ] = lang_output_names
153153 else :
@@ -175,7 +175,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
175175 # Define shapes
176176 inputs_shapes = {}
177177 inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
178- inputs_shapes ["vit_embeds " ] = (
178+ inputs_shapes ["vision_embeds " ] = (
179179 constants .INTERN_NUM_PATCHES ,
180180 constants .INTERN_FEATURE_SIZE ,
181181 self .language_model .config .hidden_size ,
@@ -196,7 +196,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
196196 lang_inputs = {}
197197 vision_inputs ["pixel_values" ] = torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
198198 lang_inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
199- lang_inputs ["vit_embeds " ] = torch .zeros ((inputs_shapes ["vit_embeds " ]), dtype = torch .float32 )
199+ lang_inputs ["vision_embeds " ] = torch .zeros ((inputs_shapes ["vision_embeds " ]), dtype = torch .float32 )
200200 lang_inputs ["position_ids" ] = (
201201 torch .arange (constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN , dtype = torch .int64 )
202202 .view (1 , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
@@ -220,21 +220,21 @@ def get_dummy_inputs(self, kv_offload: bool = False):
220220 inputs ["vision" ] = vision_inputs
221221 inputs ["lang" ] = lang_inputs
222222 else :
223- lang_inputs .pop ("vit_embeds " )
223+ lang_inputs .pop ("vision_embeds " )
224224 inputs = {** vision_inputs , ** lang_inputs }
225225
226226 return inputs
227227
228228 def forward (self , input_ids , pixel_values , position_ids , past_key_values ):
229229 input_embeds = self .language_model .get_input_embeddings ()(input_ids )
230- vit_embeds = self .extract_feature (pixel_values )
230+ vision_embeds = self .extract_feature (pixel_values )
231231 B , N , C = input_embeds .shape
232232 image_input_embeds = input_embeds .reshape (B * N , C )
233233 image_input_ids = input_ids .reshape (B * N )
234234 selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
235235 indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
236236 indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
237- image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
237+ image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
238238 image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
239239 inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
240240 outputs = self .language_model (
0 commit comments