From 7e4c1a8e0667d4f4db596b6ce103b9c4074ed0a0 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Wed, 10 Sep 2025 08:16:00 +0000 Subject: [PATCH] Updated the run_vlm_kv_model_on_pytorch and run_vlm_kv_model_on_ort methods to run for the latest dual QPC setup. Along with the required changes to be made in the Input Handler of VLMs. Also updated the way head_dim is calculated for past_key_value creation as certain models now provide specific head_dim. We fallback to previous method if the parameter isn't found in the config. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/utils/generate_inputs.py | 11 ++++++---- QEfficient/utils/run_utils.py | 32 ++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 361be3080..eb1f7c8e6 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -249,7 +249,7 @@ def prepare_pytorch_inputs(self): num_hidden_layers = txt_cfg.num_hidden_layers num_key_value_heads = txt_cfg.num_key_value_heads - head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads) if hasattr(txt_cfg, "cross_attention_layers"): cross_attention_layers = txt_cfg.cross_attention_layers @@ -287,7 +287,7 @@ def prepare_vlm_ort_inputs(self): txt_cfg = self.config.llm_config num_hidden_layers = txt_cfg.num_hidden_layers num_key_value_heads = txt_cfg.num_key_value_heads - head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads) if hasattr(txt_cfg, "cross_attention_layers"): cross_attention_layers = txt_cfg.cross_attention_layers vis_cfg = self.config.vision_config @@ -298,6 +298,7 @@ def prepare_vlm_ort_inputs(self): if "attention_mask" in inputs.keys(): inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1 inputs["past_key_values"] = [] + inputs["image_idx"] = np.array([[0]]) vision_inputs = { k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} @@ -349,6 +350,7 @@ def update_vlm_ort_outputs(self, ort_outputs): outputs["image_features_RetainedState"] = ( ort_outputs["image_features_RetainedState"] if "image_features_RetainedState" in ort_outputs else None ) + outputs["image_idx"] = ort_outputs["image_idx_output"] return outputs def update_vlm_ort_inputs(self, inputs, ort_outputs): @@ -414,7 +416,7 @@ def prepare_pytorch_inputs(self): num_hidden_layers = txt_cfg.num_hidden_layers num_key_value_heads = txt_cfg.num_key_value_heads - head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads) inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1 inputs["past_key_values"] = [] @@ -435,7 +437,7 @@ def prepare_vlm_ort_inputs(self): txt_cfg = self.config.llm_config num_hidden_layers = txt_cfg.num_hidden_layers num_key_value_heads = txt_cfg.num_key_value_heads - head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + head_dim = getattr(txt_cfg, "head_dim", txt_cfg.hidden_size // txt_cfg.num_attention_heads) question = "\n" + self.prompt pixel_values = self.processor.load_image(self.image, max_num=12) @@ -449,6 +451,7 @@ def prepare_vlm_ort_inputs(self): if "attention_mask" in inputs.keys(): inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1 inputs["past_key_values"] = [] + inputs["image_idx"] = np.array([[0]]) vision_inputs = { k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index aefd511d0..170845e21 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -129,7 +129,6 @@ def run_kv_model_on_pytorch(self, model): generated_ids = [] inputs = self.input_handler.prepare_pytorch_inputs() - pt_outputs = model(**inputs) for _ in range(1, self.gen_len): generated_ids.append(pt_outputs["logits"].argmax(-1).reshape(-1, 1)) @@ -291,9 +290,11 @@ def run_vlm_kv_model_on_pytorch(self, model): generation_len = self.gen_len generated_ids = torch.full((self.batch_size, generation_len), self.processor.tokenizer.pad_token_id) inputs = self.input_handler_vlm.prepare_pytorch_inputs() + inputs["image_idx"] = torch.tensor([[0]]) outputs = model(**inputs) inputs["input_ids"] = outputs[0].argmax(2) + inputs["image_idx"] = outputs[2] if "cross_attention_mask" in inputs: bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64) @@ -308,6 +309,7 @@ def run_vlm_kv_model_on_pytorch(self, model): for num_token in range(1, self.gen_len): outputs = model(**inputs) inputs["input_ids"] = outputs[0].argmax(2) + inputs["image_idx"] = outputs[2] inputs["position_ids"] += 1 streamer.put(inputs["input_ids"]) generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) @@ -363,15 +365,23 @@ def run_vlm_kv_model_on_ort(self, model_path): added_initializers, decoder_session = self.setup_ort_session(decoder_path) generated_ids = [] + finished_sequences = lang_inputs["input_ids"] == self.processor.tokenizer.eos_token_id ort_outputs = self.run_ort_session(lang_inputs, session=decoder_session) ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + lang_inputs = self.input_handler_vlm.update_vlm_ort_inputs(lang_inputs, ort_outputs) + for _ in range(1, self.gen_len): - generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) - lang_inputs = self.input_handler_vlm.update_vlm_ort_inputs(lang_inputs, ort_outputs) + finished_sequences |= lang_inputs["input_ids"] == self.processor.tokenizer.eos_token_id + if finished_sequences.all(): + break + ort_outputs = self.run_ort_session(lang_inputs, decoder_session) ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) - generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + lang_inputs = self.input_handler_vlm.update_vlm_ort_inputs(lang_inputs, ort_outputs) + generated_ids = np.concatenate(generated_ids, axis=1) predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) print("ORT KV_OFFLOAD Session Outputs:") @@ -383,14 +393,22 @@ def run_vlm_kv_model_on_ort(self, model_path): added_initializers, session = self.setup_ort_session(model_path) generated_ids = [] inputs = {**vision_inputs, **lang_inputs} + finished_sequences = inputs["input_ids"] == self.processor.tokenizer.eos_token_id + ort_outputs = self.run_ort_session(inputs, session=session) ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs) + for _ in range(1, self.gen_len): - generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) - inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs) + finished_sequences |= inputs["input_ids"] == self.processor.tokenizer.eos_token_id + if finished_sequences.all(): + break ort_outputs = self.run_ort_session(inputs, session) ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) - generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs) + generated_ids = np.concatenate(generated_ids, axis=1) predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) print("ORT Session Outputs:")