Skip to content

Commit c48d689

Browse files
committed
adding Context Length Specialization (CCL)
Signed-off-by: Vahid Janfaza <[email protected]>
1 parent 6a5f283 commit c48d689

File tree

4 files changed

+36
-15
lines changed

4 files changed

+36
-15
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,11 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
808808
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths[prefill_ccl_id]:
809809
prefill_ccl_id += 1
810810
if prefill_ccl_id >= self.prefill_ccl_len:
811-
prefill_ccl_id = self.prefill_ccl_len - 1
811+
prefill_ccl_id = (
812+
(self.prefill_ccl_len - 1)
813+
if self.prefill_ccl_len != 0
814+
else min(prefill_ccl_id, len(self.comp_ctx_lengths) - 1)
815+
)
812816
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[prefill_ccl_id]
813817

814818
chunk_inputs = inputs.copy()
@@ -839,7 +843,6 @@ def initialize_ccl(self, decode_inputs):
839843
ccl_id = i
840844
break
841845

842-
print(f"Decode CCL: {self.comp_ctx_lengths[ccl_id]}")
843846
return ccl_id, max_ccl_id
844847

845848
def run_continuous_batching_decode(self, prompt_queue, generation_len):

QEfficient/transformers/models/modeling_auto.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,11 @@ def kv_offload_generate(
862862
if (i + 1) * prefill_seq_len > self.comp_ctx_lengths[prefill_ccl_id]:
863863
prefill_ccl_id += 1
864864
if prefill_ccl_id >= self.prefill_ccl_len:
865-
prefill_ccl_id = self.prefill_ccl_len - 1
865+
prefill_ccl_id = (
866+
(self.prefill_ccl_len - 1)
867+
if self.prefill_ccl_len != 0
868+
else min(prefill_ccl_id, len(self.comp_ctx_lengths) - 1)
869+
)
866870
chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[prefill_ccl_id]
867871

868872
chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
@@ -1196,7 +1200,11 @@ def cloud_ai_100_generate(
11961200
if (i + 1) * prefill_seq_len > self.comp_ctx_lengths[prefill_ccl_id]:
11971201
prefill_ccl_id += 1
11981202
if prefill_ccl_id >= self.prefill_ccl_len:
1199-
prefill_ccl_id = self.prefill_ccl_len - 1
1203+
prefill_ccl_id = (
1204+
(self.prefill_ccl_len - 1)
1205+
if self.prefill_ccl_len != 0
1206+
else min(prefill_ccl_id, len(self.comp_ctx_lengths) - 1)
1207+
)
12001208
chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[prefill_ccl_id]
12011209

12021210
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
@@ -1784,8 +1792,8 @@ def build_decode_specialization(
17841792
full_batch_size: Optional[int] = None,
17851793
num_speculative_tokens: Optional[int] = None,
17861794
):
1787-
if prefill_seq_len == 1 and not self.continuous_batching:
1788-
return None # Avoid duplication with prefill
1795+
if prefill_seq_len == 1 and not self.continuous_batching and not self.comp_ctx_lengths:
1796+
return None # Avoid duplication with prefill in non-CCL
17891797
spec = {
17901798
"batch_size": full_batch_size if self.continuous_batching else batch_size,
17911799
"seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1,
@@ -1908,6 +1916,8 @@ def compile(
19081916
specializations = []
19091917
if prefill_only is None or prefill_only or prefill_seq_len == 1:
19101918
if self.comp_ctx_lengths is not None:
1919+
if prefill_seq_len != 1 and self.prefill_ccl_len == 0:
1920+
raise ValueError("When prefill_seq_len > 1, prefill_ccl_len must be greater than 0.")
19111921
# Adding elements from self.comp_ctx_lengths to prefill_specialization
19121922
for i in range(0, self.prefill_ccl_len):
19131923
specializations.append(

examples/granite_example/ccl_granitemoe_inference.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818

1919
comp_ctx_lengths = [256, 512, 1024, 2048] # None
2020

21-
## Prefill_ccl_len shows how many numbers in the comp_ctx_lengths list is related to prefilling and the rest would be for decoding. The default value is 1.
22-
prefill_ccl_len = 2
21+
"""
22+
# Prefill_ccl_len shows how many numbers in the comp_ctx_lengths list is related to prefilling and the rest would be for decoding. The default value is 1 means the first value is for prefilling and the rest are for decoding.
23+
# In moe models with prefill_seq_len=1, we can pass prefill_ccl_len=0 to use all ccl values for both prefilling and decoding steps.
24+
"""
25+
prefill_ccl_len = 0
2326

2427
model = QEFFAutoModelForCausalLM.from_pretrained(
2528
model_name, comp_ctx_lengths=comp_ctx_lengths, prefill_ccl_len=prefill_ccl_len, continuous_batching=False

examples/qwen3moe_example/ccl_qwen3moe_inference.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,28 @@
1515
# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function
1616
# We will use prompt_len=1 for compilation for both cb and non-cb inference
1717
"""
18+
ctx_len = 1024
19+
batch_size = 1
20+
comp_ctx_lengths = [128, 256, 512, 1024]
1821

19-
comp_ctx_lengths = [192, 256, 512, 1024] # None
20-
21-
## Prefill_ccl_len shows how many numbers in the comp_ctx_lengths list is related to prefilling and the rest would be for decoding. The default value is 1.
22-
prefill_ccl_len = 2
22+
"""
23+
# Prefill_ccl_len shows how many numbers in the comp_ctx_lengths list is related to prefilling and the rest would be for decoding. The default value is 1 means the first value is for prefilling and the rest are for decoding.
24+
# In moe models with prefill_seq_len=1, we can pass prefill_ccl_len=0 to use all ccl values for both prefilling and decoding steps.
25+
"""
26+
prefill_ccl_len = 0
2327

2428
model = QEFFAutoModelForCausalLM.from_pretrained(
25-
model_name, comp_ctx_lengths=comp_ctx_lengths, prefill_ccl_len=prefill_ccl_len, continuous_batching=True
29+
model_name, comp_ctx_lengths=comp_ctx_lengths, prefill_ccl_len=prefill_ccl_len
2630
)
2731
model.compile(
2832
prefill_seq_len=1,
29-
ctx_len=1024,
30-
full_batch_size=2,
33+
ctx_len=ctx_len,
34+
batch_size=batch_size,
3135
num_cores=16,
3236
num_devices=4,
3337
mxfp6_matmul=True,
3438
mxint8_kv_cache=True,
39+
mos=1,
3540
)
3641
tokenizer = AutoTokenizer.from_pretrained(model_name)
3742
exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer)

0 commit comments

Comments
 (0)