Skip to content

Commit 5f12f0e

Browse files
authoredDec 1, 2024··
Fix chunked prefill when ignore eos (#2290)
1 parent d5b95cb commit 5f12f0e

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed
 

‎python/sglang/srt/managers/schedule_policy.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142

143143
self.req_states = None
144144
self.can_run_list = []
145-
self.new_inflight_req = None
145+
self.new_being_chunked_req = None
146146
self.log_hit_tokens = 0
147147
self.log_input_tokens = 0
148148

@@ -182,7 +182,7 @@ def _prefill_one_req(
182182
self.log_hit_tokens += prefix_len
183183
self.log_input_tokens += extend_input_len
184184

185-
def add_inflight_req(self, req: Req):
185+
def add_being_chunked_req(self, req: Req):
186186
truncated = req.extend_input_len > self.rem_chunk_tokens
187187
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
188188
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -269,10 +269,13 @@ def add_req_state(r, insert_sort=False):
269269
else:
270270
# Chunked prefill
271271
trunc_len = self.rem_chunk_tokens
272+
if trunc_len == 0:
273+
return AddReqResult.OTHER
274+
272275
req.extend_input_len = trunc_len
273276
req.fill_ids = req.fill_ids[:trunc_len]
274277
self.can_run_list.append(req)
275-
self.new_inflight_req = req
278+
self.new_being_chunked_req = req
276279
self._prefill_one_req(0, trunc_len, 0)
277280

278281
return self.budget_state()
@@ -326,7 +329,7 @@ def add_one_req(self, req: Req):
326329
req.extend_input_len = trunc_len
327330
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
328331
self.can_run_list.append(req)
329-
self.new_inflight_req = req
332+
self.new_being_chunked_req = req
330333
self.tree_cache.inc_lock_ref(req.last_node)
331334
self._prefill_one_req(prefix_len, trunc_len, 0)
332335

‎python/sglang/srt/managers/scheduler.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def handle_embedding_request(
660660

661661
self.waiting_queue.append(req)
662662

663-
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
663+
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
664664
if isinstance(self.tree_cache, RadixCache):
665665
self.tree_cache_metrics["total"] += (
666666
adder.log_input_tokens + adder.log_hit_tokens
@@ -684,14 +684,14 @@ def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
684684
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
685685
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
686686
f"#running-req: {running_bs}, "
687-
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
687+
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
688688
)
689689

690690
if self.enable_metrics:
691691
self.stats.num_running_reqs = running_bs
692692
self.stats.num_used_tokens = num_used
693693
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
694-
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
694+
self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
695695
self.stats.cache_hit_rate = tree_cache_hit_rate
696696
self.metrics_collector.log_stats(self.stats)
697697

@@ -752,7 +752,7 @@ def get_next_batch_to_run(self):
752752
# Move the chunked request out of the batch
753753
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
754754
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
755-
# Inflight request keeps its rid but will get a new req_pool_idx
755+
# being chunked request keeps its rid but will get a new req_pool_idx
756756
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
757757
self.batch_is_full = False
758758

@@ -803,10 +803,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
803803
running_bs if self.is_mixed_chunk else 0,
804804
)
805805

806-
has_inflight = self.being_chunked_req is not None
807-
if has_inflight:
806+
has_being_chunked = self.being_chunked_req is not None
807+
if has_being_chunked:
808808
self.being_chunked_req.init_next_round_input()
809-
self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
809+
self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
810810

811811
if self.lora_paths:
812812
lora_set = (
@@ -848,16 +848,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
848848
x for x in self.waiting_queue if x not in set(can_run_list)
849849
]
850850

851-
if adder.new_inflight_req is not None:
851+
if adder.new_being_chunked_req is not None:
852852
assert self.being_chunked_req is None
853-
self.being_chunked_req = adder.new_inflight_req
853+
self.being_chunked_req = adder.new_being_chunked_req
854854

855855
if self.being_chunked_req:
856856
self.being_chunked_req.is_being_chunked += 1
857857

858858
# Print stats
859859
if self.tp_rank == 0:
860-
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
860+
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
861861

862862
# Create a new batch
863863
new_batch = ScheduleBatch.init_new(
@@ -1030,7 +1030,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
10301030
if req.grammar is not None:
10311031
req.grammar.accept_token(next_token_id)
10321032
else:
1033-
# Inflight reqs' prefill is not finished
1033+
# being chunked reqs' prefill is not finished
10341034
req.is_being_chunked -= 1
10351035

10361036
if batch.next_batch_sampling_info:
@@ -1058,7 +1058,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
10581058
else:
10591059
self.tree_cache.cache_unfinished_req(req)
10601060
else:
1061-
# Inflight reqs' prefill is not finished
1061+
# being chunked reqs' prefill is not finished
10621062
req.is_being_chunked -= 1
10631063

10641064
self.stream_output(batch.reqs)

0 commit comments

Comments
 (0)
Please sign in to comment.