Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sendnn_inference/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
SENDNN_INFERENCE_REQUIRE_KNOWN_CONFIG: bool = False
SENDNN_INFERENCE_MODEL_CONFIG_FILE: str | None = None
SENDNN_INFERENCE_CPU_MM_DTYPE: torch.dtype = torch.float16
SENDNN_INFERENCE_MAX_TKV_SHIFT_RATIO: float = 1.5
SENDNN_INFERENCE_MAX_SKIP_COUNT: int = 4

logger = init_logger(__name__)

Expand Down Expand Up @@ -152,6 +154,21 @@ def clear_env_cache():
_CPU_MM_DTYPE_PLATFORM_DEFAULTS.get(platform.machine(), "float16"),
)
),
# Chunked-prefill scheduling: soft admission gate on decode-tkv shift.
# When a new prefill candidate would move the decode-batch tkv by more
# than this ratio (new_tkv / current_tkv), the candidate is skipped in
# favor of the next one in the waiting queue. Set to a large value
# (e.g. math.inf) to disable and restore strict FIFO admission.
"SENDNN_INFERENCE_MAX_TKV_SHIFT_RATIO": lambda: float(
os.getenv("SENDNN_INFERENCE_MAX_TKV_SHIFT_RATIO", "1.5")
),
# Chunked-prefill scheduling: anti-starvation bound for the shift-ratio
# gate above. A waiting request that has been skipped this many
# scheduling cycles is force-admitted on the next prefill slot
# regardless of tkv shift.
"SENDNN_INFERENCE_MAX_SKIP_COUNT": lambda: int(
os.getenv("SENDNN_INFERENCE_MAX_SKIP_COUNT", "4")
),
}
# --8<-- [end:env-vars-definition]

Expand Down
100 changes: 96 additions & 4 deletions sendnn_inference/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ def __init__(self, *args, **kwargs) -> None:
"Expecting the env var VLLM_DT_MAX_BATCH_TKV_LIMIT to be set in platform.py"
)

# Soft admission gate: skip a new prefill candidate if it would push
# decode tkv by more than this ratio, unless it has aged out.
self.max_tkv_shift_ratio: float = envs_spyre.SENDNN_INFERENCE_MAX_TKV_SHIFT_RATIO
self.max_skip_count: int = envs_spyre.SENDNN_INFERENCE_MAX_SKIP_COUNT
# Per-request skip counter. Incremented when a candidate is passed over by the
# shift-ratio gate. Removed on admission or request completion.
self._skip_counts: dict[str, int] = {}

