Skip to content

Introduce async scheduler implementation with mixin pattern#941

Draft
GOavi101 wants to merge 6 commits into
torch-spyre:mainfrom
GOavi101:feature/async-scheduler-mixin-pattern
Draft

Introduce async scheduler implementation with mixin pattern#941
GOavi101 wants to merge 6 commits into
torch-spyre:mainfrom
GOavi101:feature/async-scheduler-mixin-pattern

Conversation

@GOavi101
Copy link
Copy Markdown
Collaborator

@GOavi101 GOavi101 commented Apr 21, 2026

Async Scheduling for Spyre Generative Models

Wires Spyre’s chunked-prefill scheduler to vLLM’s upstream AsyncScheduler so async scheduling (run-ahead / batch queue as implemented by the pinned vLLM v1 engine) can be used for generative models. Spyre-specific reconcile and worker changes keep scheduling correct under optimistic num_computed_tokens and related state.


Background

Upstream async scheduling lets the engine overlap scheduling with execution: the scheduler can advance before the previous step’s outputs are fully committed (run-ahead). In v1 this is tied to the engine’s async / batch-queue execution path (see comments in scheduler.py referencing step_with_batch_queue — exact wiring lives in vLLM, not reimplemented here).

Spyre’s ChunkedPrefillSpyreScheduler keeps extra mutable state (ongoing_prefills, TKV / volumetric admission, previous_step_was_prefill, …) that vanilla Scheduler does not own. Base _update_after_schedule optimistically bumps num_computed_tokens for scheduled prefill chunks; the Spyre runner then commits actual progress (left-padding, prefix-cache hits can differ). Without reconciliation, a speculative schedule(N+1) could read wrong committed vs in-flight positions and make bad admission decisions.


What changed

sendnn_inference/v1/core/async_scheduler.py (new file)

AsyncChunkedPrefillSpyreScheduler is a thin subclass combining ChunkedPrefillSpyreScheduler and the upstream AsyncScheduler via multiple inheritance. Python's MRO ensures super().schedule() hits AsyncScheduler before the base Scheduler. No pooling async variant is created — async scheduling has no benefit for pooling models.

sendnn_inference/v1/core/scheduler.py

  • _inflight_prefill_tokens: dict[str, int] — tracks the optimistic num_computed_tokens delta per request between _update_after_schedule and the runner's actual commit. In sync mode this is a no-op (added and cleared in the same step). Under run-ahead it lets update_from_output reconcile to the runner's real report, and lets the speculative schedule() call distinguish committed from in-flight tokens when evaluating ongoing_prefills and TKV admission.
  • update_from_output — reconciles num_computed_tokens using the inflight delta.
  • finish_requests — pops specific request IDs (or clears on finish-all) so aborted or mid-prefill-finished requests don't leave stale entries.
  • Defensive assert len(_inflight_prefill_tokens) <= max_num_seqs.

sendnn_inference/v1/worker/spyre_worker.py

The async engine calls execute_model and sample_tokens as separate steps. execute_model defers its result to _pending_sample; sample_tokens drains it. Warmup and cleanup paths that must not leave _pending_sample populated use _execute_and_sample instead of calling execute_model directly. Grammar masking falls back gracefully when _spyre_grammar_output is absent.

sendnn_inference/platform.py

get_spyre_scheduler_cls centralises class selection, validates non-string scheduler_cls values as Scheduler subclasses, and logs the selected class name at startup so pooling cases cannot misleadingly appear as async-enabled.

tests/v1/core/test_scheduler.py

New and extended tests: sync/async schedule-update cycles, _inflight_prefill_tokens correctness under normal flow and mid-prefill abort, MRO ordering (mro.index(AsyncScheduler) < mro.index(Scheduler)), finish_requests cleanup.

Related Issues

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

