Skip to content

Commit 421b610

Browse files
author
Avishek Goswami
committed
refactor: replace scheduler factory functions with mixin pattern
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 8c6b96f commit 421b610

9 files changed

Lines changed: 1090 additions & 149 deletions

File tree

sendnn_inference/platform.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,70 @@ def import_kernels(cls) -> None:
9595
# Workaround torch.accelerator.empty_cache for torch 2.7.1 and vllm v0.18.0 compatibility
9696
setattr(torch.accelerator, "empty_cache", lambda: None) # noqa
9797

98+
@classmethod
99+
def set_device(cls, device: torch.device) -> None:
100+
"""No-op: Spyre does not require explicit device selection."""
101+
98102
@classmethod
99103
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
100104
"""
101105
Check if the current platform supports async output.
102106
"""
103107
return False
104108

109+
@classmethod
110+
def get_spyre_scheduler_cls(
111+
cls,
112+
scheduler_config,
113+
is_pooling: bool,
114+
) -> type:
115+
"""Get the appropriate Spyre scheduler class.
116+
117+
This follows the same pattern as vLLM's upstream SchedulerConfig.get_scheduler_cls():
118+
- If scheduler_cls is already set, use it (allows custom schedulers)
119+
- Otherwise, select based on scheduler_config.async_scheduling and model type
120+
121+
The scheduler selection uses factory functions that create classes with the
122+
appropriate base (Scheduler or AsyncScheduler) based on async_scheduling config.
123+
124+
Args:
125+
scheduler_config: The scheduler configuration object
126+
is_pooling: True for pooling/embedding models, False for generative models
127+
128+
Returns:
129+
The scheduler class to use
130+
"""
131+
# If a custom scheduler is explicitly set, use it (str or class both fine)
132+
if scheduler_config.scheduler_cls is not None:
133+
return scheduler_config.scheduler_cls
134+
135+
# Import from appropriate module based on async_scheduling config
136+
# These modules have classes created at module level, so they're importable
137+
if scheduler_config.async_scheduling:
138+
# Use async scheduler variants
139+
if is_pooling:
140+
from sendnn_inference.v1.core.async_scheduler import (
141+
AsyncPoolingSpyreScheduler,
142+
)
143+
144+
return AsyncPoolingSpyreScheduler
145+
else:
146+
from sendnn_inference.v1.core.async_scheduler import (
147+
AsyncChunkedPrefillSpyreScheduler,
148+
)
149+
150+
return AsyncChunkedPrefillSpyreScheduler
151+
else:
152+
# Use sync scheduler variants (default)
153+
if is_pooling:
154+
from sendnn_inference.v1.core.scheduler import PoolingSpyreScheduler
155+
156+
return PoolingSpyreScheduler
157+
else:
158+
from sendnn_inference.v1.core.scheduler import ChunkedPrefillSpyreScheduler
159+
160+
return ChunkedPrefillSpyreScheduler
161+
105162
@classmethod
106163
def get_max_batch_tkv_limit(cls) -> int:
107164
if cls._max_batch_tkv_limit == 0:
@@ -219,8 +276,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
219276
os.environ["FLEX_DEVICE"] = "COMPILE"
220277

221278
if is_decoder:
222-
scheduler_config.scheduler_cls = (
223-
"sendnn_inference.v1.core.scheduler.ChunkedPrefillSpyreScheduler"
279+
# Select scheduler using get_spyre_scheduler_cls(), following upstream's pattern
280+
# This checks scheduler_cls first, then async_scheduling flag
281+
# SchedulerConfig.scheduler_cls accepts str | type | None directly.
282+
scheduler_config.scheduler_cls = cls.get_spyre_scheduler_cls(
283+
scheduler_config=scheduler_config,
284+
is_pooling=False,
224285
)
225286

226287
if (
@@ -249,8 +310,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
249310
# unsetting this config as it was only set to pass vllm scheduler's max_model_len check
250311
vllm_config.scheduler_config.enable_chunked_prefill = False
251312

252-
scheduler_config.scheduler_cls = (
253-
"sendnn_inference.v1.core.scheduler.PoolingSpyreScheduler"
313+
# Select scheduler using get_spyre_scheduler_cls(), following upstream's pattern
314+
# SchedulerConfig.scheduler_cls accepts str | type | None directly.
315+
scheduler_config.scheduler_cls = cls.get_spyre_scheduler_cls(
316+
scheduler_config=scheduler_config,
317+
is_pooling=True,
254318
)
255319

256320
# Apply model-specific configurations using the registry
@@ -287,8 +351,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
287351
envs_spyre.SENDNN_INFERENCE_DYNAMO_BACKEND,
288352
)
289353

290-
# TODO: try to support async scheduling
291-
scheduler_config.async_scheduling = False
354+
logger.info(
355+
"Spyre async scheduling is %s",
356+
"enabled" if scheduler_config.async_scheduling else "disabled",
357+
)
292358

293359
# To disable any paged attention ops in the base scheduler, we:
294360
# - Set the block size (in tokens) to the maximum sequence length
@@ -304,7 +370,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
304370
scheduler_config.max_num_batched_tokens = (
305371
model_config.max_model_len * scheduler_config.max_num_seqs
306372
)
307-
cache_config.block_size = model_config.max_model_len # ty: ignore[invalid-assignment]
373+
cache_config.block_size = model_config.max_model_len
308374
vllm_config.cache_config.enable_prefix_caching = False
309375

310376
else:
@@ -750,7 +816,7 @@ def maybe_ensure_sendnn_configured(cls, model_config: ModelConfig) -> None:
750816
@classmethod
751817
def _set_batch_tkv_limit_from_env(cls) -> None:
752818
try:
753-
cls._max_batch_tkv_limit = int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", "-1")) # ty: ignore
819+
cls._max_batch_tkv_limit = int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", "-1"))
754820
except ValueError as e:
755821
raise ValueError("VLLM_DT_MAX_BATCH_TKV_LIMIT must be an integer") from e
756822

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
4+
5+
from sendnn_inference.v1.core.scheduler import (
6+
ChunkedPrefillSpyreMixin,
7+
PoolingSpyreMixin,
8+
)
9+
10+
11+
class AsyncSpyreScheduler(AsyncScheduler):
12+
"""Base class inheriting from the V1 async scheduler to support static
13+
and continuous batching respecting AIU Spyre constraints."""
14+
15+
def __init__(self, *args, **kwargs) -> None:
16+
# Initialize vLLM async scheduler
17+
super().__init__(*args, **kwargs)
18+
self.model_config = self.vllm_config.model_config
19+
20+
21+
class AsyncPoolingSpyreScheduler(PoolingSpyreMixin, AsyncScheduler):
22+
"""Async scheduler for pooling models with Spyre warmup-shape constraints."""
23+
24+
pass
25+
26+
27+
class AsyncChunkedPrefillSpyreScheduler(ChunkedPrefillSpyreMixin, AsyncScheduler):
28+
"""Async scheduler with Spyre chunked-prefill constraints bypassed in async mode."""
29+
30+
pass
31+
32+
33+
__all__ = [
34+
"AsyncPoolingSpyreScheduler",
35+
"AsyncChunkedPrefillSpyreScheduler",
36+
]

0 commit comments

Comments
 (0)