def update_from_output(self, scheduler_output, model_runner_output):
assert isinstance(model_runner_output, SpyreModelRunnerOutput), (
"Expecting an instance of CPSpyreModelRunnerOutput when doing chunked prefill."
Expand Down Expand Up @@ -273,10 +281,44 @@ def schedule(self) -> "SchedulerOutput":
while self.skipped_waiting:
holdback_queue.append(self.skipped_waiting.pop_request())

# Check if new requests can be scheduled for prefill
# Check if new requests can be scheduled for prefill.
# The shift-ratio soft gate may skip candidates that would push
# decode tkv too far; skipped requests are set aside here and
# restored to holdback after admission.
skipped_for_shift: deque[Request] = deque()
if len(holdback_queue) > 0:
n_decoders = sum(
1 for r in self.running if r not in self.ongoing_prefills
)
logger.info(
"[tkv-shift-gate] cycle holdback_depth=%d n_decoders=%d tkv=%d",
len(holdback_queue),
n_decoders,
self.tkv,
)
while holdback_queue:
if self.can_schedule_prefill(holdback_queue[0]):
candidate = holdback_queue[0]

if not self._within_tkv_shift_budget(candidate):
skipped = holdback_queue.popleft()
self._skip_counts[skipped.request_id] = (
self._skip_counts.get(skipped.request_id, 0) + 1
)
skipped_for_shift.append(skipped)
logger.info(
"[tkv-shift-gate] skipping req_id=%s prompt_tokens=%d "
"skip_count=%d/%d tkv=%d",
skipped.request_id,
skipped.num_prompt_tokens,
self._skip_counts[skipped.request_id],
self.max_skip_count,
self.tkv,
)
continue

if self.can_schedule_prefill(candidate):
new_request = holdback_queue.popleft()
self._skip_counts.pop(new_request.request_id, None)

logger.debug(
"Scheduling a new request (%d prompt tokens), holding back %d requests",
Expand All @@ -287,10 +329,15 @@ def schedule(self) -> "SchedulerOutput":
# Add request to the waiting queue
self.waiting.append(new_request)
else:
# Otherwise, we simply stop here so that the scheduler
# can work with the batch we have
# Hard constraint failure — stop scanning and let the
# scheduler work with the batch we have
break

# Restore soft-skipped candidates to the front of holdback,
# preserving their original priority order.
while skipped_for_shift:
holdback_queue.appendleft(skipped_for_shift.pop())

assert len(self.ongoing_prefills) <= 1, (
"Only one request can be prefilled at a time, but got %d" % len(self.ongoing_prefills)
)
Expand Down Expand Up @@ -398,6 +445,44 @@ def can_schedule_prefill(self, request: Request) -> bool:

return self._satisfies_constraints(request)

def _within_tkv_shift_budget(self, request: Request) -> bool:
"""Soft admission gate: return False if admitting ``request`` would
push decode-batch tkv by more than ``max_tkv_shift_ratio`` while
decoders are running. Force-admit (return True) once the request
has been skipped ``max_skip_count`` times to prevent starvation.
"""
# Ongoing prefills are past the admission decision.
if request in self.ongoing_prefills:
return True

decoding_requests = [r for r in self.running if r not in self.ongoing_prefills]
# No decodes to protect, or no decode tkv yet.
if not decoding_requests or self.tkv <= 0:
return True

# Anti-starvation override.
if self._skip_counts.get(request.request_id, 0) >= self.max_skip_count:
return True

# Compare block-aligned tkvs
current_tkv = round_up_to_block_size(self.tkv)
new_tkv = round_up_to_block_size(max(self.tkv, request.num_prompt_tokens))
ratio = new_tkv / current_tkv
allowed = ratio <= self.max_tkv_shift_ratio
logger.info(
"[tkv-shift-gate] check req_id=%s prompt_tokens=%d tkv=%d "
"current_tkv=%d new_tkv=%d ratio=%.3f threshold=%.3f allowed=%s",
request.request_id,
request.num_prompt_tokens,
self.tkv,
current_tkv,
new_tkv,
ratio,
self.max_tkv_shift_ratio,
allowed,
)
return allowed

def _satisfies_constraints(self, request: Request) -> bool:
# Use a local variable to check the prefix cache hit length ahead of time without mutating
# request.num_computed_tokens
Expand Down Expand Up @@ -600,6 +685,13 @@ def finish_requests(
else [r for r in self.ongoing_prefills if r.request_id not in request_ids]
)

# Delete skip counters for finished requests.
if request_ids is None:
self._skip_counts.clear()
else:
for request_id in request_ids:
self._skip_counts.pop(request_id, None)

return aborted_requests

def calc_cached_tokens(self, prompt_len: int) -> tuple[int, int]:
Expand Down
3 changes: 3 additions & 0 deletions tests/v1/core/test_scheduler_structured_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def mocked_scheduler():
scheduler.block_size = 64
scheduler.n_free_blocks = 100
scheduler.max_batch_tkv_limit = "8192"
scheduler.max_tkv_shift_ratio = float("inf")
scheduler.max_skip_count = 0
scheduler._skip_counts = {}

# Mock the base scheduler's schedule method and can_schedule_prefill,
# but ChunkedPrefillSpyreScheduler.schedule uses the code implementation
Expand Down
Loading