Skip to content

Commit f66ad3b

Browse files
committed
Inference for OmniSVG
Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent 6aaa75a commit f66ad3b

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import os
2+
3+
import torch
4+
import torch.nn as nn
5+
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Qwen2_5_VLForConditionalGeneration
6+
7+
from QEfficient import QEFFAutoModelForImageTextToText
8+
9+
10+
class SketchDecoder(nn.Module):
11+
"""
12+
Autoregressive generative model
13+
"""
14+
15+
def __init__(self, **kwargs):
16+
super().__init__()
17+
self.vocab_size = 196042
18+
self.bos_token_id = 151643
19+
self.eos_token_id = 196041
20+
self.pad_token_id = 151643
21+
22+
config = AutoConfig.from_pretrained(
23+
"Qwen/Qwen2.5-VL-3B-Instruct",
24+
vocab_size=self.vocab_size,
25+
bos_token_id=self.bos_token_id,
26+
eos_token_id=self.eos_token_id,
27+
pad_token_id=self.pad_token_id,
28+
)
29+
30+
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
31+
"Qwen/Qwen2.5-VL-3B-Instruct", config=config, attn_implementation="eager", ignore_mismatched_sizes=True
32+
)
33+
34+
self.transformer.resize_token_embeddings(self.vocab_size)
35+
36+
def forward(self, *args, **kwargs):
37+
raise NotImplementedError("Forward pass not included in open-source version")
38+
39+
40+
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
41+
sketch_decoder = SketchDecoder()
42+
weight_path = "/home/dipankar/omnisvg/OmniSVG"
43+
sketch_weight_file = os.path.join(weight_path, "pytorch_model.bin")
44+
if not os.path.exists(sketch_weight_file):
45+
raise FileNotFoundError(f"pytorch_model.bin not found in {weight_path}")
46+
sketch_decoder.load_state_dict(torch.load(sketch_weight_file))
47+
sketch_decoder.transformer.eval()
48+
qeff_model = QEFFAutoModelForImageTextToText(sketch_decoder.transformer)
49+
qeff_model.export()
50+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
51+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
52+
path = qeff_model.compile(
53+
batch_size=1,
54+
prefill_seq_len=128,
55+
ctx_len=4096,
56+
num_cores=16,
57+
num_devices=8,
58+
height=354,
59+
width=536,
60+
mxfp6_matmul=False,
61+
aic_enable_depth_first=True,
62+
skip_vision=True,
63+
mos=1,
64+
)

0 commit comments

Comments
 (0)