@@ -95,13 +95,70 @@ def import_kernels(cls) -> None:
9595 # Workaround torch.accelerator.empty_cache for torch 2.7.1 and vllm v0.18.0 compatibility
9696 setattr (torch .accelerator , "empty_cache" , lambda : None ) # noqa
9797
98+ @classmethod
99+ def set_device (cls , device : torch .device ) -> None :
100+ """No-op: Spyre does not require explicit device selection."""
101+
98102 @classmethod
99103 def is_async_output_supported (cls , enforce_eager : bool | None ) -> bool :
100104 """
101105 Check if the current platform supports async output.
102106 """
103107 return False
104108
109+ @classmethod
110+ def get_spyre_scheduler_cls (
111+ cls ,
112+ scheduler_config ,
113+ is_pooling : bool ,
114+ ) -> type :
115+ """Get the appropriate Spyre scheduler class.
116+
117+ This follows the same pattern as vLLM's upstream SchedulerConfig.get_scheduler_cls():
118+ - If scheduler_cls is already set, use it (allows custom schedulers)
119+ - Otherwise, select based on scheduler_config.async_scheduling and model type
120+
121+ The scheduler selection uses factory functions that create classes with the
122+ appropriate base (Scheduler or AsyncScheduler) based on async_scheduling config.
123+
124+ Args:
125+ scheduler_config: The scheduler configuration object
126+ is_pooling: True for pooling/embedding models, False for generative models
127+
128+ Returns:
129+ The scheduler class to use
130+ """
131+ # If a custom scheduler is explicitly set, use it (str or class both fine)
132+ if scheduler_config .scheduler_cls is not None :
133+ return scheduler_config .scheduler_cls
134+
135+ # Import from appropriate module based on async_scheduling config
136+ # These modules have classes created at module level, so they're importable
137+ if scheduler_config .async_scheduling :
138+ # Use async scheduler variants
139+ if is_pooling :
140+ from sendnn_inference .v1 .core .async_scheduler import (
141+ AsyncPoolingSpyreScheduler ,
142+ )
143+
144+ return AsyncPoolingSpyreScheduler
145+ else :
146+ from sendnn_inference .v1 .core .async_scheduler import (
147+ AsyncChunkedPrefillSpyreScheduler ,
148+ )
149+
150+ return AsyncChunkedPrefillSpyreScheduler
151+ else :
152+ # Use sync scheduler variants (default)
153+ if is_pooling :
154+ from sendnn_inference .v1 .core .scheduler import PoolingSpyreScheduler
155+
156+ return PoolingSpyreScheduler
157+ else :
158+ from sendnn_inference .v1 .core .scheduler import ChunkedPrefillSpyreScheduler
159+
160+ return ChunkedPrefillSpyreScheduler
161+
105162 @classmethod
106163 def get_max_batch_tkv_limit (cls ) -> int :
107164 if cls ._max_batch_tkv_limit == 0 :
@@ -219,8 +276,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
219276 os .environ ["FLEX_DEVICE" ] = "COMPILE"
220277
221278 if is_decoder :
222- scheduler_config .scheduler_cls = (
223- "sendnn_inference.v1.core.scheduler.ChunkedPrefillSpyreScheduler"
279+ # Select scheduler using get_spyre_scheduler_cls(), following upstream's pattern
280+ # This checks scheduler_cls first, then async_scheduling flag
281+ # SchedulerConfig.scheduler_cls accepts str | type | None directly.
282+ scheduler_config .scheduler_cls = cls .get_spyre_scheduler_cls (
283+ scheduler_config = scheduler_config ,
284+ is_pooling = False ,
224285 )
225286
226287 if (
@@ -249,8 +310,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
249310 # unsetting this config as it was only set to pass vllm scheduler's max_model_len check
250311 vllm_config .scheduler_config .enable_chunked_prefill = False
251312
252- scheduler_config .scheduler_cls = (
253- "sendnn_inference.v1.core.scheduler.PoolingSpyreScheduler"
313+ # Select scheduler using get_spyre_scheduler_cls(), following upstream's pattern
314+ # SchedulerConfig.scheduler_cls accepts str | type | None directly.
315+ scheduler_config .scheduler_cls = cls .get_spyre_scheduler_cls (
316+ scheduler_config = scheduler_config ,
317+ is_pooling = True ,
254318 )
255319
256320 # Apply model-specific configurations using the registry
@@ -287,8 +351,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
287351 envs_spyre .SENDNN_INFERENCE_DYNAMO_BACKEND ,
288352 )
289353
290- # TODO: try to support async scheduling
291- scheduler_config .async_scheduling = False
354+ logger .info (
355+ "Spyre async scheduling is %s" ,
356+ "enabled" if scheduler_config .async_scheduling else "disabled" ,
357+ )
292358
293359 # To disable any paged attention ops in the base scheduler, we:
294360 # - Set the block size (in tokens) to the maximum sequence length
@@ -304,7 +370,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
304370 scheduler_config .max_num_batched_tokens = (
305371 model_config .max_model_len * scheduler_config .max_num_seqs
306372 )
307- cache_config .block_size = model_config .max_model_len # ty: ignore[invalid-assignment]
373+ cache_config .block_size = model_config .max_model_len
308374 vllm_config .cache_config .enable_prefix_caching = False
309375
310376 else :
@@ -750,7 +816,7 @@ def maybe_ensure_sendnn_configured(cls, model_config: ModelConfig) -> None:
750816 @classmethod
751817 def _set_batch_tkv_limit_from_env (cls ) -> None :
752818 try :
753- cls ._max_batch_tkv_limit = int (os .getenv ("VLLM_DT_MAX_BATCH_TKV_LIMIT" , "-1" )) # ty: ignore
819+ cls ._max_batch_tkv_limit = int (os .getenv ("VLLM_DT_MAX_BATCH_TKV_LIMIT" , "-1" ))
754820 except ValueError as e :
755821 raise ValueError ("VLLM_DT_MAX_BATCH_TKV_LIMIT must be an integer" ) from e
756822
0 commit comments