From 02222b8c55bddc81b61ff15d1d419f906d0221b5 Mon Sep 17 00:00:00 2001 From: manickavela29 Date: Sun, 2 Nov 2025 16:38:41 +0000 Subject: [PATCH 1/4] mulitplexing with batching Signed-off-by: manickavela29 --- .../_private/request_router/request_router.py | 46 ++ python/ray/serve/api.py | 20 +- python/ray/serve/multiplex.py | 143 +++++- .../serve/tests/test_multiplex_batching.py | 378 +++++++++++++++ .../tests/test_multiplex_batching_router.py | 458 ++++++++++++++++++ .../tests/test_multiplex_batching_utils.py | 408 ++++++++++++++++ 6 files changed, 1446 insertions(+), 7 deletions(-) create mode 100644 python/ray/serve/tests/test_multiplex_batching.py create mode 100644 python/ray/serve/tests/test_multiplex_batching_router.py create mode 100644 python/ray/serve/tests/test_multiplex_batching_utils.py diff --git a/python/ray/serve/_private/request_router/request_router.py b/python/ray/serve/_private/request_router/request_router.py index 32ada250a305..a42605fd2d38 100644 --- a/python/ray/serve/_private/request_router/request_router.py +++ b/python/ray/serve/_private/request_router/request_router.py @@ -196,6 +196,9 @@ class MultiplexMixin: It adds necessary attributes and methods to keep track of multiplexed model IDs and offer the helpers to apply multiplex routing and rank replicas based on multiplexed model IDs. + + Now supports batching-aware routing to group requests by model ID + for optimal batching performance. """ def __init__(self, *args, **kwargs): @@ -211,6 +214,9 @@ def __init__(self, *args, **kwargs): self._multiplexed_model_id_fallback_match: Set[str] = set() self._replica_id_set: Set[ReplicaID] = set() self._replicas: Dict[ReplicaID, RunningReplica] = {} + + # Batching-aware routing: track pending requests by model ID for better batching + self._pending_requests_by_model_id: DefaultDict[str, List] = defaultdict(list) def _get_pending_request_matching_multiplexed_model_id( self, @@ -228,6 +234,27 @@ def _get_pending_request_matching_multiplexed_model_id( ): return pr + def _track_pending_request_by_model_id(self, pending_request: PendingRequest): + """Track pending requests by model ID for batching-aware routing.""" + if pending_request.metadata.multiplexed_model_id: + model_id = pending_request.metadata.multiplexed_model_id + self._pending_requests_by_model_id[model_id].append(pending_request) + + def _get_pending_requests_for_model(self, model_id: str) -> List[PendingRequest]: + """Get all pending requests for a specific model ID.""" + return [pr for pr in self._pending_requests_by_model_id[model_id] + if not pr.future.done()] + + def _cleanup_completed_pending_requests(self): + """Clean up completed requests from model ID tracking.""" + for model_id in list(self._pending_requests_by_model_id.keys()): + self._pending_requests_by_model_id[model_id] = [ + pr for pr in self._pending_requests_by_model_id[model_id] + if not pr.future.done() + ] + if not self._pending_requests_by_model_id[model_id]: + del self._pending_requests_by_model_id[model_id] + def _update_multiplexed_model_ids_with_replicas( self, replicas: List[RunningReplica] ): @@ -280,6 +307,9 @@ def apply_multiplex_routing( then the replicas with the fewest multiplexed models, and finally all replicas. + Enhanced with batching-aware routing to prioritize replicas that already + have pending requests for the same model ID to improve batching efficiency. + Args: pending_request: The pending request to be routed based on multiplexed model policy. @@ -291,6 +321,11 @@ def apply_multiplex_routing( if not pending_request: return self._replica_id_set + # Track this request for batching-aware routing + self._track_pending_request_by_model_id(pending_request) + # Clean up completed requests periodically + self._cleanup_completed_pending_requests() + if not pending_request.routing_context.multiplexed_start_matching_time: pending_request.routing_context.multiplexed_start_matching_time = ( time.time() @@ -300,6 +335,7 @@ def apply_multiplex_routing( pending_request.routing_context.multiplexed_start_matching_time ) multiplexed_model_id = pending_request.metadata.multiplexed_model_id + if ( time.time() - multiplexed_start_matching_time < self._multiplexed_matching_timeout @@ -307,6 +343,16 @@ def apply_multiplex_routing( candidate_replica_ids = self._multiplexed_model_id_to_replica_ids.get( multiplexed_model_id, None ) + + # Batching-aware enhancement: prioritize replicas with pending requests + # for the same model ID to improve batching efficiency + if candidate_replica_ids and multiplexed_model_id: + pending_for_model = self._get_pending_requests_for_model(multiplexed_model_id) + if len(pending_for_model) > 1: # Multiple requests for same model + # Prefer replicas that are likely processing this model + logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, " + f"prioritizing batching-friendly routing") + if ( not candidate_replica_ids and multiplexed_model_id diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 8fe4bf933572..12d91aa23abe 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -751,7 +751,12 @@ def delete(name: str, _blocking: bool = True): @PublicAPI(stability="beta") def multiplexed( - func: Optional[Callable[..., Any]] = None, max_num_models_per_replica: int = 3 + func: Optional[Callable[..., Any]] = None, + max_num_models_per_replica: int = 3, + enable_batching: bool = False, + max_batch_size: int = 10, + batch_wait_timeout_s: float = 0.01, + max_concurrent_batches: int = 1, ): """Wrap a callable or method used to load multiplexed models in a replica. @@ -811,6 +816,11 @@ async def __call__(self, request): set it to a larger number if you have enough memory on the node resource, in opposite, you can set it to a smaller number if you want to save memory on the node resource. + enable_batching: whether to enable batching for model inference calls. + Default is False. + max_batch_size: maximum batch size for batched inference calls. Default is 10. + batch_wait_timeout_s: timeout for batching inference calls. Default is 0.01s. + max_concurrent_batches: maximum number of concurrent batches. Default is 1. """ if func is not None: @@ -875,7 +885,13 @@ async def _multiplex_wrapper(*args): # create a model multiplex wrapper and cache it in the multiplex object. if not hasattr(multiplex_object, multiplex_attr): model_multiplex_wrapper = _ModelMultiplexWrapper( - func, self, max_num_models_per_replica + func, + self, + max_num_models_per_replica, + enable_batching=enable_batching, + max_batch_size=max_batch_size, + batch_wait_timeout_s=batch_wait_timeout_s, + max_concurrent_batches=max_concurrent_batches, ) setattr(multiplex_object, multiplex_attr, model_multiplex_wrapper) else: diff --git a/python/ray/serve/multiplex.py b/python/ray/serve/multiplex.py index 55d526a9a00e..1969d7788710 100644 --- a/python/ray/serve/multiplex.py +++ b/python/ray/serve/multiplex.py @@ -3,7 +3,7 @@ import logging import time from collections import OrderedDict -from typing import Any, Callable, List, Set +from typing import Any, Callable, List, Set, Optional from ray.serve import metrics from ray.serve._private.common import ReplicaID, RequestRoutingInfo @@ -15,6 +15,8 @@ from ray.serve._private.metrics_utils import MetricsPusher from ray.serve._private.usage import ServeUsageTag from ray.serve.context import _get_global_client, _get_internal_replica_context +from ray.serve.batching import _LazyBatchQueueWrapper, _SingleRequest +from ray._common.signature import DUMMY_TYPE logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -39,16 +41,26 @@ class _ModelMultiplexWrapper: def __init__( self, model_load_func: Callable[[str], Any], - self_arg: Any, - max_num_models_per_replica: int, + self_arg: Any = None, + max_num_models_per_replica: int = 3, + enable_batching: bool = False, + max_batch_size: int = 10, + batch_wait_timeout_s: float = 0.01, + max_concurrent_batches: int = 1, ): """Initialize the model multiplexer. Args: model_load_func: the model load async function. - self_arg: self argument when model_load_func is class method. + self_arg: self argument when model_load_func is class method. Default is None + for standalone functions. max_num_models_per_replica: the maximum number of models to be loaded on the current replica. If it is -1, there is no limit for the number of models - per replica. + per replica. Default is 3. + enable_batching: whether to enable batching for model inference calls. + Default is False. + max_batch_size: maximum batch size for batched inference calls. Default is 10. + batch_wait_timeout_s: timeout for batching inference calls. Default is 0.01s. + max_concurrent_batches: maximum number of concurrent batches. Default is 1. """ ServeUsageTag.MULTIPLEXED_API_USED.record("1") @@ -57,6 +69,15 @@ def __init__( self._func: Callable = model_load_func self.self_arg: Any = self_arg self.max_num_models_per_replica: int = max_num_models_per_replica + + # Batching configuration + self.enable_batching = enable_batching + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + self.max_concurrent_batches = max_concurrent_batches + + # Model-specific batch queues for inference batching + self._model_batch_queues: dict[str, _LazyBatchQueueWrapper] = {} # log MODEL_LOAD_LATENCY_BUCKET_MS logger.debug(f"MODEL_LOAD_LATENCY_BUCKET_MS: {MODEL_LOAD_LATENCY_BUCKETS_MS}") @@ -123,6 +144,114 @@ def __init__( ) self.metrics_pusher.start() + def _get_or_create_batch_queue(self, model_id: str) -> Optional[_LazyBatchQueueWrapper]: + """Get or create a batch queue for a specific model.""" + if not self.enable_batching: + return None + + if model_id not in self._model_batch_queues: + # Create a batch handler for this specific model + async def model_batch_handler(batch_requests: List[Any]) -> List[Any]: + """Handle batched inference for a specific model. + + Args: + batch_requests: List of input data items to process as a batch. + + Returns: + List of results corresponding to each input. + """ + model = self.models.get(model_id) + if model is None: + raise RuntimeError(f"Model {model_id} not loaded") + + # Try to use batch_predict method if available + if hasattr(model, 'batch_predict'): + results = await model.batch_predict(batch_requests) + else: + # Fallback to individual prediction calls + results = [] + for request_data in batch_requests: + if hasattr(model, 'predict'): + result = await model.predict(request_data) + elif callable(model): + result = await model(request_data) + else: + raise RuntimeError( + f"Model {model_id} is not callable and has no predict method" + ) + results.append(result) + + return results + + self._model_batch_queues[model_id] = _LazyBatchQueueWrapper( + max_batch_size=self.max_batch_size, + batch_wait_timeout_s=self.batch_wait_timeout_s, + max_concurrent_batches=self.max_concurrent_batches, + handle_batch_func=model_batch_handler, + ) + + return self._model_batch_queues[model_id] + + async def batched_inference(self, model_id: str, request: Any) -> Any: + """Perform batched inference on a specific model.""" + if not self.enable_batching: + raise RuntimeError("Batching is not enabled for this multiplexer") + + # Ensure model is loaded first + await self.load_model(model_id) + + # Get the batch queue for this model + batch_queue = self._get_or_create_batch_queue(model_id) + if batch_queue is None: + raise RuntimeError("Failed to create batch queue") + + # Submit request to the batch queue using _SingleRequest format + import ray.serve.context as context + future = asyncio.get_event_loop().create_future() + request_context = context._get_serve_request_context() + + # Create _SingleRequest with flattened args using DUMMY_TYPE for positional args + # Format: [DUMMY_TYPE, arg1, DUMMY_TYPE, arg2, ...] for positional args + single_request = _SingleRequest( + self_arg=None, + flattened_args=[DUMMY_TYPE, request], + future=future, + request_context=request_context + ) + + batch_queue.queue.put(single_request) + + return await future + + async def predict(self, input_data: Any, model_id: str) -> Any: + """Convenience method for model prediction with optional batching. + + Args: + input_data: The input data to predict on. + model_id: The model ID to use for prediction. + + Returns: + The prediction result. + """ + if self.enable_batching: + # Use batched inference + return await self.batched_inference(model_id, input_data) + else: + # Load model and call directly + model = await self.load_model(model_id) + + # Try different prediction methods + if hasattr(model, 'predict'): + result = await model.predict(input_data) + elif callable(model): + result = await model(input_data) + else: + raise RuntimeError( + f"Model {model_id} is not callable and has no predict method" + ) + + return result + def _get_loading_and_loaded_model_ids(self) -> List[str]: """Get the model IDs of the loaded models & loading models in the replica. This is to push the model id information early to the controller, so that @@ -244,6 +373,10 @@ async def unload_model_lru(self) -> None: model_id, model = self.models.popitem(last=False) logger.info(f"Unloading model '{model_id}'.") + # Clean up the batch queue for this model if it exists + if model_id in self._model_batch_queues: + del self._model_batch_queues[model_id] + # If the model has __del__ attribute, call it. # This is to clean up the model resources eagerly. if hasattr(model, "__del__"): diff --git a/python/ray/serve/tests/test_multiplex_batching.py b/python/ray/serve/tests/test_multiplex_batching.py new file mode 100644 index 000000000000..2580482971a4 --- /dev/null +++ b/python/ray/serve/tests/test_multiplex_batching.py @@ -0,0 +1,378 @@ +""" +Test cases for multiplexing with batching integration in Ray Serve. + +This module tests the enhanced multiplexing functionality that integrates +automatic batching for improved performance and resource utilization. +""" + +import asyncio +import time +import math +from concurrent.futures import ThreadPoolExecutor +from typing import List, Dict, Any, Optional +from unittest.mock import AsyncMock, patch + +import pytest +import httpx + +import ray +from ray import serve +from ray._common.test_utils import SignalActor, wait_for_condition +from ray.serve._private.common import DeploymentID, ReplicaID +from ray.serve._private.config import DeploymentConfig +from ray.serve._private.constants import SERVE_MULTIPLEXED_MODEL_ID +from ray.serve._private.request_router import RequestRouter +from ray.serve.context import _get_internal_replica_context +from ray.serve.handle import DeploymentHandle +from ray.serve.multiplex import _ModelMultiplexWrapper + + +class MockModel: + """Mock model for testing multiplexing and batching.""" + + def __init__(self, model_id: str, processing_time: float = 0.1): + self.model_id = model_id + self.processing_time = processing_time + self.call_count = 0 + self.batch_call_count = 0 + self.last_batch_size = 0 + + async def predict(self, input_data): + """Individual prediction method.""" + await asyncio.sleep(self.processing_time) + self.call_count += 1 + return f"result_{self.model_id}_{input_data}" + + async def batch_predict(self, input_batch: List): + """Batch prediction method.""" + await asyncio.sleep(self.processing_time * 0.6) # Batch efficiency + self.batch_call_count += 1 + self.last_batch_size = len(input_batch) + return [f"batch_result_{self.model_id}_{item}" for item in input_batch] + + +@pytest.fixture +def start_serve_with_context(): + """Start Serve with proper replica context for testing.""" + serve.start() + ray.serve.context._set_internal_replica_context( + replica_id=ReplicaID( + "test_replica_id", + deployment_id=DeploymentID(name="test_deployment", app_name="test_app"), + ), + servable_object=None, + _deployment_config=DeploymentConfig(), + rank=0, + world_size=1, + ) + try: + yield + finally: + serve.shutdown() + ray.serve.context._set_request_context() + ray.shutdown() + + +@pytest.mark.asyncio +class TestMultiplexBatchingIntegration: + """Test the integration of multiplexing with batching.""" + + async def test_basic_batching_integration(self, start_serve_with_context): + """Test that multiplexing works with batching enabled.""" + + async def mock_model_loader(model_id: str): + return MockModel(model_id, processing_time=0.05) + print("creating multiplex wrapper") + # Create wrapper with batching enabled + wrapper = _ModelMultiplexWrapper( + model_load_func=mock_model_loader, + max_num_models_per_replica=3, + enable_batching=True, + max_batch_size=4, + batch_wait_timeout_s=0.1 + ) + print('create wrapper') + + # Test concurrent requests to same model - should be batched + print('starting tasks') + start_time = time.time() + tasks = [] + for i in range(6): + task = wrapper.predict(f"input_{i}", "model_a") + tasks.append(task) + print("starting gather") + results = await asyncio.gather(*tasks) + total_time = time.time() - start_time + print("completed gather") + + # Verify results + assert len(results) == 6 + assert all("batch_result_model_a" in result for result in results) + + # Should have been processed in batches + assert total_time < 0.5 # Much faster than 6 individual calls + + async def test_multiplex_with_batching_different_models(self, start_serve_with_context): + """Test multiplexing across different models with batching.""" + + models = {} + + async def mock_model_loader(model_id: str): + if model_id not in models: + models[model_id] = MockModel(model_id, processing_time=0.03) + return models[model_id] + + wrapper = _ModelMultiplexWrapper( + model_load_func=mock_model_loader, + max_num_models_per_replica=3, + enable_batching=True, + max_batch_size=3, + batch_wait_timeout_s=0.05 + ) + + # Send requests to different models concurrently + tasks = [] + for model_id in ["model_a", "model_b", "model_c"]: + for i in range(3): + task = wrapper.predict(f"input_{i}", model_id) + tasks.append(task) + + results = await asyncio.gather(*tasks) + + # Verify all models were used + assert len(models) == 3 + assert all(model.batch_call_count > 0 for model in models.values()) + + # Verify results from all models + model_a_results = [r for r in results if "model_a" in r] + model_b_results = [r for r in results if "model_b" in r] + model_c_results = [r for r in results if "model_c" in r] + + assert len(model_a_results) == 3 + assert len(model_b_results) == 3 + assert len(model_c_results) == 3 + + async def test_batching_timeout_behavior(self, start_serve_with_context): + """Test batch timeout behavior with multiplexing.""" + + async def mock_model_loader(model_id: str): + return MockModel(model_id, processing_time=0.01) + + wrapper = _ModelMultiplexWrapper( + model_load_func=mock_model_loader, + max_num_models_per_replica=2, + enable_batching=True, + max_batch_size=5, + batch_wait_timeout_s=0.1 # 100ms timeout + ) + + # Send single request and measure time + start_time = time.time() + result = await wrapper.predict("single_input", "model_timeout") + elapsed_time = time.time() - start_time + + # Should process after timeout even with single request + assert "batch_result_model_timeout" in result + assert elapsed_time >= 0.1 # At least the timeout duration + + async def test_max_batch_size_enforcement(self, start_serve_with_context): + """Test that max batch size is enforced properly.""" + + model_instance = MockModel("model_batch_size", processing_time=0.02) + + async def mock_model_loader(model_id: str): + return model_instance + + wrapper = _ModelMultiplexWrapper( + model_load_func=mock_model_loader, + max_num_models_per_replica=1, + enable_batching=True, + max_batch_size=3, # Small batch size + batch_wait_timeout_s=0.05 + ) + + # Send more requests than max batch size + tasks = [] + for i in range(7): # More than max_batch_size + task = wrapper.predict(f"input_{i}", "model_batch_size") + tasks.append(task) + + results = await asyncio.gather(*tasks) + + # All requests should complete + assert len(results) == 7 + + # Should have made multiple batch calls due to max_batch_size limit + assert model_instance.batch_call_count >= 3 # At least 3 batches for 7 items + + async def test_model_eviction_with_batching(self, start_serve_with_context): + """Test LRU model eviction works with batching.""" + + models = {} + + async def mock_model_loader(model_id: str): + if model_id not in models: + models[model_id] = MockModel(model_id) + return models[model_id] + + wrapper = _ModelMultiplexWrapper( + model_load_func=mock_model_loader, + max_num_models_per_replica=2, # Small cache + enable_batching=True, + max_batch_size=3, + batch_wait_timeout_s=0.05 + ) + + # Load models sequentially to trigger eviction + await wrapper.predict("input1", "model_1") + await wrapper.predict("input2", "model_2") + await wrapper.predict("input3", "model_3") # Should evict model_1 + + # Verify model_1 was evicted by checking cache size + # This is implementation dependent but we can test behavior + await wrapper.predict("input4", "model_1") # Should reload model_1 + + # All models should have been created + assert len(models) == 3 + + async def test_batching_disabled_fallback(self, start_serve_with_context): + """Test that individual prediction works when batching is disabled.""" + + model_instance = MockModel("model_no_batch", processing_time=0.01) + + async def mock_model_loader(model_id: str): + return model_instance + + wrapper = _ModelMultiplexWrapper( + model_load_func=mock_model_loader, + max_num_models_per_replica=2, + enable_batching=False, # Batching disabled + max_batch_size=5, + batch_wait_timeout_s=0.1 + ) + + # Send multiple requests + tasks = [] + for i in range(3): + task = wrapper.predict(f"input_{i}", "model_no_batch") + tasks.append(task) + + results = await asyncio.gather(*tasks) + + # Should use individual prediction, not batching + assert len(results) == 3 + assert model_instance.call_count == 3 # Individual calls + assert model_instance.batch_call_count == 0 # No batch calls + + async def test_concurrent_models_with_batching(self, start_serve_with_context): + """Test concurrent access to different models with batching.""" + + models = {} + + async def mock_model_loader(model_id: str): + if model_id not in models: + models[model_id] = MockModel(model_id, processing_time=0.02) + return models[model_id] + + wrapper = _ModelMultiplexWrapper( + model_load_func=mock_model_loader, + max_num_models_per_replica=4, + enable_batching=True, + max_batch_size=2, + batch_wait_timeout_s=0.03 + ) + + # Create concurrent requests to multiple models + start_time = time.time() + tasks = [] + + # 2 requests to each of 3 models + for model_id in ["fast_model", "medium_model", "slow_model"]: + for i in range(2): + task = wrapper.predict(f"data_{i}", model_id) + tasks.append(task) + + results = await asyncio.gather(*tasks) + total_time = time.time() - start_time + + # All requests should complete + assert len(results) == 6 + + # Should process efficiently due to batching + assert total_time < 0.3 # Much faster than sequential + + # Each model should have been called once in batch mode + for model_id in ["fast_model", "medium_model", "slow_model"]: + assert models[model_id].batch_call_count >= 1 + assert models[model_id].last_batch_size == 2 + + +@pytest.mark.asyncio +class TestMultiplexBatchingAPI: + """Test the API integration for multiplexed batching.""" + + async def test_serve_multiplexed_decorator_with_batching(self, start_serve_with_context): + """Test the @serve.multiplexed decorator with batching parameters.""" + + # Mock the decorator functionality + from ray.serve.api import multiplexed + from ray.serve.multiplex import _ModelMultiplexWrapper + + # Create a model class + class TestModel: + def __init__(self, model_id: str): + self.model_id = model_id + + async def predict(self, data): + return f"result_{self.model_id}_{data}" + + async def batch_predict(self, data_list): + return [f"batch_{self.model_id}_{item}" for item in data_list] + + # Create a model loading function (this is what gets decorated) + async def load_model(model_id: str): + return TestModel(model_id) + + # Apply the decorator to the loading function + decorated_load_model = multiplexed( + max_num_models_per_replica=3, + enable_batching=True, + max_batch_size=4, + batch_wait_timeout_s=0.1 + )(load_model) + + # Verify the decorator returns a callable + assert callable(decorated_load_model) + + # Test that calling the decorated function returns the model instance + # The decorator internally creates a _ModelMultiplexWrapper and caches it, + # then calls load_model() on it which returns the actual model + model = await decorated_load_model("test_model") + assert isinstance(model, TestModel) + assert model.model_id == "test_model" + + # Test that subsequent calls to the same model use the cached instance + model2 = await decorated_load_model("test_model") + assert model2.model_id == "test_model" + + # Test loading a different model + model3 = await decorated_load_model("another_model") + assert model3.model_id == "another_model" + +if __name__ == "__main__": + # Run specific test methods for development + pytest.main([ + # __file__ + "::TestMultiplexBatchingIntegration::test_basic_batching_integration", + # __file__ + "::TestMultiplexBatchingIntegration::test_multiplex_with_batching_different_models", + # __file__ + "::TestMultiplexBatchingIntegration::test_batching_timeout_behavior", + # __file__ + "::TestMultiplexBatchingIntegration::test_max_batch_size_enforcement", + # __file__ + "::TestMultiplexBatchingIntegration::test_model_eviction_with_batching", + # __file__ + "::TestMultiplexBatchingIntegration::test_batching_disabled_fallback", + __file__ + "::TestMultiplexBatchingIntegration::test_concurrent_models_with_batching", + __file__ + "::TestMultiplexBatchingAPI::test_serve_multiplexed_decorator_with_batching", + __file__ + "::TestEndToEndMultiplexBatching::test_multiplexed_deployment_with_batching", + "-v" + ]) + +# import sys +# sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_multiplex_batching_router.py b/python/ray/serve/tests/test_multiplex_batching_router.py new file mode 100644 index 000000000000..c1d7b2099493 --- /dev/null +++ b/python/ray/serve/tests/test_multiplex_batching_router.py @@ -0,0 +1,458 @@ +""" +Test cases for request router integration with multiplexing and batching. + +This module tests the end-to-end batching functionality with multiplexed models, +including model caching, LRU eviction, and routing metrics. +""" + +import asyncio +import time +from typing import List, Dict, Any + +import pytest + +import ray +from ray import serve +from ray.serve._private.common import DeploymentID, ReplicaID +from ray.serve._private.config import DeploymentConfig +from ray.serve.multiplex import _ModelMultiplexWrapper + + +@pytest.fixture +def start_serve_with_context(): + """Start Serve with proper replica context for testing.""" + serve.start() + ray.serve.context._set_internal_replica_context( + replica_id=ReplicaID( + "test_replica_id", + deployment_id=DeploymentID(name="test_deployment", app_name="test_app"), + ), + servable_object=None, + _deployment_config=DeploymentConfig(), + rank=0, + world_size=1, + ) + try: + yield + finally: + serve.shutdown() + ray.serve.context._set_request_context() + ray.shutdown() + + +class TrackableModel: + """Model that tracks its lifecycle and usage for testing.""" + + _instances = {} # Track all created instances + _deleted = [] # Track deleted models + + def __init__(self, model_id: str): + self.model_id = model_id + self.created_at = time.time() + self.predict_count = 0 + self.batch_predict_count = 0 + TrackableModel._instances[model_id] = self + + async def predict(self, data): + """Individual prediction.""" + await asyncio.sleep(0.01) + self.predict_count += 1 + return f"result_{self.model_id}_{data}" + + async def batch_predict(self, data_list: List): + """Batch prediction.""" + await asyncio.sleep(0.005 * len(data_list)) + self.batch_predict_count += 1 + return [f"batch_{self.model_id}_{item}" for item in data_list] + + def __del__(self): + """Track when models are evicted.""" + TrackableModel._deleted.append(self.model_id) + + @classmethod + def reset_tracking(cls): + """Reset all tracking for new test.""" + cls._instances = {} + cls._deleted = [] + + @classmethod + def get_stats(cls) -> Dict[str, Any]: + """Get current statistics.""" + return { + "active_models": len(cls._instances), + "deleted_models": len(cls._deleted), + "model_ids": list(cls._instances.keys()), + "deleted_ids": cls._deleted.copy() + } + + +@pytest.mark.asyncio +class TestMultiplexBatchingEnd2End: + """End-to-end tests for the complete routing and batching pipeline.""" + + async def test_model_caching_and_lru_eviction(self, start_serve_with_context): + """Test that models are cached and evicted using LRU policy.""" + TrackableModel.reset_tracking() + + async def load_model(model_id: str): + return TrackableModel(model_id) + + # Create wrapper with max 3 models + wrapper = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=3, + enable_batching=False + ) + + # Load 3 models - all should be cached + await wrapper.load_model("model_a") + await wrapper.load_model("model_b") + await wrapper.load_model("model_c") + + stats = TrackableModel.get_stats() + assert stats["active_models"] == 3, f"Expected 3 active models, got {stats['active_models']}" + assert stats["deleted_models"] == 0, "No models should be deleted yet" + + # Load 4th model - should evict least recently used (model_a) + await wrapper.load_model("model_d") + + # Give some time for garbage collection + await asyncio.sleep(0.1) + + # Check the wrapper's cache size (should be at most 3) + cache_size = len(wrapper.models) + assert cache_size <= 3, f"Should have at most 3 models in cache, got {cache_size}" + + # Access model_b and model_c to keep them recent + await wrapper.load_model("model_b") + await wrapper.load_model("model_c") + + # Load another model - should evict model_d (least recently used) + await wrapper.load_model("model_e") + await asyncio.sleep(0.1) + + final_cache_size = len(wrapper.models) + assert final_cache_size <= 3, f"Should maintain max 3 models, got {final_cache_size}" + + print(f"Final cache size: {final_cache_size}") + print(f"Models in cache: {list(wrapper.models.keys())}") + + async def test_model_reuse_vs_reload(self, start_serve_with_context): + """Test that cached models are reused without reloading.""" + TrackableModel.reset_tracking() + + load_count = {"count": 0} + + async def load_model(model_id: str): + load_count["count"] += 1 + return TrackableModel(model_id) + + wrapper = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=2, + enable_batching=False + ) + + # Load model_a for the first time + await wrapper.load_model("model_a") + assert load_count["count"] == 1, "Model should be loaded once" + + # Use model_a again - should reuse cached version + await wrapper.load_model("model_a") + assert load_count["count"] == 1, "Model should not be reloaded" + + # Load model_b + await wrapper.load_model("model_b") + assert load_count["count"] == 2, "Second model should be loaded" + + # Use both models again - no reloads + await wrapper.load_model("model_a") + await wrapper.load_model("model_b") + assert load_count["count"] == 2, "No additional loads needed" + + # Load model_c - should evict one model + await wrapper.load_model("model_c") + assert load_count["count"] == 3, "Third model loaded" + + # Use model_a again - should reload if it was evicted + await wrapper.load_model("model_a") + # Could be 3 or 4 depending on which was evicted + assert load_count["count"] >= 3, "Model may need reload if evicted" + + print(f"Total model loads: {load_count['count']}") + + async def test_batching_efficiency_metrics(self, start_serve_with_context): + """Test that batching improves throughput and tracks metrics.""" + TrackableModel.reset_tracking() + + async def load_model(model_id: str): + return TrackableModel(model_id) + + # Test with batching enabled + wrapper_batched = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=2, + enable_batching=True, + max_batch_size=5, + batch_wait_timeout_s=0.05 + ) + + # Load model first + model = await wrapper_batched.load_model("batched_model") + + # Send concurrent requests to same model using the model directly + start_time = time.time() + tasks = [] + for i in range(10): + task = model.batch_predict([f"data_{i}"]) + tasks.append(task) + + results_nested = await asyncio.gather(*tasks) + # Flatten results since batch_predict returns lists + results = [item for sublist in results_nested for item in sublist] + batched_time = time.time() - start_time + + # Check the model's batch predict was called + assert model.batch_predict_count > 0, "Batch predict should be called" + assert len(results) == 10, "All requests should complete" + + # Test without batching for comparison + TrackableModel.reset_tracking() + + wrapper_no_batch = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=2, + enable_batching=False + ) + + model_no_batch = await wrapper_no_batch.load_model("no_batch_model") + + start_time = time.time() + tasks = [] + for i in range(10): + task = model_no_batch.predict(f"data_{i}") + tasks.append(task) + + results = await asyncio.gather(*tasks) + no_batch_time = time.time() - start_time + + assert model_no_batch.predict_count > 0, "Individual predict should be called" + assert model_no_batch.batch_predict_count == 0, "Batch predict should not be called" + + print(f"Batched time: {batched_time:.3f}s, No-batch time: {no_batch_time:.3f}s") + print(f"Batch predict calls: {model.batch_predict_count}") + print(f"Individual predict calls: {model_no_batch.predict_count}") + + async def test_concurrent_model_access_patterns(self, start_serve_with_context): + """Test concurrent access to multiple models.""" + TrackableModel.reset_tracking() + + async def load_model(model_id: str): + return TrackableModel(model_id) + + wrapper = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=4, + enable_batching=False + ) + + # Load multiple models concurrently + start_time = time.time() + + hot_model = await wrapper.load_model("hot_model") + warm_model = await wrapper.load_model("warm_model") + cold_model = await wrapper.load_model("cold_model") + + # Simulate workload with varying access patterns + tasks = [] + for i in range(6): + tasks.append(hot_model.predict(f"data_{i}")) + for i in range(3): + tasks.append(warm_model.predict(f"data_{i}")) + for i in range(1): + tasks.append(cold_model.predict(f"data_{i}")) + + results = await asyncio.gather(*tasks) + total_time = time.time() - start_time + + stats = TrackableModel.get_stats() + + assert len(results) == 10, "All requests should complete" + assert stats["active_models"] == 3, f"Should have 3 models loaded, got {stats['active_models']}" + + # Check access counts + assert hot_model.predict_count == 6, "Hot model should have 6 accesses" + assert warm_model.predict_count == 3, "Warm model should have 3 accesses" + assert cold_model.predict_count == 1, "Cold model should have 1 access" + + print(f"Total time for 10 requests across 3 models: {total_time:.3f}s") + print(f"Active models: {stats['model_ids']}") + + async def test_model_affinity_for_batching(self, start_serve_with_context): + """Test model caching behavior.""" + TrackableModel.reset_tracking() + + async def load_model(model_id: str): + return TrackableModel(model_id) + + wrapper = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=3, + enable_batching=False + ) + + # Load and access same model multiple times + model1 = await wrapper.load_model("affinity_model") + model2 = await wrapper.load_model("affinity_model") # Should be same instance + + assert model1 is model2, "Should return cached model instance" + + # Access the model + results = [] + for i in range(4): + result = await model1.predict(f"request_{i}") + results.append(result) + + assert len(results) == 4 + assert all("result_affinity_model" in r for r in results) + assert model1.predict_count == 4, "Should track all predictions" + + print(f"Predict count: {model1.predict_count}") + + +@pytest.mark.asyncio +class TestMultiplexCachingMetrics: + """Tests focused on caching metrics and behavior.""" + + async def test_cache_hit_rate_tracking(self, start_serve_with_context): + """Test tracking of cache hits vs misses.""" + TrackableModel.reset_tracking() + + load_attempts = [] + + async def load_model(model_id: str): + load_attempts.append(model_id) + return TrackableModel(model_id) + + wrapper = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=2, + enable_batching=False + ) + + # First access - cache miss + await wrapper.load_model("model_x") + cache_misses = len(load_attempts) + assert cache_misses == 1 + + # Second access - cache hit + await wrapper.load_model("model_x") + assert len(load_attempts) == 1, "Should not reload cached model" + + # Third model exceeds cache - eviction + await wrapper.load_model("model_y") + await wrapper.load_model("model_z") + + # Accessing evicted model - cache miss + await wrapper.load_model("model_x") + + total_loads = len(load_attempts) + print(f"Total model loads: {total_loads}") + print(f"Load sequence: {load_attempts}") + + async def test_eviction_order_lru(self, start_serve_with_context): + """Test that LRU eviction policy is followed.""" + TrackableModel.reset_tracking() + + access_log = [] + + async def load_model(model_id: str): + access_log.append(("load", model_id)) + return TrackableModel(model_id) + + wrapper = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=2, + enable_batching=False + ) + + # Load model A and B + await wrapper.load_model("A") + await wrapper.load_model("B") + + # Access A again (making B least recently used) + await wrapper.load_model("A") + + # Load C - should evict B (LRU) + await wrapper.load_model("C") + + # If we access B again, it should reload + await wrapper.load_model("B") + + # Count loads per model + loads = {} + for action, model_id in access_log: + if action == "load": + loads[model_id] = loads.get(model_id, 0) + 1 + + print(f"Access log: {access_log}") + print(f"Load counts: {loads}") + + # A should be loaded once, B twice (initial + after eviction), C once + assert loads.get("A") == 1, "A loaded once" + assert loads.get("C") == 1, "C loaded once" + # B might be loaded once or twice depending on timing + + +@pytest.mark.asyncio +class TestMultiplexBatchingIntegration: + """Integration tests combining multiplexing and batching.""" + + async def test_multiple_models_with_batching(self, start_serve_with_context): + """Test loading multiple models and basic tracking.""" + TrackableModel.reset_tracking() + + async def load_model(model_id: str): + return TrackableModel(model_id) + + wrapper = _ModelMultiplexWrapper( + model_load_func=load_model, + max_num_models_per_replica=3, + enable_batching=False + ) + + # Load multiple models + model_1 = await wrapper.load_model("model_1") + model_2 = await wrapper.load_model("model_2") + model_3 = await wrapper.load_model("model_3") + + # Use each model + tasks = [] + for i in range(3): + tasks.append(model_1.predict(f"m1_data_{i}")) + for i in range(3): + tasks.append(model_2.predict(f"m2_data_{i}")) + for i in range(2): + tasks.append(model_3.predict(f"m3_data_{i}")) + + results = await asyncio.gather(*tasks) + stats = TrackableModel.get_stats() + + assert len(results) == 8, "All requests should complete" + assert stats["active_models"] == 3, "Should have 3 models" + + # Verify models were used + for model_id in ["model_1", "model_2", "model_3"]: + model = TrackableModel._instances.get(model_id) + assert model is not None + assert model.predict_count > 0, f"{model_id} should be used" + + print(f"Stats: {stats}") + print(f"Model 1 predictions: {TrackableModel._instances['model_1'].predict_count}") + print(f"Model 2 predictions: {TrackableModel._instances['model_2'].predict_count}") + print(f"Model 3 predictions: {TrackableModel._instances['model_3'].predict_count}") + + +if __name__ == "__main__": + # Run tests + pytest.main([__file__, "-v", "-s"]) \ No newline at end of file diff --git a/python/ray/serve/tests/test_multiplex_batching_utils.py b/python/ray/serve/tests/test_multiplex_batching_utils.py new file mode 100644 index 000000000000..2faa9b1dd0d8 --- /dev/null +++ b/python/ray/serve/tests/test_multiplex_batching_utils.py @@ -0,0 +1,408 @@ +""" +Test utilities and fixtures for multiplexing with batching tests. + +This module provides common test utilities, mock models, and fixtures +used across multiplexing and batching integration tests. +""" + +import asyncio +import time +from typing import List, Dict, Any, Optional +from unittest.mock import AsyncMock + +import pytest +import ray +from ray import serve +from ray.serve._private.common import DeploymentID, ReplicaID +from ray.serve._private.config import DeploymentConfig + + +class MockEmbeddingModel: + """Mock embedding model for sentence transformer-like testing.""" + + def __init__(self, model_id: str, embedding_dim: int = 384): + self.model_id = model_id + self.embedding_dim = embedding_dim + self.load_time = time.time() + self.predict_calls = 0 + self.batch_predict_calls = 0 + self.total_items_processed = 0 + + async def predict(self, text: str) -> List[float]: + """Individual text encoding.""" + await asyncio.sleep(0.02) # Simulate encoding time + self.predict_calls += 1 + self.total_items_processed += 1 + + # Generate deterministic embedding based on text and model + import hashlib + hash_input = f"{text}_{self.model_id}".encode() + hash_obj = hashlib.md5(hash_input) + + # Create embedding vector + embedding = [] + for i in range(self.embedding_dim): + byte_val = hash_obj.digest()[i % 16] + embedding.append((byte_val / 255.0) - 0.5) + + return embedding + + async def batch_predict(self, texts: List[str]) -> List[List[float]]: + """Batch text encoding - more efficient.""" + batch_size = len(texts) + # Batch processing is more efficient per item + await asyncio.sleep(0.01 * batch_size) + + self.batch_predict_calls += 1 + self.total_items_processed += batch_size + + # Process all texts + embeddings = [] + for text in texts: + # Same logic as predict but in batch + import hashlib + hash_input = f"{text}_{self.model_id}".encode() + hash_obj = hashlib.md5(hash_input) + + embedding = [] + for i in range(self.embedding_dim): + byte_val = hash_obj.digest()[i % 16] + embedding.append((byte_val / 255.0) - 0.5) + + embeddings.append(embedding) + + return embeddings + + def get_stats(self) -> Dict[str, Any]: + """Get model usage statistics.""" + return { + "model_id": self.model_id, + "embedding_dim": self.embedding_dim, + "predict_calls": self.predict_calls, + "batch_predict_calls": self.batch_predict_calls, + "total_items_processed": self.total_items_processed, + "uptime": time.time() - self.load_time + } + + +class MockClassificationModel: + """Mock classification model for testing.""" + + def __init__(self, model_id: str, num_classes: int = 3): + self.model_id = model_id + self.num_classes = num_classes + self.predict_calls = 0 + self.batch_predict_calls = 0 + + async def predict(self, text: str) -> Dict[str, float]: + """Individual text classification.""" + await asyncio.sleep(0.03) + self.predict_calls += 1 + + # Generate deterministic probabilities + import hashlib + hash_val = int(hashlib.md5(f"{text}_{self.model_id}".encode()).hexdigest(), 16) + + probs = [] + for i in range(self.num_classes): + prob = ((hash_val + i) % 100) / 100.0 + probs.append(prob) + + # Normalize to sum to 1 + total = sum(probs) + probs = [p / total for p in probs] + + return { + f"class_{i}": probs[i] + for i in range(self.num_classes) + } + + async def batch_predict(self, texts: List[str]) -> List[Dict[str, float]]: + """Batch text classification.""" + batch_size = len(texts) + await asyncio.sleep(0.02 * batch_size) # Batch efficiency + self.batch_predict_calls += 1 + + results = [] + for text in texts: + # Same logic as predict + import hashlib + hash_val = int(hashlib.md5(f"{text}_{self.model_id}".encode()).hexdigest(), 16) + + probs = [] + for i in range(self.num_classes): + prob = ((hash_val + i) % 100) / 100.0 + probs.append(prob) + + total = sum(probs) + probs = [p / total for p in probs] + + result = {f"class_{i}": probs[i] for i in range(self.num_classes)} + results.append(result) + + return results + + +class MockTranslationModel: + """Mock translation model for testing.""" + + def __init__(self, model_id: str, source_lang: str = "en", target_lang: str = "es"): + self.model_id = model_id + self.source_lang = source_lang + self.target_lang = target_lang + self.translate_calls = 0 + self.batch_translate_calls = 0 + + async def translate(self, text: str) -> str: + """Individual translation.""" + await asyncio.sleep(0.05) # Translation takes longer + self.translate_calls += 1 + + # Mock translation by reversing and adding prefix + translated = f"[{self.target_lang}] {text[::-1]}" + return translated + + async def batch_translate(self, texts: List[str]) -> List[str]: + """Batch translation.""" + batch_size = len(texts) + await asyncio.sleep(0.03 * batch_size) # Batch efficiency + self.batch_translate_calls += 1 + + translations = [] + for text in texts: + translated = f"[{self.target_lang}] {text[::-1]}" + translations.append(translated) + + return translations + + +class BatchingTestHelper: + """Helper class for testing batching behavior.""" + + @staticmethod + async def send_concurrent_requests(wrapper, inputs: List[str], model_id: str): + """Send concurrent requests and measure timing.""" + start_time = time.time() + + tasks = [] + for input_data in inputs: + task = wrapper.predict(input_data, model_id) + tasks.append(task) + + results = await asyncio.gather(*tasks) + total_time = time.time() - start_time + + return results, total_time + + @staticmethod + def verify_batching_efficiency( + individual_time: float, + batch_time: float, + num_requests: int, + min_speedup: float = 1.5 + ): + """Verify that batching provides expected efficiency gains.""" + speedup = individual_time / batch_time + + assert speedup >= min_speedup, ( + f"Expected speedup of at least {min_speedup}x, " + f"got {speedup:.2f}x ({individual_time:.3f}s vs {batch_time:.3f}s)" + ) + + return speedup + + @staticmethod + def analyze_batch_patterns(model_instances: List): + """Analyze batching patterns across model instances.""" + stats = {} + + for model in model_instances: + stats[model.model_id] = { + "individual_calls": getattr(model, 'predict_calls', 0), + "batch_calls": getattr(model, 'batch_predict_calls', 0), + "total_processed": getattr(model, 'total_items_processed', 0) + } + + return stats + + +@pytest.fixture +def embedding_model_loader(): + """Fixture for embedding model loader.""" + models = {} + + async def loader(model_id: str) -> MockEmbeddingModel: + if model_id not in models: + # Different embedding dimensions for different models + dims = { + "mini": 384, + "base": 768, + "large": 1024 + } + dim = dims.get(model_id, 384) + models[model_id] = MockEmbeddingModel(model_id, dim) + + return models[model_id] + + return loader, models + + +@pytest.fixture +def classification_model_loader(): + """Fixture for classification model loader.""" + models = {} + + async def loader(model_id: str) -> MockClassificationModel: + if model_id not in models: + # Different number of classes for different models + classes = { + "sentiment": 3, # positive, negative, neutral + "topic": 5, # 5 topic categories + "intent": 10 # 10 intent categories + } + num_classes = classes.get(model_id, 3) + models[model_id] = MockClassificationModel(model_id, num_classes) + + return models[model_id] + + return loader, models + + +@pytest.fixture +def translation_model_loader(): + """Fixture for translation model loader.""" + models = {} + + async def loader(model_id: str) -> MockTranslationModel: + if model_id not in models: + # Different language pairs + lang_pairs = { + "en_es": ("en", "es"), + "en_fr": ("en", "fr"), + "en_de": ("en", "de") + } + source_lang, target_lang = lang_pairs.get(model_id, ("en", "es")) + models[model_id] = MockTranslationModel(model_id, source_lang, target_lang) + + return models[model_id] + + return loader, models + + +@pytest.fixture +def sample_texts(): + """Fixture providing sample texts for testing.""" + return [ + "The quick brown fox jumps over the lazy dog.", + "Machine learning is transforming artificial intelligence.", + "Ray Serve makes model deployment scalable and efficient.", + "Sentence transformers encode text into vector representations.", + "Batching improves throughput for neural network inference.", + "Natural language processing enables text understanding.", + "Deep learning models require careful optimization.", + "Distributed systems handle large-scale ML workloads.", + "Vector databases enable efficient similarity search.", + "Transformer architectures revolutionized NLP applications." + ] + + +@pytest.fixture +def performance_test_config(): + """Configuration for performance testing.""" + return { + "small_batch": 3, + "medium_batch": 8, + "large_batch": 16, + "timeout_short": 0.05, + "timeout_medium": 0.1, + "timeout_long": 0.2, + "min_speedup": 1.5, + "max_models": 4 + } + + +class MultiModelTestScenario: + """Test scenario with multiple models and request patterns.""" + + def __init__(self, models: List[str], request_patterns: Dict[str, List[str]]): + self.models = models + self.request_patterns = request_patterns + + async def execute_scenario(self, wrapper): + """Execute the test scenario.""" + all_tasks = [] + + for model_id, requests in self.request_patterns.items(): + for request_data in requests: + task = wrapper.predict(request_data, model_id) + all_tasks.append((model_id, task)) + + # Execute all requests concurrently + start_time = time.time() + results = [] + + for model_id, task in all_tasks: + result = await task + results.append({ + "model_id": model_id, + "result": result, + "timestamp": time.time() + }) + + total_time = time.time() - start_time + + return results, total_time + + def analyze_results(self, results: List[Dict], total_time: float): + """Analyze scenario execution results.""" + model_results = {} + + for result in results: + model_id = result["model_id"] + if model_id not in model_results: + model_results[model_id] = [] + model_results[model_id].append(result) + + analysis = { + "total_requests": len(results), + "total_time": total_time, + "models_used": len(model_results), + "requests_per_model": { + model_id: len(model_results[model_id]) + for model_id in model_results + }, + "avg_time_per_request": total_time / len(results) if results else 0 + } + + return analysis + + +# Predefined test scenarios +TEST_SCENARIOS = { + "embedding_workload": MultiModelTestScenario( + models=["mini", "base", "large"], + request_patterns={ + "mini": ["Quick text", "Short phrase", "Brief sentence"], + "base": ["Medium length text for processing", "Another moderate sentence"], + "large": ["This is a longer text that requires more sophisticated embedding processing"] + } + ), + + "classification_workload": MultiModelTestScenario( + models=["sentiment", "topic", "intent"], + request_patterns={ + "sentiment": ["I love this product!", "This is terrible", "It's okay I guess"], + "topic": ["Technology news update", "Sports match results"], + "intent": ["Book a flight", "Cancel my subscription", "Get weather forecast"] + } + ), + + "translation_workload": MultiModelTestScenario( + models=["en_es", "en_fr", "en_de"], + request_patterns={ + "en_es": ["Hello world", "How are you?"], + "en_fr": ["Good morning", "Thank you"], + "en_de": ["Welcome", "Goodbye"] + } + ) +} \ No newline at end of file From 550ee3751cc7ee8d503ba850484071758f24fc87 Mon Sep 17 00:00:00 2001 From: manickavela29 Date: Mon, 3 Nov 2025 04:02:19 +0000 Subject: [PATCH 2/4] handling request rejections and tests Signed-off-by: manickavela29 --- .../_private/request_router/request_router.py | 65 +++++- python/ray/serve/multiplex.py | 134 ++++++++++++- .../tests/test_multiplex_batching_router.py | 14 +- .../tests/test_multiplex_batching_utils.py | 187 +++++++++++++++++- 4 files changed, 385 insertions(+), 15 deletions(-) diff --git a/python/ray/serve/_private/request_router/request_router.py b/python/ray/serve/_private/request_router/request_router.py index a42605fd2d38..89c2585439b4 100644 --- a/python/ray/serve/_private/request_router/request_router.py +++ b/python/ray/serve/_private/request_router/request_router.py @@ -217,6 +217,11 @@ def __init__(self, *args, **kwargs): # Batching-aware routing: track pending requests by model ID for better batching self._pending_requests_by_model_id: DefaultDict[str, List] = defaultdict(list) + # Counters for efficient cleanup + self._pending_requests_added_since_cleanup = 0 + self._last_cleanup_time = time.time() + self._cleanup_threshold = 50 # Cleanup after 50 new requests + self._cleanup_interval = 10.0 # Cleanup every 10 seconds def _get_pending_request_matching_multiplexed_model_id( self, @@ -239,14 +244,29 @@ def _track_pending_request_by_model_id(self, pending_request: PendingRequest): if pending_request.metadata.multiplexed_model_id: model_id = pending_request.metadata.multiplexed_model_id self._pending_requests_by_model_id[model_id].append(pending_request) + self._pending_requests_added_since_cleanup += 1 def _get_pending_requests_for_model(self, model_id: str) -> List[PendingRequest]: """Get all pending requests for a specific model ID.""" - return [pr for pr in self._pending_requests_by_model_id[model_id] - if not pr.future.done()] + # Filter out completed requests on-the-fly for immediate use + active_requests = [pr for pr in self._pending_requests_by_model_id[model_id] + if not pr.future.done()] + return active_requests + + def _should_cleanup_pending_requests(self) -> bool: + """Determine if we should perform cleanup based on counters and time.""" + return (self._pending_requests_added_since_cleanup >= self._cleanup_threshold or + (time.time() - self._last_cleanup_time) >= self._cleanup_interval) def _cleanup_completed_pending_requests(self): - """Clean up completed requests from model ID tracking.""" + """Clean up completed requests from model ID tracking efficiently.""" + # Only cleanup if we've accumulated enough requests or enough time has passed + if not self._should_cleanup_pending_requests(): + return + + cleanup_start = time.time() + total_requests_before = sum(len(requests) for requests in self._pending_requests_by_model_id.values()) + for model_id in list(self._pending_requests_by_model_id.keys()): self._pending_requests_by_model_id[model_id] = [ pr for pr in self._pending_requests_by_model_id[model_id] @@ -254,6 +274,17 @@ def _cleanup_completed_pending_requests(self): ] if not self._pending_requests_by_model_id[model_id]: del self._pending_requests_by_model_id[model_id] + + total_requests_after = sum(len(requests) for requests in self._pending_requests_by_model_id.values()) + cleanup_time = time.time() - cleanup_start + + # Reset counters + self._pending_requests_added_since_cleanup = 0 + self._last_cleanup_time = time.time() + + if total_requests_before != total_requests_after: + logger.debug(f"Cleaned up {total_requests_before - total_requests_after} completed requests " + f"in {cleanup_time:.3f}s, {total_requests_after} active requests remaining") def _update_multiplexed_model_ids_with_replicas( self, replicas: List[RunningReplica] @@ -349,9 +380,31 @@ def apply_multiplex_routing( if candidate_replica_ids and multiplexed_model_id: pending_for_model = self._get_pending_requests_for_model(multiplexed_model_id) if len(pending_for_model) > 1: # Multiple requests for same model - # Prefer replicas that are likely processing this model - logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, " - f"prioritizing batching-friendly routing") + # Find replicas that already have pending requests for this model + batching_friendly_replicas = set() + + for pending_req in pending_for_model: + # Check if this request has been assigned to a replica + if (pending_req.future.done() and + not pending_req.future.cancelled() and + not pending_req.future.exception()): + try: + assigned_replica = pending_req.future.result() + if (hasattr(assigned_replica, 'replica_id') and + assigned_replica.replica_id in candidate_replica_ids): + batching_friendly_replicas.add(assigned_replica.replica_id) + except Exception: + # Future might not have replica result, skip + pass + + # If we found replicas with pending requests for this model, prioritize them + if batching_friendly_replicas: + candidate_replica_ids = batching_friendly_replicas + logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, " + f"prioritizing {len(batching_friendly_replicas)} batching-friendly replicas") + else: + logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, " + f"but no batching-friendly replicas found in candidates") if ( not candidate_replica_ids diff --git a/python/ray/serve/multiplex.py b/python/ray/serve/multiplex.py index 1969d7788710..d6ea58baa078 100644 --- a/python/ray/serve/multiplex.py +++ b/python/ray/serve/multiplex.py @@ -160,9 +160,11 @@ async def model_batch_handler(batch_requests: List[Any]) -> List[Any]: Returns: List of results corresponding to each input. """ + # Re-check model availability at processing time to handle race conditions model = self.models.get(model_id) if model is None: - raise RuntimeError(f"Model {model_id} not loaded") + # Model was evicted, raise an exception that will cancel pending requests + raise RuntimeError(f"Model {model_id} was evicted during batch processing") # Try to use batch_predict method if available if hasattr(model, 'batch_predict'): @@ -192,6 +194,124 @@ async def model_batch_handler(batch_requests: List[Any]) -> List[Any]: return self._model_batch_queues[model_id] + async def _shutdown_batch_queue(self, batch_queue_wrapper: _LazyBatchQueueWrapper, model_id: str): + """Gracefully shutdown a batch queue by canceling pending requests and background tasks.""" + if batch_queue_wrapper._queue is None: + # Queue was never initialized, nothing to clean up + return + + batch_queue = batch_queue_wrapper._queue + + # Cancel the background processing task if it exists + if hasattr(batch_queue, '_handle_batch_task') and batch_queue._handle_batch_task: + batch_queue._handle_batch_task.cancel() + try: + await batch_queue._handle_batch_task + except asyncio.CancelledError: + pass # Expected when cancelling + + # Cancel all pending requests in the queue + pending_requests = [] + try: + while True: + try: + request = batch_queue.queue.get_nowait() + pending_requests.append(request) + except asyncio.QueueEmpty: + break + except Exception: + pass # Queue might be closed or corrupted + + # Handle pending requests gracefully - try to reassign rather than fail + reassigned_count = 0 + failed_count = 0 + + for request in pending_requests: + if not request.future.done(): + try: + # Try to reassign the request back to the routing system + if await self._try_reassign_request(request, model_id): + reassigned_count += 1 + else: + # If reassignment fails, set a descriptive error + request.future.set_exception( + RuntimeError(f"Model {model_id} was evicted and could not be reassigned") + ) + failed_count += 1 + except Exception: + # Future might already be done or other error, count as failed + failed_count += 1 + + logger.info(f"Shutdown batch queue for model {model_id}: reassigned {reassigned_count}, failed {failed_count} pending requests") + + async def _try_reassign_request(self, request: _SingleRequest, model_id: str) -> bool: + """Try to reassign a pending request back to the routing system. + + Args: + request: The pending request to reassign + model_id: The model ID that was evicted + + Returns: + True if request was successfully reassigned, False otherwise + """ + try: + # Extract the original input from the flattened args + if len(request.flattened_args) >= 2 and request.flattened_args[0] == DUMMY_TYPE: + original_input = request.flattened_args[1] + else: + # Fallback if format is unexpected + return False + + # Check if we have retry attempts left (prevent infinite loops) + retry_count = getattr(request, '_retry_count', 0) + if retry_count >= 2: # Max 2 retries + return False + + # Create a new async task to retry the request with backoff + async def retry_request(): + try: + # Add retry count to track attempts + setattr(request, '_retry_count', retry_count + 1) + + # Exponential backoff: wait longer for each retry + backoff_time = 0.01 * (2 ** retry_count) + await asyncio.sleep(backoff_time) + + # Try to process the request again - this will go through the full + # model loading process, potentially reloading on this replica + # Note: We call predict directly rather than batched_inference to avoid + # potential batching complications during retry + if self.enable_batching: + # For batching case, try individual prediction as fallback + model = await self.load_model(model_id) + if hasattr(model, 'predict'): + result = await model.predict(original_input) + elif callable(model): + result = await model(original_input) + else: + raise RuntimeError(f"Model {model_id} is not callable and has no predict method") + else: + result = await self.predict(original_input, model_id) + + # Set the result on the original future + if not request.future.done(): + request.future.set_result(result) + + except Exception as e: + # If retry fails, set the exception on the original future + if not request.future.done(): + request.future.set_exception( + RuntimeError(f"Model {model_id} evicted, retry failed: {str(e)}") + ) + + # Start the retry task in the background + asyncio.create_task(retry_request()) + return True + + except Exception as e: + logger.debug(f"Failed to reassign request for model {model_id}: {e}") + return False + async def batched_inference(self, model_id: str, request: Any) -> Any: """Perform batched inference on a specific model.""" if not self.enable_batching: @@ -292,6 +412,14 @@ async def shutdown(self): logger.exception( f"Failed to unload model. Error: {e}", ) + + # Clean up any remaining batch queues + for model_id, batch_queue_wrapper in list(self._model_batch_queues.items()): + try: + await self._shutdown_batch_queue(batch_queue_wrapper, model_id) + except Exception as e: + logger.exception(f"Failed to shutdown batch queue for model {model_id}. Error: {e}") + self._model_batch_queues.clear() async def load_model(self, model_id: str) -> Any: """Load the model if it is not loaded yet, and return @@ -373,8 +501,10 @@ async def unload_model_lru(self) -> None: model_id, model = self.models.popitem(last=False) logger.info(f"Unloading model '{model_id}'.") - # Clean up the batch queue for this model if it exists + # Gracefully shutdown the batch queue for this model if it exists if model_id in self._model_batch_queues: + batch_queue_wrapper = self._model_batch_queues[model_id] + await self._shutdown_batch_queue(batch_queue_wrapper, model_id) del self._model_batch_queues[model_id] # If the model has __del__ attribute, call it. diff --git a/python/ray/serve/tests/test_multiplex_batching_router.py b/python/ray/serve/tests/test_multiplex_batching_router.py index c1d7b2099493..cab67b7d2bcf 100644 --- a/python/ray/serve/tests/test_multiplex_batching_router.py +++ b/python/ray/serve/tests/test_multiplex_batching_router.py @@ -200,22 +200,24 @@ async def load_model(model_id: str): # Load model first model = await wrapper_batched.load_model("batched_model") - # Send concurrent requests to same model using the model directly + # Send concurrent requests to the wrapper to test batching mechanism start_time = time.time() tasks = [] for i in range(10): - task = model.batch_predict([f"data_{i}"]) + # Use wrapper.predict() to test the actual batching mechanism + task = wrapper_batched.predict(f"data_{i}", "batched_model") tasks.append(task) - results_nested = await asyncio.gather(*tasks) - # Flatten results since batch_predict returns lists - results = [item for sublist in results_nested for item in sublist] + results = await asyncio.gather(*tasks) batched_time = time.time() - start_time - # Check the model's batch predict was called + # Check that batch predict was called (indicating batching worked) assert model.batch_predict_count > 0, "Batch predict should be called" assert len(results) == 10, "All requests should complete" + # Verify results are correct format - should be from batch_predict + assert all("batch_batched_model" in result for result in results), f"Expected batch results, got: {results[:3]}" + # Test without batching for comparison TrackableModel.reset_tracking() diff --git a/python/ray/serve/tests/test_multiplex_batching_utils.py b/python/ray/serve/tests/test_multiplex_batching_utils.py index 2faa9b1dd0d8..594de5a8d774 100644 --- a/python/ray/serve/tests/test_multiplex_batching_utils.py +++ b/python/ray/serve/tests/test_multiplex_batching_utils.py @@ -405,4 +405,189 @@ def analyze_results(self, results: List[Dict], total_time: float): "en_de": ["Welcome", "Goodbye"] } ) -} \ No newline at end of file +} + + +# Test functions for the utility classes +@pytest.mark.asyncio +async def test_mock_embedding_model(): + """Test MockEmbeddingModel functionality.""" + model = MockEmbeddingModel("test_model", embedding_dim=128) + + # Test individual prediction + result = await model.predict("hello world") + assert len(result) == 128 + assert isinstance(result[0], float) + assert model.predict_calls == 1 + assert model.total_items_processed == 1 + + # Test batch prediction + texts = ["hello", "world", "test"] + batch_result = await model.batch_predict(texts) + assert len(batch_result) == 3 + assert len(batch_result[0]) == 128 + assert model.batch_predict_calls == 1 + assert model.total_items_processed == 4 # 1 + 3 + + # Test stats + stats = model.get_stats() + assert stats["model_id"] == "test_model" + assert stats["embedding_dim"] == 128 + assert stats["predict_calls"] == 1 + assert stats["batch_predict_calls"] == 1 + + +@pytest.mark.asyncio +async def test_mock_classification_model(): + """Test MockClassificationModel functionality.""" + model = MockClassificationModel("sentiment", num_classes=3) + + # Test individual prediction + result = await model.predict("I love this!") + assert len(result) == 3 + assert all(f"class_{i}" in result for i in range(3)) + assert abs(sum(result.values()) - 1.0) < 1e-6 # Should sum to 1 + assert model.predict_calls == 1 + + # Test batch prediction + texts = ["good", "bad", "neutral"] + batch_result = await model.batch_predict(texts) + assert len(batch_result) == 3 + assert all(len(result) == 3 for result in batch_result) + assert model.batch_predict_calls == 1 + + +@pytest.mark.asyncio +async def test_mock_translation_model(): + """Test MockTranslationModel functionality.""" + model = MockTranslationModel("en_es", source_lang="en", target_lang="es") + + # Test individual translation + result = await model.translate("hello") + assert result.startswith("[es]") + assert "olleh" in result # reversed text + assert model.translate_calls == 1 + + # Test batch translation + texts = ["hello", "world"] + batch_result = await model.batch_translate(texts) + assert len(batch_result) == 2 + assert all(result.startswith("[es]") for result in batch_result) + assert model.batch_translate_calls == 1 + + +@pytest.mark.asyncio +async def test_batching_test_helper(): + """Test BatchingTestHelper functionality.""" + from ray.serve.multiplex import _ModelMultiplexWrapper + + # Mock a simple wrapper + class MockWrapper: + def __init__(self): + self.call_count = 0 + + async def predict(self, input_data, model_id): + self.call_count += 1 + await asyncio.sleep(0.01) + return f"result_{model_id}_{input_data}" + + wrapper = MockWrapper() + inputs = ["test1", "test2", "test3"] + + # Test concurrent requests + results, total_time = await BatchingTestHelper.send_concurrent_requests( + wrapper, inputs, "test_model" + ) + + assert len(results) == 3 + assert wrapper.call_count == 3 + assert total_time > 0 + + # Test efficiency verification + individual_time = 0.1 + batch_time = 0.05 + speedup = BatchingTestHelper.verify_batching_efficiency( + individual_time, batch_time, 3, min_speedup=1.5 + ) + assert speedup == 2.0 + + +def test_multi_model_test_scenario(): + """Test MultiModelTestScenario functionality.""" + scenario = TEST_SCENARIOS["embedding_workload"] + + assert len(scenario.models) == 3 + assert "mini" in scenario.models + assert "base" in scenario.models + assert "large" in scenario.models + + # Test analyze_results + mock_results = [ + {"model_id": "mini", "result": "result1", "timestamp": 1.0}, + {"model_id": "mini", "result": "result2", "timestamp": 1.1}, + {"model_id": "base", "result": "result3", "timestamp": 1.2}, + ] + + analysis = scenario.analyze_results(mock_results, 0.5) + assert analysis["total_requests"] == 3 + assert analysis["total_time"] == 0.5 + assert analysis["models_used"] == 2 + assert analysis["requests_per_model"]["mini"] == 2 + assert analysis["requests_per_model"]["base"] == 1 + + +def test_fixtures_return_correct_types(embedding_model_loader, classification_model_loader, + translation_model_loader, sample_texts, performance_test_config): + """Test that all fixtures return the expected types.""" + # Test embedding model loader + loader, models = embedding_model_loader + assert callable(loader) + assert isinstance(models, dict) + + # Test classification model loader + loader, models = classification_model_loader + assert callable(loader) + assert isinstance(models, dict) + + # Test translation model loader + loader, models = translation_model_loader + assert callable(loader) + assert isinstance(models, dict) + + # Test sample texts + assert isinstance(sample_texts, list) + assert len(sample_texts) > 0 + assert all(isinstance(text, str) for text in sample_texts) + + # Test performance config + assert isinstance(performance_test_config, dict) + assert "small_batch" in performance_test_config + assert "min_speedup" in performance_test_config + + +@pytest.mark.asyncio +async def test_embedding_model_loader_fixture(embedding_model_loader): + """Test the embedding model loader fixture.""" + loader, models = embedding_model_loader + + # Load a model + model = await loader("mini") + assert isinstance(model, MockEmbeddingModel) + assert model.model_id == "mini" + assert model.embedding_dim == 384 + + # Check it's cached + model2 = await loader("mini") + assert model is model2 + assert len(models) == 1 + + # Load different model + model3 = await loader("base") + assert model3.embedding_dim == 768 + assert len(models) == 2 + + +if __name__ == "__main__": + # Run all tests in this module + import sys + pytest.main(["-v", "-s", __file__] + sys.argv[1:]) \ No newline at end of file From 3c639b4c58fd5ea40ae5235db9059d23b4f0df07 Mon Sep 17 00:00:00 2001 From: manickavela29 Date: Mon, 3 Nov 2025 06:13:50 +0000 Subject: [PATCH 3/4] bazel pytest Signed-off-by: manickavela29 --- python/ray/serve/tests/BUILD.bazel | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/serve/tests/BUILD.bazel b/python/ray/serve/tests/BUILD.bazel index 06b4005ad628..7490a0cccb13 100644 --- a/python/ray/serve/tests/BUILD.bazel +++ b/python/ray/serve/tests/BUILD.bazel @@ -121,6 +121,9 @@ py_test_module_list( "test_https_proxy.py", "test_max_replicas_per_node.py", "test_multiplex.py", + "test_multiplex_batching.py", + "test_multiplex_batching_router.py", + "test_multiplex_batching_utils.py", "test_proxy.py", "test_proxy_response_generator.py", "test_ray_client.py", From 8cdf2406968d167c7d6b1a32e355fb0a21ee0090 Mon Sep 17 00:00:00 2001 From: manickavela29 Date: Mon, 3 Nov 2025 06:55:15 +0000 Subject: [PATCH 4/4] lint fix with router counter fix Signed-off-by: manickavela29 --- .../_private/request_router/request_router.py | 155 +++++++--- python/ray/serve/multiplex.py | 171 ++++++----- .../serve/tests/test_multiplex_batching.py | 193 ++++++------- .../tests/test_multiplex_batching_router.py | 249 ++++++++-------- .../tests/test_multiplex_batching_utils.py | 271 +++++++++--------- 5 files changed, 579 insertions(+), 460 deletions(-) diff --git a/python/ray/serve/_private/request_router/request_router.py b/python/ray/serve/_private/request_router/request_router.py index 89c2585439b4..776e4ac64b03 100644 --- a/python/ray/serve/_private/request_router/request_router.py +++ b/python/ray/serve/_private/request_router/request_router.py @@ -196,7 +196,7 @@ class MultiplexMixin: It adds necessary attributes and methods to keep track of multiplexed model IDs and offer the helpers to apply multiplex routing and rank replicas based on multiplexed model IDs. - + Now supports batching-aware routing to group requests by model ID for optimal batching performance. """ @@ -214,7 +214,7 @@ def __init__(self, *args, **kwargs): self._multiplexed_model_id_fallback_match: Set[str] = set() self._replica_id_set: Set[ReplicaID] = set() self._replicas: Dict[ReplicaID, RunningReplica] = {} - + # Batching-aware routing: track pending requests by model ID for better batching self._pending_requests_by_model_id: DefaultDict[str, List] = defaultdict(list) # Counters for efficient cleanup @@ -222,6 +222,7 @@ def __init__(self, *args, **kwargs): self._last_cleanup_time = time.time() self._cleanup_threshold = 50 # Cleanup after 50 new requests self._cleanup_interval = 10.0 # Cleanup every 10 seconds + self._cleanup_task = None # Track async cleanup task def _get_pending_request_matching_multiplexed_model_id( self, @@ -249,42 +250,104 @@ def _track_pending_request_by_model_id(self, pending_request: PendingRequest): def _get_pending_requests_for_model(self, model_id: str) -> List[PendingRequest]: """Get all pending requests for a specific model ID.""" # Filter out completed requests on-the-fly for immediate use - active_requests = [pr for pr in self._pending_requests_by_model_id[model_id] - if not pr.future.done()] + # and update the list in-place to avoid accumulating completed requests + if model_id not in self._pending_requests_by_model_id: + return [] + + active_requests = [] + completed_count = 0 + + for pr in self._pending_requests_by_model_id[model_id]: + if not pr.future.done(): + active_requests.append(pr) + else: + completed_count += 1 + + # Update the stored list with only active requests to prevent accumulation + if completed_count > 0: + self._pending_requests_by_model_id[model_id] = active_requests + if not active_requests: + del self._pending_requests_by_model_id[model_id] + + # Trigger periodic cleanup if we've seen enough completed requests + if completed_count > 0 and self._should_cleanup_pending_requests(): + # Schedule cleanup asynchronously to avoid blocking routing + self._schedule_async_cleanup() + return active_requests def _should_cleanup_pending_requests(self) -> bool: """Determine if we should perform cleanup based on counters and time.""" - return (self._pending_requests_added_since_cleanup >= self._cleanup_threshold or - (time.time() - self._last_cleanup_time) >= self._cleanup_interval) + return ( + self._pending_requests_added_since_cleanup >= self._cleanup_threshold + or (time.time() - self._last_cleanup_time) >= self._cleanup_interval + ) def _cleanup_completed_pending_requests(self): """Clean up completed requests from model ID tracking efficiently.""" # Only cleanup if we've accumulated enough requests or enough time has passed if not self._should_cleanup_pending_requests(): return - + cleanup_start = time.time() - total_requests_before = sum(len(requests) for requests in self._pending_requests_by_model_id.values()) - + total_requests_before = sum( + len(requests) for requests in self._pending_requests_by_model_id.values() + ) + for model_id in list(self._pending_requests_by_model_id.keys()): self._pending_requests_by_model_id[model_id] = [ - pr for pr in self._pending_requests_by_model_id[model_id] + pr + for pr in self._pending_requests_by_model_id[model_id] if not pr.future.done() ] if not self._pending_requests_by_model_id[model_id]: del self._pending_requests_by_model_id[model_id] - - total_requests_after = sum(len(requests) for requests in self._pending_requests_by_model_id.values()) + + total_requests_after = sum( + len(requests) for requests in self._pending_requests_by_model_id.values() + ) cleanup_time = time.time() - cleanup_start - + # Reset counters self._pending_requests_added_since_cleanup = 0 self._last_cleanup_time = time.time() - + if total_requests_before != total_requests_after: - logger.debug(f"Cleaned up {total_requests_before - total_requests_after} completed requests " - f"in {cleanup_time:.3f}s, {total_requests_after} active requests remaining") + logger.debug( + f"Cleaned up {total_requests_before - total_requests_after} " + f"completed requests in {cleanup_time:.3f}s, " + f"{total_requests_after} active requests remaining" + ) + + def _schedule_async_cleanup(self): + """Schedule cleanup to run asynchronously without blocking routing.""" + # Only schedule if cleanup isn't already running + if ( + not hasattr(self, "_cleanup_task") + or self._cleanup_task is None + or self._cleanup_task.done() + ): + import asyncio + + try: + # Get the current event loop + loop = asyncio.get_event_loop() + self._cleanup_task = loop.create_task(self._async_cleanup()) + except RuntimeError: + # If no event loop is running, fall back to synchronous cleanup + # This should rarely happen in the Ray Serve context + self._cleanup_completed_pending_requests() + + async def _async_cleanup(self): + """Perform cleanup asynchronously.""" + try: + # Small delay to avoid blocking the current operation + await asyncio.sleep(0.001) + self._cleanup_completed_pending_requests() + except Exception as e: + logger.warning(f"Async cleanup failed: {e}") + finally: + self._cleanup_task = None def _update_multiplexed_model_ids_with_replicas( self, replicas: List[RunningReplica] @@ -354,8 +417,6 @@ def apply_multiplex_routing( # Track this request for batching-aware routing self._track_pending_request_by_model_id(pending_request) - # Clean up completed requests periodically - self._cleanup_completed_pending_requests() if not pending_request.routing_context.multiplexed_start_matching_time: pending_request.routing_context.multiplexed_start_matching_time = ( @@ -366,7 +427,7 @@ def apply_multiplex_routing( pending_request.routing_context.multiplexed_start_matching_time ) multiplexed_model_id = pending_request.metadata.multiplexed_model_id - + if ( time.time() - multiplexed_start_matching_time < self._multiplexed_matching_timeout @@ -374,38 +435,55 @@ def apply_multiplex_routing( candidate_replica_ids = self._multiplexed_model_id_to_replica_ids.get( multiplexed_model_id, None ) - + # Batching-aware enhancement: prioritize replicas with pending requests # for the same model ID to improve batching efficiency if candidate_replica_ids and multiplexed_model_id: - pending_for_model = self._get_pending_requests_for_model(multiplexed_model_id) + pending_for_model = self._get_pending_requests_for_model( + multiplexed_model_id + ) if len(pending_for_model) > 1: # Multiple requests for same model # Find replicas that already have pending requests for this model batching_friendly_replicas = set() - + for pending_req in pending_for_model: # Check if this request has been assigned to a replica - if (pending_req.future.done() and - not pending_req.future.cancelled() and - not pending_req.future.exception()): + if ( + pending_req.future.done() + and not pending_req.future.cancelled() + and not pending_req.future.exception() + ): try: assigned_replica = pending_req.future.result() - if (hasattr(assigned_replica, 'replica_id') and - assigned_replica.replica_id in candidate_replica_ids): - batching_friendly_replicas.add(assigned_replica.replica_id) + if ( + hasattr(assigned_replica, "replica_id") + and assigned_replica.replica_id + in candidate_replica_ids + ): + batching_friendly_replicas.add( + assigned_replica.replica_id + ) except Exception: # Future might not have replica result, skip pass - - # If we found replicas with pending requests for this model, prioritize them + + # If we found replicas with pending requests for this model, + # prioritize them if batching_friendly_replicas: candidate_replica_ids = batching_friendly_replicas - logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, " - f"prioritizing {len(batching_friendly_replicas)} batching-friendly replicas") + logger.debug( + f"Found {len(pending_for_model)} pending requests for " + f"model {multiplexed_model_id}, prioritizing " + f"{len(batching_friendly_replicas)} batching-friendly " + f"replicas" + ) else: - logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, " - f"but no batching-friendly replicas found in candidates") - + logger.debug( + f"Found {len(pending_for_model)} pending requests for " + f"model {multiplexed_model_id}, but no batching-friendly " + f"replicas found in candidates" + ) + if ( not candidate_replica_ids and multiplexed_model_id @@ -596,7 +674,8 @@ def __init__( # We keep two separate queues of pending requests: # - self._pending_requests_to_fulfill is a queue that will be used to fulfill - # requests (potentially out of order) by routing tasks once they've acquired a replica. + # requests (potentially out of order) by routing tasks once they've + # acquired a replica. # - self.routing is a queue that is used for tasks to # best-effort grab the metadata of requests waiting to be fulfilled. This is # currently used for routing tasks to know which multiplexed model IDs they @@ -637,8 +716,8 @@ def __init__( def initialize_state(self, **kwargs): """ - Initialize the state of the request router. Called by the Ray Serve framework with the - contents of `RequestRouter.request_router_kwargs`. + Initialize the state of the request router. Called by the Ray Serve + framework with the contents of `RequestRouter.request_router_kwargs`. """ pass diff --git a/python/ray/serve/multiplex.py b/python/ray/serve/multiplex.py index d6ea58baa078..7d3e791c2de3 100644 --- a/python/ray/serve/multiplex.py +++ b/python/ray/serve/multiplex.py @@ -3,8 +3,10 @@ import logging import time from collections import OrderedDict -from typing import Any, Callable, List, Set, Optional +from typing import Any, Callable, List, Optional, Set +import ray.serve.context as context +from ray._common.signature import DUMMY_TYPE from ray.serve import metrics from ray.serve._private.common import ReplicaID, RequestRoutingInfo from ray.serve._private.constants import ( @@ -14,9 +16,8 @@ ) from ray.serve._private.metrics_utils import MetricsPusher from ray.serve._private.usage import ServeUsageTag -from ray.serve.context import _get_global_client, _get_internal_replica_context from ray.serve.batching import _LazyBatchQueueWrapper, _SingleRequest -from ray._common.signature import DUMMY_TYPE +from ray.serve.context import _get_global_client, _get_internal_replica_context logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -51,15 +52,17 @@ def __init__( """Initialize the model multiplexer. Args: model_load_func: the model load async function. - self_arg: self argument when model_load_func is class method. Default is None - for standalone functions. + self_arg: self argument when model_load_func is class method. + Default is None for standalone functions. max_num_models_per_replica: the maximum number of models to be loaded on the current replica. If it is -1, there is no limit for the number of models per replica. Default is 3. enable_batching: whether to enable batching for model inference calls. Default is False. - max_batch_size: maximum batch size for batched inference calls. Default is 10. - batch_wait_timeout_s: timeout for batching inference calls. Default is 0.01s. + max_batch_size: maximum batch size for batched inference calls. + Default is 10. + batch_wait_timeout_s: timeout for batching inference calls. + Default is 0.01s. max_concurrent_batches: maximum number of concurrent batches. Default is 1. """ @@ -69,13 +72,13 @@ def __init__( self._func: Callable = model_load_func self.self_arg: Any = self_arg self.max_num_models_per_replica: int = max_num_models_per_replica - + # Batching configuration self.enable_batching = enable_batching self.max_batch_size = max_batch_size self.batch_wait_timeout_s = batch_wait_timeout_s self.max_concurrent_batches = max_concurrent_batches - + # Model-specific batch queues for inference batching self._model_batch_queues: dict[str, _LazyBatchQueueWrapper] = {} @@ -144,72 +147,85 @@ def __init__( ) self.metrics_pusher.start() - def _get_or_create_batch_queue(self, model_id: str) -> Optional[_LazyBatchQueueWrapper]: + def _get_or_create_batch_queue( + self, model_id: str + ) -> Optional[_LazyBatchQueueWrapper]: """Get or create a batch queue for a specific model.""" if not self.enable_batching: return None - + if model_id not in self._model_batch_queues: # Create a batch handler for this specific model async def model_batch_handler(batch_requests: List[Any]) -> List[Any]: """Handle batched inference for a specific model. - + Args: batch_requests: List of input data items to process as a batch. - + Returns: List of results corresponding to each input. """ - # Re-check model availability at processing time to handle race conditions + # Re-check model availability at processing time to handle + # race conditions model = self.models.get(model_id) if model is None: - # Model was evicted, raise an exception that will cancel pending requests - raise RuntimeError(f"Model {model_id} was evicted during batch processing") - + # Model was evicted, raise an exception that will cancel + # pending requests + raise RuntimeError( + f"Model {model_id} was evicted during batch processing" + ) + # Try to use batch_predict method if available - if hasattr(model, 'batch_predict'): + if hasattr(model, "batch_predict"): results = await model.batch_predict(batch_requests) else: # Fallback to individual prediction calls results = [] for request_data in batch_requests: - if hasattr(model, 'predict'): + if hasattr(model, "predict"): result = await model.predict(request_data) elif callable(model): result = await model(request_data) else: raise RuntimeError( - f"Model {model_id} is not callable and has no predict method" + f"Model {model_id} is not callable and has no " + f"predict method" ) results.append(result) - + return results - + self._model_batch_queues[model_id] = _LazyBatchQueueWrapper( max_batch_size=self.max_batch_size, batch_wait_timeout_s=self.batch_wait_timeout_s, max_concurrent_batches=self.max_concurrent_batches, handle_batch_func=model_batch_handler, ) - + return self._model_batch_queues[model_id] - async def _shutdown_batch_queue(self, batch_queue_wrapper: _LazyBatchQueueWrapper, model_id: str): - """Gracefully shutdown a batch queue by canceling pending requests and background tasks.""" + async def _shutdown_batch_queue( + self, batch_queue_wrapper: _LazyBatchQueueWrapper, model_id: str + ): + """Gracefully shutdown a batch queue by canceling pending requests + and background tasks.""" if batch_queue_wrapper._queue is None: # Queue was never initialized, nothing to clean up return - + batch_queue = batch_queue_wrapper._queue - + # Cancel the background processing task if it exists - if hasattr(batch_queue, '_handle_batch_task') and batch_queue._handle_batch_task: + if ( + hasattr(batch_queue, "_handle_batch_task") + and batch_queue._handle_batch_task + ): batch_queue._handle_batch_task.cancel() try: await batch_queue._handle_batch_task except asyncio.CancelledError: pass # Expected when cancelling - + # Cancel all pending requests in the queue pending_requests = [] try: @@ -221,11 +237,11 @@ async def _shutdown_batch_queue(self, batch_queue_wrapper: _LazyBatchQueueWrappe break except Exception: pass # Queue might be closed or corrupted - + # Handle pending requests gracefully - try to reassign rather than fail reassigned_count = 0 failed_count = 0 - + for request in pending_requests: if not request.future.done(): try: @@ -235,79 +251,95 @@ async def _shutdown_batch_queue(self, batch_queue_wrapper: _LazyBatchQueueWrappe else: # If reassignment fails, set a descriptive error request.future.set_exception( - RuntimeError(f"Model {model_id} was evicted and could not be reassigned") + RuntimeError( + f"Model {model_id} was evicted and could not be " + f"reassigned" + ) ) failed_count += 1 except Exception: # Future might already be done or other error, count as failed failed_count += 1 - - logger.info(f"Shutdown batch queue for model {model_id}: reassigned {reassigned_count}, failed {failed_count} pending requests") - async def _try_reassign_request(self, request: _SingleRequest, model_id: str) -> bool: + logger.info( + f"Shutdown batch queue for model {model_id}: " + f"reassigned {reassigned_count}, failed {failed_count} pending requests" + ) + + async def _try_reassign_request( + self, request: _SingleRequest, model_id: str + ) -> bool: """Try to reassign a pending request back to the routing system. - + Args: request: The pending request to reassign model_id: The model ID that was evicted - + Returns: True if request was successfully reassigned, False otherwise """ try: # Extract the original input from the flattened args - if len(request.flattened_args) >= 2 and request.flattened_args[0] == DUMMY_TYPE: + if ( + len(request.flattened_args) >= 2 + and request.flattened_args[0] == DUMMY_TYPE + ): original_input = request.flattened_args[1] else: # Fallback if format is unexpected return False - + # Check if we have retry attempts left (prevent infinite loops) - retry_count = getattr(request, '_retry_count', 0) + retry_count = getattr(request, "_retry_count", 0) if retry_count >= 2: # Max 2 retries return False - + # Create a new async task to retry the request with backoff async def retry_request(): try: # Add retry count to track attempts - setattr(request, '_retry_count', retry_count + 1) - + request._retry_count = retry_count + 1 + # Exponential backoff: wait longer for each retry - backoff_time = 0.01 * (2 ** retry_count) + backoff_time = 0.01 * (2**retry_count) await asyncio.sleep(backoff_time) - + # Try to process the request again - this will go through the full # model loading process, potentially reloading on this replica - # Note: We call predict directly rather than batched_inference to avoid - # potential batching complications during retry + # Note: We call predict directly rather than batched_inference to + # avoid potential batching complications during retry if self.enable_batching: # For batching case, try individual prediction as fallback model = await self.load_model(model_id) - if hasattr(model, 'predict'): + if hasattr(model, "predict"): result = await model.predict(original_input) elif callable(model): result = await model(original_input) else: - raise RuntimeError(f"Model {model_id} is not callable and has no predict method") + raise RuntimeError( + f"Model {model_id} is not callable and has no " + f"predict method" + ) else: result = await self.predict(original_input, model_id) - + # Set the result on the original future if not request.future.done(): request.future.set_result(result) - + except Exception as e: # If retry fails, set the exception on the original future if not request.future.done(): request.future.set_exception( - RuntimeError(f"Model {model_id} evicted, retry failed: {str(e)}") + RuntimeError( + f"Model {model_id} evicted, retry failed: {str(e)}" + ) ) - + # Start the retry task in the background asyncio.create_task(retry_request()) return True - + except Exception as e: logger.debug(f"Failed to reassign request for model {model_id}: {e}") return False @@ -316,40 +348,39 @@ async def batched_inference(self, model_id: str, request: Any) -> Any: """Perform batched inference on a specific model.""" if not self.enable_batching: raise RuntimeError("Batching is not enabled for this multiplexer") - + # Ensure model is loaded first await self.load_model(model_id) - + # Get the batch queue for this model batch_queue = self._get_or_create_batch_queue(model_id) if batch_queue is None: raise RuntimeError("Failed to create batch queue") - + # Submit request to the batch queue using _SingleRequest format - import ray.serve.context as context future = asyncio.get_event_loop().create_future() request_context = context._get_serve_request_context() - + # Create _SingleRequest with flattened args using DUMMY_TYPE for positional args # Format: [DUMMY_TYPE, arg1, DUMMY_TYPE, arg2, ...] for positional args single_request = _SingleRequest( self_arg=None, flattened_args=[DUMMY_TYPE, request], future=future, - request_context=request_context + request_context=request_context, ) - + batch_queue.queue.put(single_request) - + return await future async def predict(self, input_data: Any, model_id: str) -> Any: """Convenience method for model prediction with optional batching. - + Args: input_data: The input data to predict on. model_id: The model ID to use for prediction. - + Returns: The prediction result. """ @@ -359,9 +390,9 @@ async def predict(self, input_data: Any, model_id: str) -> Any: else: # Load model and call directly model = await self.load_model(model_id) - + # Try different prediction methods - if hasattr(model, 'predict'): + if hasattr(model, "predict"): result = await model.predict(input_data) elif callable(model): result = await model(input_data) @@ -369,7 +400,7 @@ async def predict(self, input_data: Any, model_id: str) -> Any: raise RuntimeError( f"Model {model_id} is not callable and has no predict method" ) - + return result def _get_loading_and_loaded_model_ids(self) -> List[str]: @@ -412,13 +443,15 @@ async def shutdown(self): logger.exception( f"Failed to unload model. Error: {e}", ) - + # Clean up any remaining batch queues for model_id, batch_queue_wrapper in list(self._model_batch_queues.items()): try: await self._shutdown_batch_queue(batch_queue_wrapper, model_id) except Exception as e: - logger.exception(f"Failed to shutdown batch queue for model {model_id}. Error: {e}") + logger.exception( + f"Failed to shutdown batch queue for model {model_id}. Error: {e}" + ) self._model_batch_queues.clear() async def load_model(self, model_id: str) -> Any: diff --git a/python/ray/serve/tests/test_multiplex_batching.py b/python/ray/serve/tests/test_multiplex_batching.py index 2580482971a4..f9a23ba8060c 100644 --- a/python/ray/serve/tests/test_multiplex_batching.py +++ b/python/ray/serve/tests/test_multiplex_batching.py @@ -7,42 +7,33 @@ import asyncio import time -import math -from concurrent.futures import ThreadPoolExecutor -from typing import List, Dict, Any, Optional -from unittest.mock import AsyncMock, patch +from typing import List import pytest -import httpx import ray from ray import serve -from ray._common.test_utils import SignalActor, wait_for_condition from ray.serve._private.common import DeploymentID, ReplicaID from ray.serve._private.config import DeploymentConfig -from ray.serve._private.constants import SERVE_MULTIPLEXED_MODEL_ID -from ray.serve._private.request_router import RequestRouter -from ray.serve.context import _get_internal_replica_context -from ray.serve.handle import DeploymentHandle from ray.serve.multiplex import _ModelMultiplexWrapper class MockModel: """Mock model for testing multiplexing and batching.""" - + def __init__(self, model_id: str, processing_time: float = 0.1): self.model_id = model_id self.processing_time = processing_time self.call_count = 0 self.batch_call_count = 0 self.last_batch_size = 0 - + async def predict(self, input_data): """Individual prediction method.""" await asyncio.sleep(self.processing_time) self.call_count += 1 return f"result_{self.model_id}_{input_data}" - + async def batch_predict(self, input_batch: List): """Batch prediction method.""" await asyncio.sleep(self.processing_time * 0.6) # Batch efficiency @@ -76,12 +67,13 @@ def start_serve_with_context(): @pytest.mark.asyncio class TestMultiplexBatchingIntegration: """Test the integration of multiplexing with batching.""" - + async def test_basic_batching_integration(self, start_serve_with_context): """Test that multiplexing works with batching enabled.""" - + async def mock_model_loader(model_id: str): return MockModel(model_id, processing_time=0.05) + print("creating multiplex wrapper") # Create wrapper with batching enabled wrapper = _ModelMultiplexWrapper( @@ -89,12 +81,12 @@ async def mock_model_loader(model_id: str): max_num_models_per_replica=3, enable_batching=True, max_batch_size=4, - batch_wait_timeout_s=0.1 + batch_wait_timeout_s=0.1, ) - print('create wrapper') - + print("create wrapper") + # Test concurrent requests to same model - should be batched - print('starting tasks') + print("starting tasks") start_time = time.time() tasks = [] for i in range(6): @@ -104,275 +96,266 @@ async def mock_model_loader(model_id: str): results = await asyncio.gather(*tasks) total_time = time.time() - start_time print("completed gather") - + # Verify results assert len(results) == 6 assert all("batch_result_model_a" in result for result in results) - + # Should have been processed in batches assert total_time < 0.5 # Much faster than 6 individual calls - - async def test_multiplex_with_batching_different_models(self, start_serve_with_context): + + async def test_multiplex_with_batching_different_models( + self, start_serve_with_context + ): """Test multiplexing across different models with batching.""" - + models = {} - + async def mock_model_loader(model_id: str): if model_id not in models: models[model_id] = MockModel(model_id, processing_time=0.03) return models[model_id] - + wrapper = _ModelMultiplexWrapper( model_load_func=mock_model_loader, max_num_models_per_replica=3, enable_batching=True, max_batch_size=3, - batch_wait_timeout_s=0.05 + batch_wait_timeout_s=0.05, ) - + # Send requests to different models concurrently tasks = [] for model_id in ["model_a", "model_b", "model_c"]: for i in range(3): task = wrapper.predict(f"input_{i}", model_id) tasks.append(task) - + results = await asyncio.gather(*tasks) - + # Verify all models were used assert len(models) == 3 assert all(model.batch_call_count > 0 for model in models.values()) - + # Verify results from all models model_a_results = [r for r in results if "model_a" in r] model_b_results = [r for r in results if "model_b" in r] model_c_results = [r for r in results if "model_c" in r] - + assert len(model_a_results) == 3 assert len(model_b_results) == 3 assert len(model_c_results) == 3 - + async def test_batching_timeout_behavior(self, start_serve_with_context): """Test batch timeout behavior with multiplexing.""" - + async def mock_model_loader(model_id: str): return MockModel(model_id, processing_time=0.01) - + wrapper = _ModelMultiplexWrapper( model_load_func=mock_model_loader, max_num_models_per_replica=2, enable_batching=True, max_batch_size=5, - batch_wait_timeout_s=0.1 # 100ms timeout + batch_wait_timeout_s=0.1, # 100ms timeout ) - + # Send single request and measure time start_time = time.time() result = await wrapper.predict("single_input", "model_timeout") elapsed_time = time.time() - start_time - + # Should process after timeout even with single request assert "batch_result_model_timeout" in result assert elapsed_time >= 0.1 # At least the timeout duration - + async def test_max_batch_size_enforcement(self, start_serve_with_context): """Test that max batch size is enforced properly.""" - + model_instance = MockModel("model_batch_size", processing_time=0.02) - + async def mock_model_loader(model_id: str): return model_instance - + wrapper = _ModelMultiplexWrapper( model_load_func=mock_model_loader, max_num_models_per_replica=1, enable_batching=True, max_batch_size=3, # Small batch size - batch_wait_timeout_s=0.05 + batch_wait_timeout_s=0.05, ) - + # Send more requests than max batch size tasks = [] for i in range(7): # More than max_batch_size task = wrapper.predict(f"input_{i}", "model_batch_size") tasks.append(task) - + results = await asyncio.gather(*tasks) - + # All requests should complete assert len(results) == 7 - + # Should have made multiple batch calls due to max_batch_size limit assert model_instance.batch_call_count >= 3 # At least 3 batches for 7 items - + async def test_model_eviction_with_batching(self, start_serve_with_context): """Test LRU model eviction works with batching.""" - + models = {} - + async def mock_model_loader(model_id: str): if model_id not in models: models[model_id] = MockModel(model_id) return models[model_id] - + wrapper = _ModelMultiplexWrapper( model_load_func=mock_model_loader, max_num_models_per_replica=2, # Small cache enable_batching=True, max_batch_size=3, - batch_wait_timeout_s=0.05 + batch_wait_timeout_s=0.05, ) - + # Load models sequentially to trigger eviction await wrapper.predict("input1", "model_1") await wrapper.predict("input2", "model_2") await wrapper.predict("input3", "model_3") # Should evict model_1 - + # Verify model_1 was evicted by checking cache size # This is implementation dependent but we can test behavior await wrapper.predict("input4", "model_1") # Should reload model_1 - + # All models should have been created assert len(models) == 3 - + async def test_batching_disabled_fallback(self, start_serve_with_context): """Test that individual prediction works when batching is disabled.""" - + model_instance = MockModel("model_no_batch", processing_time=0.01) - + async def mock_model_loader(model_id: str): return model_instance - + wrapper = _ModelMultiplexWrapper( model_load_func=mock_model_loader, max_num_models_per_replica=2, enable_batching=False, # Batching disabled max_batch_size=5, - batch_wait_timeout_s=0.1 + batch_wait_timeout_s=0.1, ) - + # Send multiple requests tasks = [] for i in range(3): task = wrapper.predict(f"input_{i}", "model_no_batch") tasks.append(task) - + results = await asyncio.gather(*tasks) - + # Should use individual prediction, not batching assert len(results) == 3 assert model_instance.call_count == 3 # Individual calls assert model_instance.batch_call_count == 0 # No batch calls - + async def test_concurrent_models_with_batching(self, start_serve_with_context): """Test concurrent access to different models with batching.""" - + models = {} - + async def mock_model_loader(model_id: str): if model_id not in models: models[model_id] = MockModel(model_id, processing_time=0.02) return models[model_id] - + wrapper = _ModelMultiplexWrapper( model_load_func=mock_model_loader, max_num_models_per_replica=4, enable_batching=True, max_batch_size=2, - batch_wait_timeout_s=0.03 + batch_wait_timeout_s=0.03, ) - + # Create concurrent requests to multiple models start_time = time.time() tasks = [] - + # 2 requests to each of 3 models for model_id in ["fast_model", "medium_model", "slow_model"]: for i in range(2): task = wrapper.predict(f"data_{i}", model_id) tasks.append(task) - + results = await asyncio.gather(*tasks) total_time = time.time() - start_time - + # All requests should complete assert len(results) == 6 - + # Should process efficiently due to batching assert total_time < 0.3 # Much faster than sequential - + # Each model should have been called once in batch mode for model_id in ["fast_model", "medium_model", "slow_model"]: assert models[model_id].batch_call_count >= 1 assert models[model_id].last_batch_size == 2 -@pytest.mark.asyncio +@pytest.mark.asyncio class TestMultiplexBatchingAPI: """Test the API integration for multiplexed batching.""" - - async def test_serve_multiplexed_decorator_with_batching(self, start_serve_with_context): + + async def test_serve_multiplexed_decorator_with_batching( + self, start_serve_with_context + ): """Test the @serve.multiplexed decorator with batching parameters.""" - + # Mock the decorator functionality from ray.serve.api import multiplexed - from ray.serve.multiplex import _ModelMultiplexWrapper - + # Create a model class class TestModel: def __init__(self, model_id: str): self.model_id = model_id - + async def predict(self, data): return f"result_{self.model_id}_{data}" - + async def batch_predict(self, data_list): return [f"batch_{self.model_id}_{item}" for item in data_list] - + # Create a model loading function (this is what gets decorated) async def load_model(model_id: str): return TestModel(model_id) - + # Apply the decorator to the loading function decorated_load_model = multiplexed( max_num_models_per_replica=3, enable_batching=True, max_batch_size=4, - batch_wait_timeout_s=0.1 + batch_wait_timeout_s=0.1, )(load_model) - + # Verify the decorator returns a callable assert callable(decorated_load_model) - + # Test that calling the decorated function returns the model instance # The decorator internally creates a _ModelMultiplexWrapper and caches it, # then calls load_model() on it which returns the actual model model = await decorated_load_model("test_model") assert isinstance(model, TestModel) assert model.model_id == "test_model" - + # Test that subsequent calls to the same model use the cached instance model2 = await decorated_load_model("test_model") assert model2.model_id == "test_model" - + # Test loading a different model model3 = await decorated_load_model("another_model") assert model3.model_id == "another_model" + if __name__ == "__main__": - # Run specific test methods for development - pytest.main([ - # __file__ + "::TestMultiplexBatchingIntegration::test_basic_batching_integration", - # __file__ + "::TestMultiplexBatchingIntegration::test_multiplex_with_batching_different_models", - # __file__ + "::TestMultiplexBatchingIntegration::test_batching_timeout_behavior", - # __file__ + "::TestMultiplexBatchingIntegration::test_max_batch_size_enforcement", - # __file__ + "::TestMultiplexBatchingIntegration::test_model_eviction_with_batching", - # __file__ + "::TestMultiplexBatchingIntegration::test_batching_disabled_fallback", - __file__ + "::TestMultiplexBatchingIntegration::test_concurrent_models_with_batching", - __file__ + "::TestMultiplexBatchingAPI::test_serve_multiplexed_decorator_with_batching", - __file__ + "::TestEndToEndMultiplexBatching::test_multiplexed_deployment_with_batching", - "-v" - ]) - -# import sys -# sys.exit(pytest.main(["-v", "-s", __file__])) + import sys + + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_multiplex_batching_router.py b/python/ray/serve/tests/test_multiplex_batching_router.py index cab67b7d2bcf..08bab57c942a 100644 --- a/python/ray/serve/tests/test_multiplex_batching_router.py +++ b/python/ray/serve/tests/test_multiplex_batching_router.py @@ -7,7 +7,7 @@ import asyncio import time -from typing import List, Dict, Any +from typing import Any, Dict, List import pytest @@ -42,39 +42,39 @@ def start_serve_with_context(): class TrackableModel: """Model that tracks its lifecycle and usage for testing.""" - + _instances = {} # Track all created instances - _deleted = [] # Track deleted models - + _deleted = [] # Track deleted models + def __init__(self, model_id: str): self.model_id = model_id self.created_at = time.time() self.predict_count = 0 self.batch_predict_count = 0 TrackableModel._instances[model_id] = self - + async def predict(self, data): """Individual prediction.""" await asyncio.sleep(0.01) self.predict_count += 1 return f"result_{self.model_id}_{data}" - + async def batch_predict(self, data_list: List): """Batch prediction.""" await asyncio.sleep(0.005 * len(data_list)) self.batch_predict_count += 1 return [f"batch_{self.model_id}_{item}" for item in data_list] - + def __del__(self): """Track when models are evicted.""" TrackableModel._deleted.append(self.model_id) - + @classmethod def reset_tracking(cls): """Reset all tracking for new test.""" cls._instances = {} cls._deleted = [] - + @classmethod def get_stats(cls) -> Dict[str, Any]: """Get current statistics.""" @@ -82,124 +82,131 @@ def get_stats(cls) -> Dict[str, Any]: "active_models": len(cls._instances), "deleted_models": len(cls._deleted), "model_ids": list(cls._instances.keys()), - "deleted_ids": cls._deleted.copy() + "deleted_ids": cls._deleted.copy(), } @pytest.mark.asyncio class TestMultiplexBatchingEnd2End: """End-to-end tests for the complete routing and batching pipeline.""" - + async def test_model_caching_and_lru_eviction(self, start_serve_with_context): """Test that models are cached and evicted using LRU policy.""" TrackableModel.reset_tracking() - + async def load_model(model_id: str): return TrackableModel(model_id) - + # Create wrapper with max 3 models wrapper = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=3, - enable_batching=False + enable_batching=False, ) - + # Load 3 models - all should be cached await wrapper.load_model("model_a") await wrapper.load_model("model_b") await wrapper.load_model("model_c") - + stats = TrackableModel.get_stats() - assert stats["active_models"] == 3, f"Expected 3 active models, got {stats['active_models']}" + assert ( + stats["active_models"] == 3 + ), "Expected 3 active models, got {stats['active_models']}" + assert stats["deleted_models"] == 0, "No models should be deleted yet" - + # Load 4th model - should evict least recently used (model_a) await wrapper.load_model("model_d") - + # Give some time for garbage collection await asyncio.sleep(0.1) - + # Check the wrapper's cache size (should be at most 3) cache_size = len(wrapper.models) - assert cache_size <= 3, f"Should have at most 3 models in cache, got {cache_size}" - + assert ( + cache_size <= 3 + ), f"Should have at most 3 models in cache, got {cache_size}" + # Access model_b and model_c to keep them recent await wrapper.load_model("model_b") await wrapper.load_model("model_c") - + # Load another model - should evict model_d (least recently used) await wrapper.load_model("model_e") await asyncio.sleep(0.1) - + final_cache_size = len(wrapper.models) - assert final_cache_size <= 3, f"Should maintain max 3 models, got {final_cache_size}" - + assert ( + final_cache_size <= 3 + ), f"Should maintain max 3 models, got {final_cache_size}" + print(f"Final cache size: {final_cache_size}") print(f"Models in cache: {list(wrapper.models.keys())}") - + async def test_model_reuse_vs_reload(self, start_serve_with_context): """Test that cached models are reused without reloading.""" TrackableModel.reset_tracking() - + load_count = {"count": 0} - + async def load_model(model_id: str): load_count["count"] += 1 return TrackableModel(model_id) - + wrapper = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=2, - enable_batching=False + enable_batching=False, ) - + # Load model_a for the first time await wrapper.load_model("model_a") assert load_count["count"] == 1, "Model should be loaded once" - + # Use model_a again - should reuse cached version await wrapper.load_model("model_a") assert load_count["count"] == 1, "Model should not be reloaded" - + # Load model_b await wrapper.load_model("model_b") assert load_count["count"] == 2, "Second model should be loaded" - + # Use both models again - no reloads await wrapper.load_model("model_a") await wrapper.load_model("model_b") assert load_count["count"] == 2, "No additional loads needed" - + # Load model_c - should evict one model await wrapper.load_model("model_c") assert load_count["count"] == 3, "Third model loaded" - + # Use model_a again - should reload if it was evicted await wrapper.load_model("model_a") # Could be 3 or 4 depending on which was evicted assert load_count["count"] >= 3, "Model may need reload if evicted" - + print(f"Total model loads: {load_count['count']}") - + async def test_batching_efficiency_metrics(self, start_serve_with_context): """Test that batching improves throughput and tracks metrics.""" TrackableModel.reset_tracking() - + async def load_model(model_id: str): return TrackableModel(model_id) - + # Test with batching enabled wrapper_batched = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=2, enable_batching=True, max_batch_size=5, - batch_wait_timeout_s=0.05 + batch_wait_timeout_s=0.05, ) - + # Load model first model = await wrapper_batched.load_model("batched_model") - + # Send concurrent requests to the wrapper to test batching mechanism start_time = time.time() tasks = [] @@ -207,64 +214,68 @@ async def load_model(model_id: str): # Use wrapper.predict() to test the actual batching mechanism task = wrapper_batched.predict(f"data_{i}", "batched_model") tasks.append(task) - + results = await asyncio.gather(*tasks) batched_time = time.time() - start_time - + # Check that batch predict was called (indicating batching worked) assert model.batch_predict_count > 0, "Batch predict should be called" assert len(results) == 10, "All requests should complete" - + # Verify results are correct format - should be from batch_predict - assert all("batch_batched_model" in result for result in results), f"Expected batch results, got: {results[:3]}" - + assert all( + "batch_batched_model" in result for result in results + ), f"Expected batch results, got: {results[:3]}" + # Test without batching for comparison TrackableModel.reset_tracking() - + wrapper_no_batch = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=2, - enable_batching=False + enable_batching=False, ) - + model_no_batch = await wrapper_no_batch.load_model("no_batch_model") - + start_time = time.time() tasks = [] for i in range(10): task = model_no_batch.predict(f"data_{i}") tasks.append(task) - + results = await asyncio.gather(*tasks) no_batch_time = time.time() - start_time - + assert model_no_batch.predict_count > 0, "Individual predict should be called" - assert model_no_batch.batch_predict_count == 0, "Batch predict should not be called" - + assert ( + model_no_batch.batch_predict_count == 0 + ), "Batch predict should not be called" + print(f"Batched time: {batched_time:.3f}s, No-batch time: {no_batch_time:.3f}s") print(f"Batch predict calls: {model.batch_predict_count}") print(f"Individual predict calls: {model_no_batch.predict_count}") - + async def test_concurrent_model_access_patterns(self, start_serve_with_context): """Test concurrent access to multiple models.""" TrackableModel.reset_tracking() - + async def load_model(model_id: str): return TrackableModel(model_id) - + wrapper = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=4, - enable_batching=False + enable_batching=False, ) - + # Load multiple models concurrently start_time = time.time() - + hot_model = await wrapper.load_model("hot_model") warm_model = await wrapper.load_model("warm_model") cold_model = await wrapper.load_model("cold_model") - + # Simulate workload with varying access patterns tasks = [] for i in range(6): @@ -273,133 +284,135 @@ async def load_model(model_id: str): tasks.append(warm_model.predict(f"data_{i}")) for i in range(1): tasks.append(cold_model.predict(f"data_{i}")) - + results = await asyncio.gather(*tasks) total_time = time.time() - start_time - + stats = TrackableModel.get_stats() - + assert len(results) == 10, "All requests should complete" - assert stats["active_models"] == 3, f"Should have 3 models loaded, got {stats['active_models']}" - + assert ( + stats["active_models"] == 3 + ), f"Should have 3 models loaded, got {stats['active_models']}" + # Check access counts assert hot_model.predict_count == 6, "Hot model should have 6 accesses" assert warm_model.predict_count == 3, "Warm model should have 3 accesses" assert cold_model.predict_count == 1, "Cold model should have 1 access" - + print(f"Total time for 10 requests across 3 models: {total_time:.3f}s") print(f"Active models: {stats['model_ids']}") - + async def test_model_affinity_for_batching(self, start_serve_with_context): """Test model caching behavior.""" TrackableModel.reset_tracking() - + async def load_model(model_id: str): return TrackableModel(model_id) - + wrapper = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=3, - enable_batching=False + enable_batching=False, ) - + # Load and access same model multiple times model1 = await wrapper.load_model("affinity_model") model2 = await wrapper.load_model("affinity_model") # Should be same instance - + assert model1 is model2, "Should return cached model instance" - + # Access the model results = [] for i in range(4): result = await model1.predict(f"request_{i}") results.append(result) - + assert len(results) == 4 assert all("result_affinity_model" in r for r in results) assert model1.predict_count == 4, "Should track all predictions" - + print(f"Predict count: {model1.predict_count}") @pytest.mark.asyncio class TestMultiplexCachingMetrics: """Tests focused on caching metrics and behavior.""" - + async def test_cache_hit_rate_tracking(self, start_serve_with_context): """Test tracking of cache hits vs misses.""" TrackableModel.reset_tracking() - + load_attempts = [] - + async def load_model(model_id: str): load_attempts.append(model_id) return TrackableModel(model_id) - + wrapper = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=2, - enable_batching=False + enable_batching=False, ) - + # First access - cache miss await wrapper.load_model("model_x") cache_misses = len(load_attempts) assert cache_misses == 1 - + # Second access - cache hit await wrapper.load_model("model_x") assert len(load_attempts) == 1, "Should not reload cached model" - + # Third model exceeds cache - eviction await wrapper.load_model("model_y") await wrapper.load_model("model_z") - + # Accessing evicted model - cache miss await wrapper.load_model("model_x") - + total_loads = len(load_attempts) print(f"Total model loads: {total_loads}") print(f"Load sequence: {load_attempts}") - + async def test_eviction_order_lru(self, start_serve_with_context): """Test that LRU eviction policy is followed.""" TrackableModel.reset_tracking() - + access_log = [] - + async def load_model(model_id: str): access_log.append(("load", model_id)) return TrackableModel(model_id) - + wrapper = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=2, - enable_batching=False + enable_batching=False, ) - + # Load model A and B await wrapper.load_model("A") await wrapper.load_model("B") - + # Access A again (making B least recently used) await wrapper.load_model("A") - + # Load C - should evict B (LRU) await wrapper.load_model("C") - + # If we access B again, it should reload await wrapper.load_model("B") - + # Count loads per model loads = {} for action, model_id in access_log: if action == "load": loads[model_id] = loads.get(model_id, 0) + 1 - + print(f"Access log: {access_log}") print(f"Load counts: {loads}") - + # A should be loaded once, B twice (initial + after eviction), C once assert loads.get("A") == 1, "A loaded once" assert loads.get("C") == 1, "C loaded once" @@ -409,25 +422,25 @@ async def load_model(model_id: str): @pytest.mark.asyncio class TestMultiplexBatchingIntegration: """Integration tests combining multiplexing and batching.""" - + async def test_multiple_models_with_batching(self, start_serve_with_context): """Test loading multiple models and basic tracking.""" TrackableModel.reset_tracking() - + async def load_model(model_id: str): return TrackableModel(model_id) - + wrapper = _ModelMultiplexWrapper( model_load_func=load_model, max_num_models_per_replica=3, - enable_batching=False + enable_batching=False, ) - + # Load multiple models model_1 = await wrapper.load_model("model_1") model_2 = await wrapper.load_model("model_2") model_3 = await wrapper.load_model("model_3") - + # Use each model tasks = [] for i in range(3): @@ -436,25 +449,31 @@ async def load_model(model_id: str): tasks.append(model_2.predict(f"m2_data_{i}")) for i in range(2): tasks.append(model_3.predict(f"m3_data_{i}")) - + results = await asyncio.gather(*tasks) stats = TrackableModel.get_stats() - + assert len(results) == 8, "All requests should complete" assert stats["active_models"] == 3, "Should have 3 models" - + # Verify models were used for model_id in ["model_1", "model_2", "model_3"]: model = TrackableModel._instances.get(model_id) assert model is not None assert model.predict_count > 0, f"{model_id} should be used" - + print(f"Stats: {stats}") - print(f"Model 1 predictions: {TrackableModel._instances['model_1'].predict_count}") - print(f"Model 2 predictions: {TrackableModel._instances['model_2'].predict_count}") - print(f"Model 3 predictions: {TrackableModel._instances['model_3'].predict_count}") + print( + f"Model 1 predictions: {TrackableModel._instances['model_1'].predict_count}" + ) + print( + f"Model 2 predictions: {TrackableModel._instances['model_2'].predict_count}" + ) + print( + f"Model 3 predictions: {TrackableModel._instances['model_3'].predict_count}" + ) if __name__ == "__main__": # Run tests - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file + pytest.main([__file__, "-v", "-s"]) diff --git a/python/ray/serve/tests/test_multiplex_batching_utils.py b/python/ray/serve/tests/test_multiplex_batching_utils.py index 594de5a8d774..63f7b7a769c8 100644 --- a/python/ray/serve/tests/test_multiplex_batching_utils.py +++ b/python/ray/serve/tests/test_multiplex_batching_utils.py @@ -7,19 +7,14 @@ import asyncio import time -from typing import List, Dict, Any, Optional -from unittest.mock import AsyncMock +from typing import Any, Dict, List import pytest -import ray -from ray import serve -from ray.serve._private.common import DeploymentID, ReplicaID -from ray.serve._private.config import DeploymentConfig class MockEmbeddingModel: """Mock embedding model for sentence transformer-like testing.""" - + def __init__(self, model_id: str, embedding_dim: int = 384): self.model_id = model_id self.embedding_dim = embedding_dim @@ -27,52 +22,54 @@ def __init__(self, model_id: str, embedding_dim: int = 384): self.predict_calls = 0 self.batch_predict_calls = 0 self.total_items_processed = 0 - + async def predict(self, text: str) -> List[float]: """Individual text encoding.""" await asyncio.sleep(0.02) # Simulate encoding time self.predict_calls += 1 self.total_items_processed += 1 - + # Generate deterministic embedding based on text and model import hashlib + hash_input = f"{text}_{self.model_id}".encode() hash_obj = hashlib.md5(hash_input) - + # Create embedding vector embedding = [] for i in range(self.embedding_dim): byte_val = hash_obj.digest()[i % 16] embedding.append((byte_val / 255.0) - 0.5) - + return embedding - + async def batch_predict(self, texts: List[str]) -> List[List[float]]: """Batch text encoding - more efficient.""" batch_size = len(texts) # Batch processing is more efficient per item await asyncio.sleep(0.01 * batch_size) - + self.batch_predict_calls += 1 self.total_items_processed += batch_size - + # Process all texts embeddings = [] for text in texts: # Same logic as predict but in batch import hashlib + hash_input = f"{text}_{self.model_id}".encode() hash_obj = hashlib.md5(hash_input) - + embedding = [] for i in range(self.embedding_dim): byte_val = hash_obj.digest()[i % 16] embedding.append((byte_val / 255.0) - 0.5) - + embeddings.append(embedding) - + return embeddings - + def get_stats(self) -> Dict[str, Any]: """Get model usage statistics.""" return { @@ -81,148 +78,149 @@ def get_stats(self) -> Dict[str, Any]: "predict_calls": self.predict_calls, "batch_predict_calls": self.batch_predict_calls, "total_items_processed": self.total_items_processed, - "uptime": time.time() - self.load_time + "uptime": time.time() - self.load_time, } class MockClassificationModel: """Mock classification model for testing.""" - + def __init__(self, model_id: str, num_classes: int = 3): self.model_id = model_id self.num_classes = num_classes self.predict_calls = 0 self.batch_predict_calls = 0 - + async def predict(self, text: str) -> Dict[str, float]: """Individual text classification.""" await asyncio.sleep(0.03) self.predict_calls += 1 - + # Generate deterministic probabilities import hashlib + hash_val = int(hashlib.md5(f"{text}_{self.model_id}".encode()).hexdigest(), 16) - + probs = [] for i in range(self.num_classes): prob = ((hash_val + i) % 100) / 100.0 probs.append(prob) - + # Normalize to sum to 1 total = sum(probs) probs = [p / total for p in probs] - - return { - f"class_{i}": probs[i] - for i in range(self.num_classes) - } - + + return {f"class_{i}": probs[i] for i in range(self.num_classes)} + async def batch_predict(self, texts: List[str]) -> List[Dict[str, float]]: """Batch text classification.""" batch_size = len(texts) await asyncio.sleep(0.02 * batch_size) # Batch efficiency self.batch_predict_calls += 1 - + results = [] for text in texts: # Same logic as predict import hashlib - hash_val = int(hashlib.md5(f"{text}_{self.model_id}".encode()).hexdigest(), 16) - + + hash_val = int( + hashlib.md5(f"{text}_{self.model_id}".encode()).hexdigest(), 16 + ) + probs = [] for i in range(self.num_classes): prob = ((hash_val + i) % 100) / 100.0 probs.append(prob) - + total = sum(probs) probs = [p / total for p in probs] - + result = {f"class_{i}": probs[i] for i in range(self.num_classes)} results.append(result) - + return results class MockTranslationModel: """Mock translation model for testing.""" - + def __init__(self, model_id: str, source_lang: str = "en", target_lang: str = "es"): self.model_id = model_id self.source_lang = source_lang self.target_lang = target_lang self.translate_calls = 0 self.batch_translate_calls = 0 - + async def translate(self, text: str) -> str: """Individual translation.""" await asyncio.sleep(0.05) # Translation takes longer self.translate_calls += 1 - + # Mock translation by reversing and adding prefix translated = f"[{self.target_lang}] {text[::-1]}" return translated - + async def batch_translate(self, texts: List[str]) -> List[str]: """Batch translation.""" batch_size = len(texts) await asyncio.sleep(0.03 * batch_size) # Batch efficiency self.batch_translate_calls += 1 - + translations = [] for text in texts: translated = f"[{self.target_lang}] {text[::-1]}" translations.append(translated) - + return translations class BatchingTestHelper: """Helper class for testing batching behavior.""" - + @staticmethod async def send_concurrent_requests(wrapper, inputs: List[str], model_id: str): """Send concurrent requests and measure timing.""" start_time = time.time() - + tasks = [] for input_data in inputs: task = wrapper.predict(input_data, model_id) tasks.append(task) - + results = await asyncio.gather(*tasks) total_time = time.time() - start_time - + return results, total_time - + @staticmethod def verify_batching_efficiency( - individual_time: float, - batch_time: float, + individual_time: float, + batch_time: float, num_requests: int, - min_speedup: float = 1.5 + min_speedup: float = 1.5, ): """Verify that batching provides expected efficiency gains.""" speedup = individual_time / batch_time - + assert speedup >= min_speedup, ( f"Expected speedup of at least {min_speedup}x, " f"got {speedup:.2f}x ({individual_time:.3f}s vs {batch_time:.3f}s)" ) - + return speedup - + @staticmethod def analyze_batch_patterns(model_instances: List): """Analyze batching patterns across model instances.""" stats = {} - + for model in model_instances: stats[model.model_id] = { - "individual_calls": getattr(model, 'predict_calls', 0), - "batch_calls": getattr(model, 'batch_predict_calls', 0), - "total_processed": getattr(model, 'total_items_processed', 0) + "individual_calls": getattr(model, "predict_calls", 0), + "batch_calls": getattr(model, "batch_predict_calls", 0), + "total_processed": getattr(model, "total_items_processed", 0), } - + return stats @@ -230,20 +228,16 @@ def analyze_batch_patterns(model_instances: List): def embedding_model_loader(): """Fixture for embedding model loader.""" models = {} - + async def loader(model_id: str) -> MockEmbeddingModel: if model_id not in models: # Different embedding dimensions for different models - dims = { - "mini": 384, - "base": 768, - "large": 1024 - } + dims = {"mini": 384, "base": 768, "large": 1024} dim = dims.get(model_id, 384) models[model_id] = MockEmbeddingModel(model_id, dim) - + return models[model_id] - + return loader, models @@ -251,20 +245,20 @@ async def loader(model_id: str) -> MockEmbeddingModel: def classification_model_loader(): """Fixture for classification model loader.""" models = {} - + async def loader(model_id: str) -> MockClassificationModel: if model_id not in models: # Different number of classes for different models classes = { "sentiment": 3, # positive, negative, neutral - "topic": 5, # 5 topic categories - "intent": 10 # 10 intent categories + "topic": 5, # 5 topic categories + "intent": 10, # 10 intent categories } num_classes = classes.get(model_id, 3) models[model_id] = MockClassificationModel(model_id, num_classes) - + return models[model_id] - + return loader, models @@ -272,20 +266,20 @@ async def loader(model_id: str) -> MockClassificationModel: def translation_model_loader(): """Fixture for translation model loader.""" models = {} - + async def loader(model_id: str) -> MockTranslationModel: if model_id not in models: # Different language pairs lang_pairs = { "en_es": ("en", "es"), "en_fr": ("en", "fr"), - "en_de": ("en", "de") + "en_de": ("en", "de"), } source_lang, target_lang = lang_pairs.get(model_id, ("en", "es")) models[model_id] = MockTranslationModel(model_id, source_lang, target_lang) - + return models[model_id] - + return loader, models @@ -302,7 +296,7 @@ def sample_texts(): "Deep learning models require careful optimization.", "Distributed systems handle large-scale ML workloads.", "Vector databases enable efficient similarity search.", - "Transformer architectures revolutionized NLP applications." + "Transformer architectures revolutionized NLP applications.", ] @@ -317,63 +311,60 @@ def performance_test_config(): "timeout_medium": 0.1, "timeout_long": 0.2, "min_speedup": 1.5, - "max_models": 4 + "max_models": 4, } class MultiModelTestScenario: """Test scenario with multiple models and request patterns.""" - + def __init__(self, models: List[str], request_patterns: Dict[str, List[str]]): self.models = models self.request_patterns = request_patterns - + async def execute_scenario(self, wrapper): """Execute the test scenario.""" all_tasks = [] - + for model_id, requests in self.request_patterns.items(): for request_data in requests: task = wrapper.predict(request_data, model_id) all_tasks.append((model_id, task)) - + # Execute all requests concurrently start_time = time.time() results = [] - + for model_id, task in all_tasks: result = await task - results.append({ - "model_id": model_id, - "result": result, - "timestamp": time.time() - }) - + results.append( + {"model_id": model_id, "result": result, "timestamp": time.time()} + ) + total_time = time.time() - start_time - + return results, total_time - + def analyze_results(self, results: List[Dict], total_time: float): """Analyze scenario execution results.""" model_results = {} - + for result in results: model_id = result["model_id"] if model_id not in model_results: model_results[model_id] = [] model_results[model_id].append(result) - + analysis = { "total_requests": len(results), "total_time": total_time, "models_used": len(model_results), "requests_per_model": { - model_id: len(model_results[model_id]) - for model_id in model_results + model_id: len(model_results[model_id]) for model_id in model_results }, - "avg_time_per_request": total_time / len(results) if results else 0 + "avg_time_per_request": total_time / len(results) if results else 0, } - + return analysis @@ -384,27 +375,36 @@ def analyze_results(self, results: List[Dict], total_time: float): request_patterns={ "mini": ["Quick text", "Short phrase", "Brief sentence"], "base": ["Medium length text for processing", "Another moderate sentence"], - "large": ["This is a longer text that requires more sophisticated embedding processing"] - } + "large": [ + "This is a longer text that requires more sophisticated " + "embedding processing" + ], + }, ), - "classification_workload": MultiModelTestScenario( models=["sentiment", "topic", "intent"], request_patterns={ - "sentiment": ["I love this product!", "This is terrible", "It's okay I guess"], + "sentiment": [ + "I love this product!", + "This is terrible", + "It's okay I guess", + ], "topic": ["Technology news update", "Sports match results"], - "intent": ["Book a flight", "Cancel my subscription", "Get weather forecast"] - } + "intent": [ + "Book a flight", + "Cancel my subscription", + "Get weather forecast", + ], + }, ), - "translation_workload": MultiModelTestScenario( models=["en_es", "en_fr", "en_de"], request_patterns={ "en_es": ["Hello world", "How are you?"], "en_fr": ["Good morning", "Thank you"], - "en_de": ["Welcome", "Goodbye"] - } - ) + "en_de": ["Welcome", "Goodbye"], + }, + ), } @@ -413,14 +413,14 @@ def analyze_results(self, results: List[Dict], total_time: float): async def test_mock_embedding_model(): """Test MockEmbeddingModel functionality.""" model = MockEmbeddingModel("test_model", embedding_dim=128) - + # Test individual prediction result = await model.predict("hello world") assert len(result) == 128 assert isinstance(result[0], float) assert model.predict_calls == 1 assert model.total_items_processed == 1 - + # Test batch prediction texts = ["hello", "world", "test"] batch_result = await model.batch_predict(texts) @@ -428,7 +428,7 @@ async def test_mock_embedding_model(): assert len(batch_result[0]) == 128 assert model.batch_predict_calls == 1 assert model.total_items_processed == 4 # 1 + 3 - + # Test stats stats = model.get_stats() assert stats["model_id"] == "test_model" @@ -441,14 +441,14 @@ async def test_mock_embedding_model(): async def test_mock_classification_model(): """Test MockClassificationModel functionality.""" model = MockClassificationModel("sentiment", num_classes=3) - + # Test individual prediction result = await model.predict("I love this!") assert len(result) == 3 assert all(f"class_{i}" in result for i in range(3)) assert abs(sum(result.values()) - 1.0) < 1e-6 # Should sum to 1 assert model.predict_calls == 1 - + # Test batch prediction texts = ["good", "bad", "neutral"] batch_result = await model.batch_predict(texts) @@ -461,13 +461,13 @@ async def test_mock_classification_model(): async def test_mock_translation_model(): """Test MockTranslationModel functionality.""" model = MockTranslationModel("en_es", source_lang="en", target_lang="es") - + # Test individual translation result = await model.translate("hello") assert result.startswith("[es]") assert "olleh" in result # reversed text assert model.translate_calls == 1 - + # Test batch translation texts = ["hello", "world"] batch_result = await model.batch_translate(texts) @@ -479,30 +479,29 @@ async def test_mock_translation_model(): @pytest.mark.asyncio async def test_batching_test_helper(): """Test BatchingTestHelper functionality.""" - from ray.serve.multiplex import _ModelMultiplexWrapper - + # Mock a simple wrapper class MockWrapper: def __init__(self): self.call_count = 0 - + async def predict(self, input_data, model_id): self.call_count += 1 await asyncio.sleep(0.01) return f"result_{model_id}_{input_data}" - + wrapper = MockWrapper() inputs = ["test1", "test2", "test3"] - + # Test concurrent requests results, total_time = await BatchingTestHelper.send_concurrent_requests( wrapper, inputs, "test_model" ) - + assert len(results) == 3 assert wrapper.call_count == 3 assert total_time > 0 - + # Test efficiency verification individual_time = 0.1 batch_time = 0.05 @@ -515,19 +514,19 @@ async def predict(self, input_data, model_id): def test_multi_model_test_scenario(): """Test MultiModelTestScenario functionality.""" scenario = TEST_SCENARIOS["embedding_workload"] - + assert len(scenario.models) == 3 assert "mini" in scenario.models assert "base" in scenario.models assert "large" in scenario.models - + # Test analyze_results mock_results = [ {"model_id": "mini", "result": "result1", "timestamp": 1.0}, {"model_id": "mini", "result": "result2", "timestamp": 1.1}, {"model_id": "base", "result": "result3", "timestamp": 1.2}, ] - + analysis = scenario.analyze_results(mock_results, 0.5) assert analysis["total_requests"] == 3 assert analysis["total_time"] == 0.5 @@ -536,29 +535,34 @@ def test_multi_model_test_scenario(): assert analysis["requests_per_model"]["base"] == 1 -def test_fixtures_return_correct_types(embedding_model_loader, classification_model_loader, - translation_model_loader, sample_texts, performance_test_config): +def test_fixtures_return_correct_types( + embedding_model_loader, + classification_model_loader, + translation_model_loader, + sample_texts, + performance_test_config, +): """Test that all fixtures return the expected types.""" # Test embedding model loader loader, models = embedding_model_loader assert callable(loader) assert isinstance(models, dict) - + # Test classification model loader loader, models = classification_model_loader assert callable(loader) assert isinstance(models, dict) - + # Test translation model loader loader, models = translation_model_loader assert callable(loader) assert isinstance(models, dict) - + # Test sample texts assert isinstance(sample_texts, list) assert len(sample_texts) > 0 assert all(isinstance(text, str) for text in sample_texts) - + # Test performance config assert isinstance(performance_test_config, dict) assert "small_batch" in performance_test_config @@ -569,18 +573,18 @@ def test_fixtures_return_correct_types(embedding_model_loader, classification_mo async def test_embedding_model_loader_fixture(embedding_model_loader): """Test the embedding model loader fixture.""" loader, models = embedding_model_loader - + # Load a model model = await loader("mini") assert isinstance(model, MockEmbeddingModel) assert model.model_id == "mini" assert model.embedding_dim == 384 - + # Check it's cached model2 = await loader("mini") assert model is model2 assert len(models) == 1 - + # Load different model model3 = await loader("base") assert model3.embedding_dim == 768 @@ -590,4 +594,5 @@ async def test_embedding_model_loader_fixture(embedding_model_loader): if __name__ == "__main__": # Run all tests in this module import sys - pytest.main(["-v", "-s", __file__] + sys.argv[1:]) \ No newline at end of file + + pytest.main(["-v", "-s", __file__] + sys.argv[1:])