33import logging
44import time
55from collections import OrderedDict
6- from typing import Any , Callable , List , Set
6+ from typing import Any , Callable , List , Set , Optional
77
88from ray .serve import metrics
99from ray .serve ._private .common import ReplicaID , RequestRoutingInfo
1515from ray .serve ._private .metrics_utils import MetricsPusher
1616from ray .serve ._private .usage import ServeUsageTag
1717from ray .serve .context import _get_global_client , _get_internal_replica_context
18+ from ray .serve .batching import _LazyBatchQueueWrapper , _SingleRequest
19+ from ray ._common .signature import DUMMY_TYPE
1820
1921logger = logging .getLogger (SERVE_LOGGER_NAME )
2022
@@ -39,16 +41,26 @@ class _ModelMultiplexWrapper:
3941 def __init__ (
4042 self ,
4143 model_load_func : Callable [[str ], Any ],
42- self_arg : Any ,
43- max_num_models_per_replica : int ,
44+ self_arg : Any = None ,
45+ max_num_models_per_replica : int = 3 ,
46+ enable_batching : bool = False ,
47+ max_batch_size : int = 10 ,
48+ batch_wait_timeout_s : float = 0.01 ,
49+ max_concurrent_batches : int = 1 ,
4450 ):
4551 """Initialize the model multiplexer.
4652 Args:
4753 model_load_func: the model load async function.
48- self_arg: self argument when model_load_func is class method.
54+ self_arg: self argument when model_load_func is class method. Default is None
55+ for standalone functions.
4956 max_num_models_per_replica: the maximum number of models to be loaded on the
5057 current replica. If it is -1, there is no limit for the number of models
51- per replica.
58+ per replica. Default is 3.
59+ enable_batching: whether to enable batching for model inference calls.
60+ Default is False.
61+ max_batch_size: maximum batch size for batched inference calls. Default is 10.
62+ batch_wait_timeout_s: timeout for batching inference calls. Default is 0.01s.
63+ max_concurrent_batches: maximum number of concurrent batches. Default is 1.
5264 """
5365
5466 ServeUsageTag .MULTIPLEXED_API_USED .record ("1" )
@@ -57,6 +69,15 @@ def __init__(
5769 self ._func : Callable = model_load_func
5870 self .self_arg : Any = self_arg
5971 self .max_num_models_per_replica : int = max_num_models_per_replica
72+
73+ # Batching configuration
74+ self .enable_batching = enable_batching
75+ self .max_batch_size = max_batch_size
76+ self .batch_wait_timeout_s = batch_wait_timeout_s
77+ self .max_concurrent_batches = max_concurrent_batches
78+
79+ # Model-specific batch queues for inference batching
80+ self ._model_batch_queues : dict [str , _LazyBatchQueueWrapper ] = {}
6081
6182 # log MODEL_LOAD_LATENCY_BUCKET_MS
6283 logger .debug (f"MODEL_LOAD_LATENCY_BUCKET_MS: { MODEL_LOAD_LATENCY_BUCKETS_MS } " )
@@ -123,6 +144,114 @@ def __init__(
123144 )
124145 self .metrics_pusher .start ()
125146
147+ def _get_or_create_batch_queue (self , model_id : str ) -> Optional [_LazyBatchQueueWrapper ]:
148+ """Get or create a batch queue for a specific model."""
149+ if not self .enable_batching :
150+ return None
151+
152+ if model_id not in self ._model_batch_queues :
153+ # Create a batch handler for this specific model
154+ async def model_batch_handler (batch_requests : List [Any ]) -> List [Any ]:
155+ """Handle batched inference for a specific model.
156+
157+ Args:
158+ batch_requests: List of input data items to process as a batch.
159+
160+ Returns:
161+ List of results corresponding to each input.
162+ """
163+ model = self .models .get (model_id )
164+ if model is None :
165+ raise RuntimeError (f"Model { model_id } not loaded" )
166+
167+ # Try to use batch_predict method if available
168+ if hasattr (model , 'batch_predict' ):
169+ results = await model .batch_predict (batch_requests )
170+ else :
171+ # Fallback to individual prediction calls
172+ results = []
173+ for request_data in batch_requests :
174+ if hasattr (model , 'predict' ):
175+ result = await model .predict (request_data )
176+ elif callable (model ):
177+ result = await model (request_data )
178+ else :
179+ raise RuntimeError (
180+ f"Model { model_id } is not callable and has no predict method"
181+ )
182+ results .append (result )
183+
184+ return results
185+
186+ self ._model_batch_queues [model_id ] = _LazyBatchQueueWrapper (
187+ max_batch_size = self .max_batch_size ,
188+ batch_wait_timeout_s = self .batch_wait_timeout_s ,
189+ max_concurrent_batches = self .max_concurrent_batches ,
190+ handle_batch_func = model_batch_handler ,
191+ )
192+
193+ return self ._model_batch_queues [model_id ]
194+
195+ async def batched_inference (self , model_id : str , request : Any ) -> Any :
196+ """Perform batched inference on a specific model."""
197+ if not self .enable_batching :
198+ raise RuntimeError ("Batching is not enabled for this multiplexer" )
199+
200+ # Ensure model is loaded first
201+ await self .load_model (model_id )
202+
203+ # Get the batch queue for this model
204+ batch_queue = self ._get_or_create_batch_queue (model_id )
205+ if batch_queue is None :
206+ raise RuntimeError ("Failed to create batch queue" )
207+
208+ # Submit request to the batch queue using _SingleRequest format
209+ import ray .serve .context as context
210+ future = asyncio .get_event_loop ().create_future ()
211+ request_context = context ._get_serve_request_context ()
212+
213+ # Create _SingleRequest with flattened args using DUMMY_TYPE for positional args
214+ # Format: [DUMMY_TYPE, arg1, DUMMY_TYPE, arg2, ...] for positional args
215+ single_request = _SingleRequest (
216+ self_arg = None ,
217+ flattened_args = [DUMMY_TYPE , request ],
218+ future = future ,
219+ request_context = request_context
220+ )
221+
222+ batch_queue .queue .put (single_request )
223+
224+ return await future
225+
226+ async def predict (self , input_data : Any , model_id : str ) -> Any :
227+ """Convenience method for model prediction with optional batching.
228+
229+ Args:
230+ input_data: The input data to predict on.
231+ model_id: The model ID to use for prediction.
232+
233+ Returns:
234+ The prediction result.
235+ """
236+ if self .enable_batching :
237+ # Use batched inference
238+ return await self .batched_inference (model_id , input_data )
239+ else :
240+ # Load model and call directly
241+ model = await self .load_model (model_id )
242+
243+ # Try different prediction methods
244+ if hasattr (model , 'predict' ):
245+ result = await model .predict (input_data )
246+ elif callable (model ):
247+ result = await model (input_data )
248+ else :
249+ raise RuntimeError (
250+ f"Model { model_id } is not callable and has no predict method"
251+ )
252+
253+ return result
254+
126255 def _get_loading_and_loaded_model_ids (self ) -> List [str ]:
127256 """Get the model IDs of the loaded models & loading models in the replica.
128257 This is to push the model id information early to the controller, so that
@@ -244,6 +373,10 @@ async def unload_model_lru(self) -> None:
244373 model_id , model = self .models .popitem (last = False )
245374 logger .info (f"Unloading model '{ model_id } '." )
246375
376+ # Clean up the batch queue for this model if it exists
377+ if model_id in self ._model_batch_queues :
378+ del self ._model_batch_queues [model_id ]
379+
247380 # If the model has __del__ attribute, call it.
248381 # This is to clean up the model resources eagerly.
249382 if hasattr (model , "__del__" ):
0 commit comments