Skip to content

Commit 2c9e1b4

Browse files
author
Mohit Soni
committed
Testing Changes
Signed-off-by: Mohit Soni <[email protected]>
1 parent 849aee7 commit 2c9e1b4

File tree

4 files changed

+160
-48
lines changed

4 files changed

+160
-48
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -818,12 +818,9 @@ def kv_offload_generate(
818818
in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"}
819819
}
820820

821-
molmo = hasattr(self.model.config, "model_type") and self.model.config.model_type == "molmo"
821+
vision_inputs_fp16 = {"pixel_values", "image_masks"}
822+
vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs})
822823

823-
if vision_inputs:
824-
vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16")
825-
if molmo:
826-
vision_inputs["image_masks"] = vision_inputs["image_masks"].astype("float16")
827824
vision_start = perf_counter()
828825

829826
vision_outputs = {}

QEfficient/transformers/models/molmo/modeling_molmo.py

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,7 @@ def forward(
235235
**kwargs,
236236
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
237237
if not self.config.norm_after:
238-
if self._activation_checkpoint_fn is not None:
239-
atten_in = self._activation_checkpoint_fn(self.attn_norm, x)
240-
else:
241-
atten_in = self.attn_norm(x)
238+
atten_in = self.attn_norm(x)
242239
else:
243240
atten_in = x
244241
qkv = self.att_proj(atten_in)
@@ -249,34 +246,19 @@ def forward(
249246
q, k, v = qkv.split(self.fused_dims, dim=-1)
250247

251248
# Get attention scores.
252-
if self._activation_checkpoint_fn is not None:
253-
att, cache = self._activation_checkpoint_fn( # type: ignore
254-
self.attention,
255-
q,
256-
k,
257-
v,
258-
attention_bias,
259-
position_ids=position_ids,
260-
layer_past=layer_past,
261-
use_cache=use_cache,
262-
)
263-
else:
264-
att, cache = self.attention(
265-
q,
266-
k,
267-
v,
268-
attention_bias,
269-
position_ids=position_ids,
270-
layer_past=layer_past,
271-
batch_index=batch_index,
272-
use_cache=use_cache,
273-
)
249+
att, cache = self.attention(
250+
q,
251+
k,
252+
v,
253+
attention_bias,
254+
position_ids=position_ids,
255+
layer_past=layer_past,
256+
batch_index=batch_index,
257+
use_cache=use_cache,
258+
)
274259

275260
if self.config.norm_after:
276-
if self._activation_checkpoint_fn is not None:
277-
att = self._activation_checkpoint_fn(self.attn_norm, att)
278-
else:
279-
att = self.attn_norm(att)
261+
att = self.attn_norm(att)
280262

281263
# Add attention scores.
282264
# shape: (B, T, C)
@@ -287,23 +269,15 @@ def forward(
287269
og_x = x
288270

289271
if not self.config.norm_after:
290-
if self._activation_checkpoint_fn is not None:
291-
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
292-
else:
293-
x = self.ff_norm(x)
272+
x = self.ff_norm(x)
294273

295274
x = self.ff_proj(x)
296-
if self._activation_checkpoint_fn is not None:
297-
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
298-
else:
299-
x = self.act(x)
275+
276+
x = self.act(x)
300277
x = self.ff_out(x)
301278

302279
if self.config.norm_after:
303-
if self._activation_checkpoint_fn is not None:
304-
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
305-
else:
306-
x = self.ff_norm(x)
280+
x = self.ff_norm(x)
307281

308282
x = self.dropout(x)
309283
x = og_x + x

QEfficient/utils/run_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,36 @@ def run_vlm_hf_model_on_pytorch(self, model, inputs, generation_config):
439439
print("Original HF Model Outputs (Torch CPU):")
440440
print("Completion:", repr(py_output))
441441
return generated_ids
442+
443+
444+
class ApiRunnerMolmo(ApiRunnerVlm):
445+
"""
446+
ApiRunner for Molmo models:
447+
---------
448+
449+
1. HuggingFace ``PyTorch`` model
450+
2. Transformed KV Pytorch Model
451+
3. ``ONNX`` model on ONNXRT
452+
4. ``ONNX`` model on Cloud AI 100
453+
"""
454+
455+
def __init__(self, batch_size, processor, config, image, prompt, prompt_len, ctx_len, max_gen_len, n_layer):
456+
self.processor = processor
457+
self.ctx_len = ctx_len
458+
self.prompt_len = prompt_len
459+
self.batch_size = batch_size
460+
self.config = config
461+
self.gen_len = max_gen_len
462+
463+
@torch.no_grad()
464+
def run_vlm_hf_model_on_pytorch(self, model, inputs, generation_config):
465+
outputs = model.generate_from_batch(
466+
inputs, generation_config, tokenizer=self.processor.tokenizer, do_sample=False
467+
)
468+
469+
generated_ids = outputs[0, inputs["input_ids"].size(1) :]
470+
471+
py_output = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
472+
print("Original HF Model Outputs (Torch CPU):")
473+
print("Completion:", repr(py_output))
474+
return generated_ids

tests/transformers/models/test_image_text_to_text_models.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
AutoModelForImageTextToText,
2020
AutoProcessor,
2121
AutoTokenizer,
22+
GenerationConfig,
2223
TextStreamer,
2324
)
2425

@@ -27,7 +28,7 @@
2728
from QEfficient.utils._utils import create_json, get_num_layers_vlm
2829
from QEfficient.utils.constants import QnnConstants
2930
from QEfficient.utils.device_utils import get_available_device_id
30-
from QEfficient.utils.run_utils import ApiRunnerInternVL, ApiRunnerVlm
31+
from QEfficient.utils.run_utils import ApiRunnerInternVL, ApiRunnerMolmo, ApiRunnerVlm
3132
from QEfficient.utils.test_utils import InternProcessor
3233

3334
NEW_GENERATION_TOKENS = 10
@@ -146,6 +147,19 @@
146147
# ), # commented becuase QNN Convertor is not supported for this model yet.
147148
]
148149

150+
molmo_model_config = [
151+
(
152+
"allenai/Molmo-7B-D-0924",
153+
True,
154+
1,
155+
128,
156+
4096,
157+
"https://picsum.photos/id/237/536/354",
158+
"Can you describe the image in detail.",
159+
2,
160+
),
161+
]
162+
149163

150164
def load_image_text_to_text_model(model_config):
151165
model_path = hf_download(
@@ -185,6 +199,8 @@ def set_num_layers(config, n_layer=1):
185199
elif hasattr(config, "llm_config"):
186200
config.llm_config.num_hidden_layers = n_layer
187201
config.vision_config.num_hidden_layers = n_layer
202+
else:
203+
config.num_hidden_layers = n_layer
188204
return config
189205

190206

@@ -276,6 +292,77 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
276292
return
277293

278294

295+
def check_molmo_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
296+
model_name: str,
297+
img_url: str,
298+
query: str,
299+
prompt_len: int,
300+
ctx_len: int,
301+
max_gen_len: int = 20,
302+
batch_size: int = 1,
303+
n_layer: int = 1,
304+
kv_offload: bool = False,
305+
num_devices: int = 1,
306+
enable_qnn: Optional[bool] = False,
307+
qnn_config: Optional[str] = None,
308+
):
309+
model_config = {"model_name": model_name}
310+
311+
config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True)
312+
config._attn_implementation = "eager"
313+
config = set_num_layers(config, n_layer=n_layer)
314+
model_hf, _ = load_image_text_to_text_model(config)
315+
n_layer = (n_layer, n_layer)
316+
317+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True)
318+
img = requests.get(img_url, stream=True)
319+
image = Image.open(BytesIO(img.content)).convert("RGB")
320+
image = image.resize((536, 354))
321+
322+
api_runner = ApiRunnerMolmo(
323+
batch_size,
324+
processor,
325+
config,
326+
image,
327+
query,
328+
prompt_len,
329+
ctx_len,
330+
max_gen_len,
331+
n_layer,
332+
)
333+
334+
inputs = processor.process(images=[image], text=query)
335+
inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
336+
337+
generation_config = GenerationConfig(max_new_tokens=NEW_GENERATION_TOKENS, stop_strings="<|endoftext|>")
338+
pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs, generation_config)
339+
340+
batch_size, prompt_len = inputs["input_ids"].shape
341+
inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64)
342+
valid = inputs["image_input_idx"] > 0
343+
valid = valid.reshape(1, -1)
344+
inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0)
345+
inputs["pixel_values"] = inputs.pop("images")
346+
347+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
348+
model_config["model_name"],
349+
kv_offload=kv_offload,
350+
config=config,
351+
)
352+
353+
streamer = TextStreamer(processor.tokenizer)
354+
qeff_model.export()
355+
356+
if not get_available_device_id():
357+
pytest.skip("No available devices to run model on Cloud AI 100")
358+
qeff_model.compile(num_devices=num_devices, prefill_seq_len=prompt_len, ctx_len=ctx_len, mxfp6=False)
359+
print("QPC Outputs (QAIC):")
360+
output = qeff_model.generate(inputs=inputs, generation_len=NEW_GENERATION_TOKENS, streamer=streamer)
361+
qpc_tokens = output.generated_ids[:, :-1]
362+
assert (pytorch_hf_tokens == qpc_tokens).all(), "Tokens don't match for pytorch HF output and QPC output"
363+
return
364+
365+
279366
def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
280367
model_name: str,
281368
img_url: str,
@@ -427,6 +514,27 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_qnn(
427514
)
428515

429516

517+
@pytest.mark.on_qaic
518+
@pytest.mark.multimodal
519+
@pytest.mark.parametrize(
520+
"model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer", molmo_model_config
521+
)
522+
def test_image_text_to_text_molmo_pytorch_vs_kv_vs_ort_vs_ai100(
523+
model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer
524+
):
525+
check_molmo_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
526+
model_name=model_name,
527+
prompt_len=prompt_len,
528+
ctx_len=ctx_len,
529+
max_gen_len=NEW_GENERATION_TOKENS,
530+
img_url=img_url,
531+
query=query,
532+
n_layer=n_layer,
533+
batch_size=batch_size,
534+
kv_offload=kv_offload,
535+
)
536+
537+
430538
@pytest.mark.on_qaic
431539
@pytest.mark.multimodal
432540
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)