diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index dd9159769..6c11e6a7b 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -322,7 +322,7 @@ def main(args): omni_llm = Omni( model=model_name, log_stats=args.enable_stats, - init_sleep_seconds=args.init_sleep_seconds, + stage_init_timeout=args.stage_init_timeout, batch_timeout=args.batch_timeout, init_timeout=args.init_timeout, shm_threshold_bytes=args.shm_threshold_bytes, @@ -426,10 +426,10 @@ def parse_args(): help="Enable writing detailed statistics (default: disabled)", ) parser.add_argument( - "--init-sleep-seconds", + "--stage-init-timeout", type=int, - default=20, - help="Sleep seconds after starting each stage process to allow initialization (default: 20)", + default=300, + help="Timeout for initializing a single stage in seconds (default: 300)", ) parser.add_argument( "--batch-timeout", diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 4d3ce25cd..a343d6fca 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -224,6 +224,7 @@ def main(args): model=model_name, stage_configs_path=args.stage_configs_path, log_stats=args.enable_stats, + stage_init_timeout=args.stage_init_timeout, ) thinker_sampling_params = SamplingParams( @@ -337,10 +338,10 @@ def parse_args(): help="Enable writing detailed statistics (default: disabled)", ) parser.add_argument( - "--init-sleep-seconds", + "--stage-init-timeout", type=int, - default=20, - help="Sleep seconds after starting each stage process to allow initialization (default: 20)", + default=300, + help="Timeout for initializing a single stage in seconds (default: 300)", ) parser.add_argument( "--batch-timeout", diff --git a/examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh b/examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh index f49ec4814..0cb459eab 100644 --- a/examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh +++ b/examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh @@ -1,5 +1,5 @@ python end2end.py --output-wav output_audio \ --query-type use_audio \ - --init-sleep-seconds 90 + --stage-init-timeout 300 -# init-sleep-seconds works to avoid two vLLM stages initialized at the same time within a card. +# stage-init-timeout sets the maximum wait to avoid two vLLM stages initializing at the same time on the same card. diff --git a/tests/e2e/offline_inference/conftest.py b/tests/e2e/offline_inference/conftest.py index 6322ce766..e53e0ae3b 100644 --- a/tests/e2e/offline_inference/conftest.py +++ b/tests/e2e/offline_inference/conftest.py @@ -26,7 +26,7 @@ def __init__( self, model_name: str, seed: int = 42, - init_sleep_seconds: int = 20, + stage_init_timeout: int = 300, batch_timeout: int = 10, init_timeout: int = 300, shm_threshold_bytes: int = 65536, @@ -40,7 +40,7 @@ def __init__( Args: model_name: The model name or path seed: Random seed for reproducibility - init_sleep_seconds: Sleep time after starting each stage + stage_init_timeout: Timeout for initializing a single stage in seconds batch_timeout: Timeout for batching in seconds init_timeout: Timeout for initializing stages in seconds shm_threshold_bytes: Threshold for using shared memory @@ -54,7 +54,7 @@ def __init__( self.omni = Omni( model=model_name, log_stats=log_stats, - init_sleep_seconds=init_sleep_seconds, + stage_init_timeout=stage_init_timeout, batch_timeout=batch_timeout, init_timeout=init_timeout, shm_threshold_bytes=shm_threshold_bytes, diff --git a/tests/e2e/offline_inference/test_qwen3_omni.py b/tests/e2e/offline_inference/test_qwen3_omni.py index 945b6eaef..b6f6a7590 100644 --- a/tests/e2e/offline_inference/test_qwen3_omni.py +++ b/tests/e2e/offline_inference/test_qwen3_omni.py @@ -27,7 +27,7 @@ def test_video_to_audio(omni_runner: type[OmniRunner], test_config) -> None: """Test processing video, generating audio output.""" model, stage_config_path = test_config - with omni_runner(model, seed=42, stage_configs_path=stage_config_path, init_sleep_seconds=90) as runner: + with omni_runner(model, seed=42, stage_configs_path=stage_config_path, stage_init_timeout=300) as runner: # Prepare inputs question = "Describe the video briefly." video = VideoAsset(name="baby_reading", num_frames=4).np_ndarrays diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index b90e03555..ed2e05e2b 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -130,7 +130,7 @@ def omni_server(request): Multi-stage initialization can take 10-20+ minutes. """ model, stage_config_path = request.param - with OmniServer(model, ["--stage-configs-path", stage_config_path, "--init-sleep-seconds", "90"]) as server: + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: yield server diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index bcc6d5bff..c77b3b179 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -5,9 +5,6 @@ from unittest.mock import MagicMock import pytest -from vllm.sampling_params import SamplingParams - -from vllm_omni.entrypoints.stage_utils import _to_dict # Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. warnings.filterwarnings( @@ -71,7 +68,7 @@ def empty(self): class _FakeStage: """Lightweight Stage stub for multi-process pipeline version with queue support.""" - def __init__(self, config): + def __init__(self, config, stage_init_timeout: int = 300): # Handle both dict and object configs if isinstance(config, dict): config = _FakeStageConfig(config) @@ -95,9 +92,7 @@ def __init__(self, config): self._in_q = None self._out_q = None self._proc = None # Mock process reference - - default_sampling_params = getattr(config, "default_sampling_params", {}) - self.default_sampling_params = SamplingParams(**_to_dict(default_sampling_params)) + self._stage_init_timeout = max(0, int(stage_init_timeout)) def attach_queues(self, in_q, out_q): """Attach input and output queues.""" @@ -481,7 +476,7 @@ def _fake_loader(model: str): # Replace OmniStage monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStage(cfg), + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), raising=False, ) @@ -490,7 +485,7 @@ def _fake_loader(model: str): # Patch the imported function and class in the module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) from vllm_omni.entrypoints.omni import Omni @@ -543,14 +538,14 @@ def _fake_loader(model: str): ) monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStage(cfg), + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), raising=False, ) import vllm_omni.entrypoints.omni as omni_module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) from vllm_omni.entrypoints.omni import Omni @@ -595,14 +590,14 @@ def _fake_loader(model: str): ) monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStage(cfg), + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), raising=False, ) import vllm_omni.entrypoints.omni as omni_module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -688,14 +683,14 @@ def _fake_loader(model: str): ) monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStage(cfg), + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), raising=False, ) import vllm_omni.entrypoints.omni as omni_module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -762,14 +757,14 @@ def _fake_loader(model: str): ) monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStage(cfg), + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), raising=False, ) import vllm_omni.entrypoints.omni as omni_module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -841,14 +836,14 @@ def init_stage_worker(self, *args, **kwargs): monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStageNoReady(cfg), + lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs), raising=False, ) import vllm_omni.entrypoints.omni as omni_module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStageNoReady(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs)) from vllm_omni.entrypoints.omni import Omni @@ -886,14 +881,14 @@ def _fake_loader(model: str): ) monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStage(cfg), + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), raising=False, ) import vllm_omni.entrypoints.omni as omni_module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") @@ -962,14 +957,14 @@ def _fake_loader(model: str): ) monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg: _FakeStage(cfg), + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), raising=False, ) import vllm_omni.entrypoints.omni as omni_module monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) from vllm_omni.entrypoints.omni import Omni diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index bd0ebd393..2a7bbd3c4 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -51,8 +51,9 @@ class AsyncOmni(OmniBase): configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging be written to files with stage-specific suffixes. - - init_sleep_seconds: Number of seconds to sleep between starting - each stage process during initialization + - stage_init_timeout: Per-stage init watchdog (seconds). Measured from + when the previous stage finished (possibly a prior Omni run with GPU + reuse/overlap) to when the current stage starts to initialize. - shm_threshold_bytes: Threshold in bytes for using shared memory for IPC. Objects larger than this threshold will use shared memory. - worker_backend: Backend for worker processes. Default is "multi_process". diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 0485eab97..9235b1a76 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -82,10 +82,10 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Path to the stage configs file. If not specified, the stage configs will be loaded from the model.", ) serve_parser.add_argument( - "--init-sleep-seconds", + "--stage-init-timeout", type=int, - default=30, - help="The number of seconds to sleep before initializing the stages.", + default=300, + help="The timeout for initializing a single stage in seconds (default: 300)", ) serve_parser.add_argument( "--init-timeout", diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index f4abb6c96..b9db56c78 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -67,8 +67,9 @@ class OmniBase: configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging be written to files with stage-specific suffixes. - - init_sleep_seconds: Number of seconds to sleep between starting - each stage process during initialization + - stage_init_timeout: Per-stage init watchdog (seconds). Measured from + when the previous stage finished (possibly a prior Omni run with GPU + reuse/overlap) to when the current stage starts to initialize. - shm_threshold_bytes: Threshold in bytes for using shared memory for IPC. Objects larger than this threshold will use shared memory. - worker_backend: Backend for worker processes. Default is "multi_process". @@ -172,7 +173,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: """Initialize stage list management.""" - init_sleep_seconds = kwargs.get("init_sleep_seconds", 20) + stage_init_timeout = kwargs.get("stage_init_timeout", 20) shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) init_timeout = kwargs.get("init_timeout", 300) worker_backend = kwargs.get("worker_backend", "multi_process") @@ -207,7 +208,7 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: # Build OmniStage instances in parallel, preserve original order def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: idx, cfg = idx_cfg - return idx, OmniStage(cfg) + return idx, OmniStage(cfg, stage_init_timeout=stage_init_timeout) with ThreadPoolExecutor(max_workers=min(len(self.stage_configs), max(1, os.cpu_count() or 1))) as executor: futures = [executor.submit(_build_stage, (idx, cfg)) for idx, cfg in enumerate(self.stage_configs)] @@ -226,7 +227,7 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: self._ctx = mp.get_context("spawn") self._queue_cls = lambda: self._ctx.Queue(maxsize=0) - self._init_sleep_seconds = max(0, int(init_sleep_seconds)) + self._stage_init_timeout = max(0, int(stage_init_timeout)) self._shm_threshold_bytes = max(0, int(shm_threshold_bytes)) self._start_stages(model) # Wait for all stages to report readiness before seeding @@ -264,7 +265,6 @@ def _start_stages(self, model: str) -> None: ) logger.debug(f"[{self._name}] Stage-{stage_id} process started") - time.sleep(self._init_sleep_seconds) def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str, Any]) -> None: self._stages_ready.add(stage_id) @@ -300,7 +300,7 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: "Verify GPU/device assignment in config (runtime.devices) is correct.", "Check GPU/host memory availability; reduce model or batch size if needed.", "Check model weights path and network reachability (if loading remotely).", - "Increase initialization wait time (init_sleep_seconds or call-site timeout).", + "Increase initialization wait time (stage_init_timeout or call-site timeout).", ] ) logger.error( @@ -360,8 +360,9 @@ class Omni(OmniBase): configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging be written to files with stage-specific suffixes. - - init_sleep_seconds: Number of seconds to sleep between starting - each stage process during initialization + - stage_init_timeout: Per-stage init watchdog (seconds). Measured from + when the previous stage finished (possibly a prior Omni run with GPU + reuse/overlap) to when the current stage starts to initialize. - shm_threshold_bytes: Threshold in bytes for using shared memory for IPC. Objects larger than this threshold will use shared memory. - worker_backend: Backend for worker processes. Default is "multi_process". diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 814e26009..b0c4be08e 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -101,7 +101,7 @@ class OmniStage: runtime settings, and stage-specific parameters """ - def __init__(self, stage_config: Any): + def __init__(self, stage_config: Any, stage_init_timeout: int = 300): logger.info(f"[OmniStage] stage_config: {stage_config}") self.stage_config = stage_config self.engine = None @@ -139,6 +139,7 @@ def __init__(self, stage_config: Any): self._out_q: mp.Queue | None = None self._proc: mp.Process | None = None self._shm_threshold_bytes: int = 65536 + self._stage_init_timeout: int = stage_init_timeout def set_engine(self, engine: LLMEngine) -> None: """Set the LLM engine for this stage. @@ -272,6 +273,7 @@ def init_stage_worker( model=model, stage_payload=stage_payload, batch_timeout=batch_timeout, + stage_init_timeout=self._stage_init_timeout, ) else: self._ray_actor = start_ray_actor( @@ -283,6 +285,7 @@ def init_stage_worker( in_q=self._in_q, out_q=self._out_q, batch_timeout=batch_timeout, + stage_init_timeout=self._stage_init_timeout, ) else: if is_async: @@ -293,6 +296,7 @@ def init_stage_worker( model, stage_payload, batch_timeout, + self._stage_init_timeout, ), ) else: @@ -304,6 +308,7 @@ def init_stage_worker( self._in_q, self._out_q, batch_timeout, + self._stage_init_timeout, ), ) self._proc.start() @@ -420,6 +425,7 @@ def _stage_worker( in_q: mp.Queue, out_q: mp.Queue, batch_timeout: int = 10, + stage_init_timeout: int = 300, ) -> None: """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" # Use local aliases to avoid conflicts with global imports in worker process @@ -525,7 +531,6 @@ def _stage_worker( # Acquire exclusive locks for all devices using fcntl.flock # Locks are automatically released when process dies - max_wait_time = 300 # 5 minutes max wait wait_start = _time.time() acquired_lock_fds = [] # Store file descriptors to keep locks alive @@ -553,7 +558,7 @@ def _stage_worker( _os.close(lock_fd) # Check if we've been waiting too long - if _time.time() - wait_start > max_wait_time: + if _time.time() - wait_start > stage_init_timeout: logger.warning( "Timeout waiting for device %s initialization lock, proceeding anyway", device_id, @@ -852,8 +857,9 @@ def _stage_worker_async_entry( model: str, stage_payload: dict[str, Any], batch_timeout: int = 10, + stage_init_timeout: int = 300, ) -> None: - asyncio.run(_stage_worker_async(omni_stage, model, stage_payload, batch_timeout)) + asyncio.run(_stage_worker_async(omni_stage, model, stage_payload, batch_timeout, stage_init_timeout)) async def _stage_worker_async( @@ -861,6 +867,7 @@ async def _stage_worker_async( model: str, stage_payload: dict[str, Any], batch_timeout: int = 10, + stage_init_timeout: int = 300, ) -> None: """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" # Use local aliases to avoid conflicts with global imports in worker process @@ -969,7 +976,6 @@ async def _stage_worker_async( # Acquire exclusive locks for all devices using fcntl.flock # Locks are automatically released when process dies - max_wait_time = 300 # 5 minutes max wait wait_start = _time.time() acquired_lock_fds = [] # Store file descriptors to keep locks alive @@ -997,10 +1003,12 @@ async def _stage_worker_async( _os.close(lock_fd) # Check if we've been waiting too long - if _time.time() - wait_start > max_wait_time: + if _time.time() - wait_start > stage_init_timeout: logger.warning( - "Timeout waiting for device %s initialization lock, proceeding anyway", + "Timeout waiting for device %s initialization lock, " + "proceeding anyway with timeout %s", device_id, + stage_init_timeout, ) break