@GOavi101 GOavi101 requested review from dilipgb and joerunde April 21, 2026 08:10
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch 15 times, most recently from 1a3ecbb to b0e8e83 Compare April 22, 2026 17:20
Comment thread vllm_spyre/v1/core/scheduler.py Outdated
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch from b0e8e83 to d71cfb3 Compare April 22, 2026 17:34
Comment thread sendnn_inference/v1/worker/spyre_worker.py Outdated
Comment thread tests/v1/core/test_async_scheduler.py Outdated
Comment thread sendnn_inference/platform.py Outdated
Comment thread sendnn_inference/platform.py Outdated
# The mixin's pre-filter pattern is not safe under that run-ahead scenario.
# For TP=1 (UniProcExecutor), futures are immediately done so it's safe.
if parallel_config.world_size > 1:
scheduler_config.async_scheduling = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting- if we wanted to support this feature then it would likely need to work with TP=4 which is how we run most models. I thought this was only incompatible with pipeline parallel upstream - does it also not work with tensor parallel?

Copy link
Copy Markdown
Collaborator Author

@GOavi101 GOavi101 Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joerunde

The fix is SpyreMultiprocExecutor — a thin MultiprocExecutor subclass that overrides max_concurrent_batches to return 1 instead of 2. This forces the engine to use the simpler step() path (strictly schedule → execute → update) rather than step_with_batch_queue, which was the only thing that broke TP>1.
Spyre's forward pass is synchronous, so there's no compute/schedule overlap to lose. The AsyncScheduler base class and its _update_after_schedule TTFT benefit are still fully active — we just removed the run-ahead that its state tracking couldn't handle.
So TP=1, TP=2, and TP=4 should all work with async scheduling now. Not a blocker.

what do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't quite line up with my understanding- IIUC the step_with_batch_queue method is what works with the speculative scheduling: The engine runs the scheduler again while the model is running, assuming that the requests in the batch will continue.

Spyre's forward pass is synchronous, so there's no compute/schedule overlap to lose

I don't quite understand this either- the multiproc executor is definitely async, it broadcasts an RPC to the workers to run the model and the engine gets back a future that it waits on. step_with_batch_queue queues up that future so that it can speculatively schedule the next pass.

This TP=1 profile shows the scheduler running in between the model forward passes, the goal with async scheduling is to get the scheduler running for the next step during the model forward pass instead:

image

The AsyncScheduler base class and its _update_after_schedule TTFT benefit are still fully active — we just removed the run-ahead that its state tracking couldn't handle.
So TP=1, TP=2, and TP=4 should all work with async scheduling now. Not a blocker.

Based on the above, my understanding is that the run-ahead state is the whole point and we won't gain any performance benefit from this unless we support it, so this is a blocker. Is there something else I'm missing?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, thanks for the correction. I'll fix this — snapshot the mixin's mutable state (ongoing_prefills, tkv, previous_step_was_prefill) before delegating to super().schedule() so the run-ahead second schedule() call sees consistent state, and remove SpyreMultiprocExecutor. That way TP≥2 gets the full async scheduling benefit.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the step_with_batch_queue is actually what makes async scheduling work at all. In the PR where Woosuk added this, he reused the existing batch queue that was originally intended for Pipeline Parallel (PP). With PP you have to wait N steps for each of the N PP stages. But with async scheduling the intention is for the scheduler to be ahead by 1 step, so the num_output_placeholders were introduced to prevent the scheduler from waiting for the current step.

@joerunde
Copy link
Copy Markdown
Collaborator

Thanks @GOavi101!

