Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sendnn_inference/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def patched_parse_args(
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
result = original_parse_args(self, args, namespace)
assert result is not None # type: ignore[redundant-expr]
assert result is not None

if args is None or len(args) == 0:
# Don't override anything if there were no args parsed
Expand Down
96 changes: 85 additions & 11 deletions sendnn_inference/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import math
import operator
import os
from typing import TYPE_CHECKING, cast, Literal
from typing import TYPE_CHECKING, Any, cast, Literal

import torch
import huggingface_hub
Expand All @@ -29,17 +29,20 @@
if TYPE_CHECKING:
# NB: We can't eagerly import many things from vllm since vllm.config
# will import this file. These would lead to circular imports
from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.inputs import EngineInput, TokensInput
from vllm.v1.core.sched.scheduler import Scheduler
else:
ModelConfig = None
SchedulerConfig = None
VllmConfig = None
SamplingParams = None
PoolingParams = None
EngineInput = None
TokensInput = None
Scheduler = None
from vllm.platforms import Platform, PlatformEnum

import sendnn_inference.envs as envs_spyre
Expand Down Expand Up @@ -133,13 +136,75 @@ def import_kernels(cls) -> None:
# Workaround torch.accelerator.empty_cache for torch 2.7.1 and vllm v0.18.0 compatibility
setattr(torch.accelerator, "empty_cache", lambda: None) # noqa

@classmethod
def set_device(cls, device: torch.device) -> None:
"""No-op: Spyre does not require explicit device selection."""

@classmethod
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
"""
Check if the current platform supports async output.
"""
return False

@classmethod
def get_spyre_scheduler_cls(
cls,
scheduler_config: "SchedulerConfig",
is_pooling: bool,
) -> "type[Scheduler] | str":
"""Get the appropriate Spyre scheduler class.

This follows the same pattern as vLLM's upstream
``SchedulerConfig.get_scheduler_cls()``:

- If ``scheduler_cls`` is already set, use it (allows custom schedulers).
- For pooling models, always use the sync ``PoolingSpyreScheduler``;
async scheduling is not advantageous for pooling.
- For generative models, select the chunked-prefill variant based on
``scheduler_config.async_scheduling``.

Args:
scheduler_config: The scheduler configuration object
is_pooling: True for pooling/embedding models, False for generative models

Returns:
The scheduler class to use
"""
custom = scheduler_config.scheduler_cls
if custom is not None:
if isinstance(custom, str):
return custom # vLLM resolves dotted paths during engine init
if isinstance(custom, type):
from vllm.v1.core.sched.scheduler import Scheduler as VllmScheduler

if not issubclass(custom, VllmScheduler):
raise TypeError(
"scheduler_cls must be a vLLM Scheduler subclass or a dotted "
f"import path string, got {custom!r}"
)
return custom
raise TypeError(
"scheduler_cls must be a vLLM Scheduler subclass or a dotted import path "
f"string, got {custom!r}"
)

if is_pooling:
from sendnn_inference.v1.core.scheduler import PoolingSpyreScheduler

return PoolingSpyreScheduler

if scheduler_config.async_scheduling:
from sendnn_inference.v1.core.async_scheduler import (
AsyncChunkedPrefillSpyreScheduler,
)

return AsyncChunkedPrefillSpyreScheduler

from sendnn_inference.v1.core.scheduler import ChunkedPrefillSpyreScheduler

return ChunkedPrefillSpyreScheduler

@classmethod
def get_max_batch_tkv_limit(cls) -> int:
if cls._max_batch_tkv_limit == 0:
Expand Down Expand Up @@ -260,8 +325,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
os.environ["FLEX_DEVICE"] = "COMPILE"

if is_decoder:
scheduler_config.scheduler_cls = (
"sendnn_inference.v1.core.scheduler.ChunkedPrefillSpyreScheduler"
scheduler_config.scheduler_cls = cls.get_spyre_scheduler_cls(
scheduler_config=scheduler_config,
is_pooling=False,
)

if (
Expand Down Expand Up @@ -290,8 +356,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# unsetting this config as it was only set to pass vllm scheduler's max_model_len check
vllm_config.scheduler_config.enable_chunked_prefill = False

scheduler_config.scheduler_cls = (
"sendnn_inference.v1.core.scheduler.PoolingSpyreScheduler"
scheduler_config.scheduler_cls = cls.get_spyre_scheduler_cls(
scheduler_config=scheduler_config,
is_pooling=True,
)

# Apply model-specific configurations using the registry
Expand Down Expand Up @@ -328,8 +395,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
envs_spyre.SENDNN_INFERENCE_DYNAMO_BACKEND,
)

# TODO: try to support async scheduling
scheduler_config.async_scheduling = False
selected = scheduler_config.scheduler_cls
selected_display = (
selected if isinstance(selected, str) else getattr(selected, "__name__", repr(selected))
)
logger.info(
"Spyre scheduler class: %s (config async_scheduling=%s)",
selected_display,
scheduler_config.async_scheduling,
)

# To disable any paged attention ops in the base scheduler, we:
# - Set the block size (in tokens) to the maximum sequence length
Expand All @@ -345,7 +419,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config.max_num_batched_tokens = (
model_config.max_model_len * scheduler_config.max_num_seqs
)
cache_config.block_size = model_config.max_model_len # ty: ignore[invalid-assignment]
cache_config.block_size = model_config.max_model_len
vllm_config.cache_config.enable_prefix_caching = False

else:
Expand Down Expand Up @@ -721,7 +795,7 @@ def _patch_tokenizer_registry_get_config(cls) -> None:
"""
import vllm.tokenizers.registry as tokenizer_registry

original_get_config = tokenizer_registry.get_config
original_get_config = cast(Any, tokenizer_registry).get_config

def safe_get_config(*args, **kwargs):
try:
Expand Down Expand Up @@ -821,7 +895,7 @@ def maybe_ensure_sendnn_configured(cls, model_config: ModelConfig) -> None:
@classmethod
def _set_batch_tkv_limit_from_env(cls) -> None:
try:
cls._max_batch_tkv_limit = int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", "-1")) # ty: ignore
cls._max_batch_tkv_limit = int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", "-1"))
except ValueError as e:
raise ValueError("VLLM_DT_MAX_BATCH_TKV_LIMIT must be an integer") from e

Expand Down
23 changes: 23 additions & 0 deletions sendnn_inference/v1/core/async_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0

from vllm.v1.core.sched.async_scheduler import AsyncScheduler

from sendnn_inference.v1.core.scheduler import ChunkedPrefillSpyreScheduler


class AsyncChunkedPrefillSpyreScheduler(ChunkedPrefillSpyreScheduler, AsyncScheduler):
"""Async-scheduling variant of ``ChunkedPrefillSpyreScheduler``.

``ChunkedPrefillSpyreScheduler`` and ``AsyncScheduler`` both subclass
``Scheduler``. For this subclass (only), C3 linearization places
``ChunkedPrefillSpyreScheduler`` before ``AsyncScheduler`` before the
shared ``Scheduler`` base, so inside ``ChunkedPrefillSpyreScheduler``
methods ``super()`` resolves to ``AsyncScheduler``, not ``Scheduler``.
"""

pass


__all__ = [
"AsyncChunkedPrefillSpyreScheduler",
]
Loading
Loading