1414from  QEfficient .utils .logging_utils  import  logger 
1515
1616
17+ class  QEffInternEncoderWrapper (nn .Module ):
18+     def  __init__ (self , model ):
19+         super ().__init__ ()
20+         self .model  =  model 
21+ 
22+     def  forward (self , pixel_values ):
23+         vit_embeds  =  self .model .extract_feature (pixel_values )
24+         return  vit_embeds 
25+ 
26+ 
27+ class  QEffInternDecoderWrapper (nn .Module ):
28+     def  __init__ (self , model ):
29+         super ().__init__ ()
30+         self .model  =  model 
31+         self .config  =  self .model .language_model .config 
32+ 
33+     def  forward (self , input_ids , vit_embeds , position_ids , past_key_values ):
34+         # TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models 
35+         IMG_CONTEXT_TOKEN  =  151667 
36+ 
37+         input_embeds  =  self .model .language_model .get_input_embeddings ()(input_ids )
38+         B , N , C  =  input_embeds .shape 
39+         image_input_embeds  =  input_embeds .reshape (B  *  N , C )
40+         image_input_ids  =  input_ids .reshape (B  *  N )
41+         selected  =  image_input_ids  ==  IMG_CONTEXT_TOKEN 
42+         indices1  =  selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) -  1 
43+         indices0  =  torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
44+         image_features_expanded  =  vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
45+         image_input_embeds  =  torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
46+         inputs_embeds  =  torch .where (input_ids .shape [1 ] ==  torch .tensor (1 ), input_embeds , image_input_embeds )
47+         outputs  =  self .model .language_model (
48+             inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True 
49+         )
50+         return  outputs .logits , vit_embeds , outputs .past_key_values 
51+ 
52+ 
1753class  QEffInternVLModel (nn .Module ):
54+     def  get_qeff_vision_encoder (self ):
55+         return  QEffInternEncoderWrapper (self )
56+ 
57+     def  get_qeff_language_decoder (self ):
58+         return  QEffInternDecoderWrapper (self )
59+ 
1860    def  get_specializations (
19-         self , batch_size : int , prefill_seq_len : int , ctx_len : int , img_size : int , ** compiler_options 
61+         self ,
62+         batch_size : int ,
63+         prefill_seq_len : int ,
64+         ctx_len : int ,
65+         img_size : int ,
66+         kv_offload : bool  =  False ,
67+         ** compiler_options ,
2068    ):
2169        # TODO: check if this should be named num_patches or something else 
2270        num_patches  =  compiler_options .pop ("num_patches" , None )
@@ -33,8 +81,18 @@ def get_specializations(
3381        elif  img_size  is  None :
3482            img_size  =  448 
3583            logger .warning ("Setting img_size to be 448, as it was neither passed nor found in vision_config" )
36- 
37-         specializations  =  [
84+         if  img_size  !=  448  and  kv_offload :
85+             raise  NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
86+         vision  =  [
87+             {
88+                 "batch_size" : batch_size ,
89+                 "num_patches" : num_patches ,
90+                 "img_size" : img_size ,
91+                 "seq_len" : prefill_seq_len ,
92+                 "ctx_len" : ctx_len ,
93+             }
94+         ]
95+         lang  =  [
3896            {
3997                "batch_size" : batch_size ,
4098                "seq_len" : prefill_seq_len ,
@@ -50,61 +108,92 @@ def get_specializations(
50108                "img_size" : img_size ,
51109            },
52110        ]
53-         return  specializations , compiler_options 
54111
55-     def  get_onnx_dynamic_axes (
56-         self ,
57-     ):
112+         specializations  =  {}
113+ 
114+         if  kv_offload :
115+             specializations ["vision" ] =  vision 
116+             specializations ["lang" ] =  lang 
117+             return  specializations , compiler_options 
118+         else :
119+             return  lang , compiler_options 
120+ 
121+     def  get_onnx_dynamic_axes (self , kv_offload : bool  =  False ):
58122        # Define dynamic axes 
59-         dynamic_axes  =  {}
60-         dynamic_axes ["input_ids" ] =  {0 : "batch_size" , 1 : "seq_len" }
61-         dynamic_axes ["position_ids" ] =  {0 : "batch_size" , 1 : "seq_len" }
62-         dynamic_axes ["pixel_values" ] =  {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
123+         vision_dynamic_axes  =  {}
124+         lang_dynamic_axes  =  {}
125+         lang_dynamic_axes ["input_ids" ] =  {0 : "batch_size" , 1 : "seq_len" }
126+         lang_dynamic_axes ["position_ids" ] =  {0 : "batch_size" , 1 : "seq_len" }
127+         vision_dynamic_axes ["pixel_values" ] =  {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
63128
64129        pkv_dynamic_axes  =  {0 : "batch_size" , 2 : "ctx_len" }
65130        for  i  in  range (self .language_model .config .num_hidden_layers ):
66131            for  kv  in  ["key" , "value" ]:
67-                 dynamic_axes [f"past_{ kv }  .{ i }  " ] =  pkv_dynamic_axes 
132+                 lang_dynamic_axes [f"past_{ kv }  .{ i }  " ] =  pkv_dynamic_axes 
68133
134+         dynamic_axes  =  {}
135+         if  kv_offload :
136+             dynamic_axes ["vision" ] =  vision_dynamic_axes 
137+             dynamic_axes ["lang" ] =  lang_dynamic_axes 
138+         else :
139+             dynamic_axes  =  {** vision_dynamic_axes , ** lang_dynamic_axes }
69140        return  dynamic_axes 
70141
71-     def  get_output_names (
72-         self ,
73-     ):
74-         output_names  =  ["logits" , "pixel_values_RetainedState" ]
142+     def  get_output_names (self , kv_offload : bool  =  False ):
143+         vision_output_names  =  ["vit_embeds" ]
144+         lang_output_names  =  ["logits" ]
75145        for  i  in  range (self .language_model .config .num_hidden_layers ):
76146            for  kv  in  ["key" , "value" ]:
77-                 output_names .append (f"past_{ kv }  .{ i }  _RetainedState" )
147+                 lang_output_names .append (f"past_{ kv }  .{ i }  _RetainedState" )
148+ 
149+         output_names  =  {}
150+         if  kv_offload :
151+             lang_output_names .insert (1 , "vit_embeds_RetainedState" )
152+             output_names ["vision" ] =  vision_output_names 
153+             output_names ["lang" ] =  lang_output_names 
154+         else :
155+             lang_output_names .insert (1 , "pixel_values_RetainedState" )
156+             return  lang_output_names 
78157        return  output_names 
79158
80159    def  get_dummy_inputs (self , kv_offload : bool  =  False ):
81-         if  kv_offload :
82-             raise  ValueError ("kv_offload method not supported for InternVL yet!" )
83160        num_patches  =  13 
84161        C  =  3 
85162        if  vis_cfg  :=  getattr (self .config , "vision_config" , None ):
86163            img_size  =  getattr (vis_cfg , "image_size" , 448 )
87164        else :
88165            img_size  =  448 
166+         if  img_size  !=  448  and  kv_offload :
167+             raise  NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
168+ 
169+         # Taken from the modeling files of OpenGVLab/InternVL2_5-1B 
170+         feature_size  =  int ((((self .config .vision_config .hidden_size ** 0.5 ) *  self .config .downsample_ratio ) **  2 ))
89171
90172        # Define shapes 
91173        inputs_shapes  =  {}
92174        inputs_shapes ["input_ids" ] =  (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
175+         inputs_shapes ["vit_embeds" ] =  (
176+             num_patches ,
177+             feature_size ,
178+             self .language_model .config .hidden_size ,
179+         )
93180        inputs_shapes ["position_ids" ] =  (
94181            constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
95182            constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
96183        )
97184        inputs_shapes ["pixel_values" ] =  (num_patches , C , img_size , img_size )
98185
99186        # Define inputs 
100-         inputs  =  {}
101-         inputs ["input_ids" ] =  torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
102-         inputs ["position_ids" ] =  (
187+         vision_inputs  =  {}
188+         lang_inputs  =  {}
189+         vision_inputs ["pixel_values" ] =  torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
190+         lang_inputs ["input_ids" ] =  torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
191+         lang_inputs ["vit_embeds" ] =  torch .zeros ((inputs_shapes ["vit_embeds" ]), dtype = torch .float32 )
192+         lang_inputs ["position_ids" ] =  (
103193            torch .arange (constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN , dtype = torch .int64 )
104194            .view (1 , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
105195            .repeat (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , 1 )
106196        )
107-         inputs ["pixel_values" ] =  torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
108197
109198        # Add data for KV 
110199        kv_cache_shape  =  get_padding_shape_from_config (
@@ -113,10 +202,18 @@ def get_dummy_inputs(self, kv_offload: bool = False):
113202            seq_len = constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
114203        )
115204
116-         inputs ["past_key_values" ] =  [[] for  _  in  range (self .language_model .config .num_hidden_layers )]
205+         lang_inputs ["past_key_values" ] =  [[] for  _  in  range (self .language_model .config .num_hidden_layers )]
117206        for  i  in  range (self .language_model .config .num_hidden_layers ):
118207            for  kv  in  ["key" , "value" ]:
119-                 inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
208+                 lang_inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
209+ 
210+         inputs  =  {}
211+         if  kv_offload :
212+             inputs ["vision" ] =  vision_inputs 
213+             inputs ["lang" ] =  lang_inputs 
214+         else :
215+             lang_inputs .pop ("vit_embeds" )
216+             inputs  =  {** vision_inputs , ** lang_inputs }
120217
121218        return  inputs 
122219
0 commit comments