A few notes:

  1. If this can't be done with tensor parallel, then maybe it's not worth pursuing. Is that a hard blocker?
  2. We need to have an end-to-end test that shows this working, ie using an LLM with async scheduling enabled. It would also be good to include an illustrative test at the engine level (see https://github.com/torch-spyre/sendnn-inference/blob/main/tests/e2e/test_spyre_pc_scheduler_steps.py) that shows the effects of async scheduling. From my quick skim it sounds like the engine is speculatively scheduling batches one step ahead, so we should see a "dead token" in some cases where the engine schedules a decode past the end of a sequence.
  3. It would be really great to see a profile of this in action, or at least some minimal vllm bench results showing what kind of performance improvement we can expect.

@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch 5 times, most recently from 1bd875b to 2246d48 Compare April 23, 2026 10:03
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch 10 times, most recently from 421b610 to 4777fb6 Compare April 28, 2026 19:01
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch from 4777fb6 to 6af4564 Compare April 28, 2026 19:11
Copy link
Copy Markdown
Collaborator

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking on this issue, @GOavi101. I think perhaps it would better to take a step back to re-think this PR a bit.

There are a few unecessary changes in this PR. I would suggest the following:

  1. Remove the AsyncPoolingSpyreScheduler class and the code that selects it in platform.py as async scheduling is not advantageous for pooling
  2. Undo the mixin class structure. Currently the only purpose of this is for _is_async_scheduler() to run isinstance(self, AsyncScheduler) but the same can be achieved by returning vllm_config.scheduler_config.async_scheduling

In PR 19970 Woosuk added async scheduling based in this idea from the NanoFlow paper:

Asynchronous scheduling: Batch formation, including estimating memory 
usage, scheduling new requests, retiring finished requests, and adjusting the 
page table for PagedAttention [17], consumes a non-negligible amount of time 
on the CPU side [42]. In most serving frameworks [17,58], only after the GPU 
executes one iteration, the scheduler on the CPU is able to detect EOS tokens, 
remove the finished request from the batch, and refill the batch with new 
requests. However, GPU is under-utilized during this time. To avoid this waste, 
NanoFlow asynchronously schedules batch formation in parallel to the GPU 
executions. At any iteration i, NanoFlow forms the batch for the next iteration 
before the end of the current iteration. This means that NanoFlow cannot 
detect the EOS tokens generated at iteration i. After launching iteration i+1, 
NanoFlow forms the batch for cycle i+2, detects the EOS token from iteration i, 
and removes finished requests.  Fortunately, since the average decode length 
surpasses 100 for typical workloads (See Table 4), the overhead of one extra 
decode token is negligible (< 1%), given the benefit of hiding the batch 
formation overhead.

In this first PR, no model runner changes were needed. Later additions were made to overlap the model runner execution with other steps such as sampling and the WorkerBase sample_tokens method was added to separate sampling from the execution of the model.

I think that the main problem to solve is that in Spyre we can't run prefills in the same batch as the decode requests. If it weren't for this fact, probably the only required code change would be implementing the sample_tokens() method in the SpyreWorker.
Since we can currently interleave prefill chunk batches and decode batches, in principle there is no problem in doing so asynchronously.
But once prefill is done, we either must wait and add the request in the current decode batch or in the one after that.

Comment thread sendnn_inference/platform.py Outdated
Comment thread sendnn_inference/v1/worker/spyre_worker.py Outdated
Comment thread sendnn_inference/v1/worker/spyre_model_runner.py Outdated
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch 4 times, most recently from 990a1be to eea3259 Compare May 8, 2026 07:54
Comment thread sendnn_inference/v1/core/scheduler.py Outdated
# not yet committed by ``update_from_output``. The committed prefill
# position for ``req`` is
# ``req.num_computed_tokens - self._inflight_prefill_tokens.get(rid, 0)``.
self._inflight_prefill_tokens: dict[str, int] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Upstream the scheduler skips the chunks:

            if request.is_prefill_chunk:
                continue

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_prefill_chunk skip in upstream's _update_after_schedule only avoids bumping num_output_placeholders (which is decode-only). The base Scheduler._update_after_schedule still optimistically advances request.num_computed_tokens by num_scheduled_tokens for prefill chunks — that's what enables schedule(N+1) to pick the next chunk while execute(N) is in flight.
We need _inflight_prefill_tokens because on Spyre the committed prefill position after the runner returns is not necessarily equal to the optimistically-scheduled amount: the model runner adjusts for left-padding and prefix-cache hits and reports back what was actually consumed. Without tracking the per-chunk optimistic delta, update_from_output cannot reconcile num_computed_tokens to the runner's actual report, and (under run-ahead) schedule() cannot tell which portion of num_computed_tokens is committed vs. in-flight when deciding whether a request is still in ongoing_prefills.
In sync mode this is a no-op (added and cleared in the same step). It only matters when the async run-ahead inserts a schedule() call between the optimistic advance and the runner's commit.

Comment thread sendnn_inference/v1/core/scheduler.py Outdated
… collapse mixins

- Remove AsyncPoolingSpyreScheduler; pooling models always use the sync
  PoolingSpyreScheduler (async scheduling is not advantageous for pooling).
- Collapse PoolingSpyreMixin / ChunkedPrefillSpyreMixin into direct
  Scheduler subclasses. AsyncChunkedPrefillSpyreScheduler subclasses
  (ChunkedPrefillSpyreScheduler, AsyncScheduler).
- Replace isinstance-based _is_async_scheduler() with a check on
  vllm_config.scheduler_config.async_scheduling.
- Revert speculative model_runner / worker changes that should land in a
  follow-up that implements SpyreWorker.sample_tokens().
- Update tests accordingly.

Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch from eea3259 to 8b4b55e Compare May 9, 2026 07:56
Avishek Goswami added 2 commits May 9, 2026 18:16
…l runner

The engine's async scheduling path uses the sample_tokens future result as the model_output passed to scheduler.update_from_output, while the execute_model future result is only checked for None. The previous implementation folded sampling into execute_model and returned EMPTY_MODEL_RUNNER_OUTPUT from sample_tokens, causing real outputs to be silently dropped under async scheduling and triggering a scheduler/runner desync.

Split the chunked-prefill runner so execute_model only runs the forward pass and stashes the logits, then sample_tokens consumes the stash, applies the grammar bitmask, runs sampling, and returns the real ModelRunnerOutput. apply_grammar_bitmask now takes grammar_output as a parameter (the previous getattr from scheduler_output never matched anything). A new execute_and_sample helper preserves the combined behaviour for tests and warmup paths that drive the runner directly. The base class gets a default sample_tokens that returns EMPTY_MODEL_RUNNER_OUTPUT for subclasses (e.g. pooling) that fold sampling into the forward pass.

Warmup paths in SpyreWorker invoke execute_model directly; after the split they would leak _pending_sample state. Introduce SpyreWorker._execute_and_sample to drain the deferred sampling step and route warmup callsites through it (cleanup paths, which schedule no tokens, keep using execute_model and return a concrete empty output).

Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
…er-mixin-pattern

Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>

# Conflicts:
#	sendnn_inference/v1/worker/spyre_worker.py
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch from 5ffbab8 to d24a91b Compare May 10, 2026 07:13
@GOavi101 GOavi101 requested a review from maxdebayser May 10, 2026 07:47
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch from d24a91b to 85912af Compare May 10, 2026 07:55
- Reconcile optimistic _inflight_prefill_tokens when the runner reports empty
  req_ids (async incomplete prefill); support generic ModelRunnerOutput via
  duck-typed padding fields.
- Worker: split execute_model/sample_tokens; execute_and_sample helper;
  dense req_id_to_index for sampled_token_ids; reconcile input_batch for
  partial decode schedules.
- Scheduler TKV tests use execute_and_sample; conftest tolerates read-only cwd
  for test-sort artifact.
- Omit temporary SDBG instrumentation.

Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
@GOavi101 GOavi101 force-pushed the feature/async-scheduler-mixin-pattern branch from 85912af to d7d2ac6 Compare May 11, 2026 05:45
- Add _discard_stale_pending_sample for empty schedules and incomplete-prefill
  early returns so deferred sampling cannot leak across steps.
- Route warmup _cleanup_model_runner through _execute_and_sample for a single
  contract with ChunkedPrefillModelRunner.
- Resolve grammar for masking via explicit branches (engine arg vs Spyre
  _spyre_grammar_output on SchedulerOutput).
- Document dense req_id_to_index for sampled_token_ids vs vLLM scheduler indexing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants