Skip to content

Commit 02222b8

Browse files
manickavela29manickavela-uni
authored andcommitted
mulitplexing with batching
Signed-off-by: manickavela29 <[email protected]>
1 parent 1c8d408 commit 02222b8

File tree

6 files changed

+1446
-7
lines changed

6 files changed

+1446
-7
lines changed

python/ray/serve/_private/request_router/request_router.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ class MultiplexMixin:
196196
It adds necessary attributes and methods to keep track of multiplexed
197197
model IDs and offer the helpers to apply multiplex routing and rank
198198
replicas based on multiplexed model IDs.
199+
200+
Now supports batching-aware routing to group requests by model ID
201+
for optimal batching performance.
199202
"""
200203

201204
def __init__(self, *args, **kwargs):
@@ -211,6 +214,9 @@ def __init__(self, *args, **kwargs):
211214
self._multiplexed_model_id_fallback_match: Set[str] = set()
212215
self._replica_id_set: Set[ReplicaID] = set()
213216
self._replicas: Dict[ReplicaID, RunningReplica] = {}
217+
218+
# Batching-aware routing: track pending requests by model ID for better batching
219+
self._pending_requests_by_model_id: DefaultDict[str, List] = defaultdict(list)
214220

215221
def _get_pending_request_matching_multiplexed_model_id(
216222
self,
@@ -228,6 +234,27 @@ def _get_pending_request_matching_multiplexed_model_id(
228234
):
229235
return pr
230236

237+
def _track_pending_request_by_model_id(self, pending_request: PendingRequest):
238+
"""Track pending requests by model ID for batching-aware routing."""
239+
if pending_request.metadata.multiplexed_model_id:
240+
model_id = pending_request.metadata.multiplexed_model_id
241+
self._pending_requests_by_model_id[model_id].append(pending_request)
242+
243+
def _get_pending_requests_for_model(self, model_id: str) -> List[PendingRequest]:
244+
"""Get all pending requests for a specific model ID."""
245+
return [pr for pr in self._pending_requests_by_model_id[model_id]
246+
if not pr.future.done()]
247+
248+
def _cleanup_completed_pending_requests(self):
249+
"""Clean up completed requests from model ID tracking."""
250+
for model_id in list(self._pending_requests_by_model_id.keys()):
251+
self._pending_requests_by_model_id[model_id] = [
252+
pr for pr in self._pending_requests_by_model_id[model_id]
253+
if not pr.future.done()
254+
]
255+
if not self._pending_requests_by_model_id[model_id]:
256+
del self._pending_requests_by_model_id[model_id]
257+
231258
def _update_multiplexed_model_ids_with_replicas(
232259
self, replicas: List[RunningReplica]
233260
):
@@ -280,6 +307,9 @@ def apply_multiplex_routing(
280307
then the replicas with the fewest multiplexed models, and finally all
281308
replicas.
282309
310+
Enhanced with batching-aware routing to prioritize replicas that already
311+
have pending requests for the same model ID to improve batching efficiency.
312+
283313
Args:
284314
pending_request: The pending request to be routed based on
285315
multiplexed model policy.
@@ -291,6 +321,11 @@ def apply_multiplex_routing(
291321
if not pending_request:
292322
return self._replica_id_set
293323

324+
# Track this request for batching-aware routing
325+
self._track_pending_request_by_model_id(pending_request)
326+
# Clean up completed requests periodically
327+
self._cleanup_completed_pending_requests()
328+
294329
if not pending_request.routing_context.multiplexed_start_matching_time:
295330
pending_request.routing_context.multiplexed_start_matching_time = (
296331
time.time()
@@ -300,13 +335,24 @@ def apply_multiplex_routing(
300335
pending_request.routing_context.multiplexed_start_matching_time
301336
)
302337
multiplexed_model_id = pending_request.metadata.multiplexed_model_id
338+
303339
if (
304340
time.time() - multiplexed_start_matching_time
305341
< self._multiplexed_matching_timeout
306342
):
307343
candidate_replica_ids = self._multiplexed_model_id_to_replica_ids.get(
308344
multiplexed_model_id, None
309345
)
346+
347+
# Batching-aware enhancement: prioritize replicas with pending requests
348+
# for the same model ID to improve batching efficiency
349+
if candidate_replica_ids and multiplexed_model_id:
350+
pending_for_model = self._get_pending_requests_for_model(multiplexed_model_id)
351+
if len(pending_for_model) > 1: # Multiple requests for same model
352+
# Prefer replicas that are likely processing this model
353+
logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, "
354+
f"prioritizing batching-friendly routing")
355+
310356
if (
311357
not candidate_replica_ids
312358
and multiplexed_model_id

python/ray/serve/api.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,12 @@ def delete(name: str, _blocking: bool = True):
751751

752752
@PublicAPI(stability="beta")
753753
def multiplexed(
754-
func: Optional[Callable[..., Any]] = None, max_num_models_per_replica: int = 3
754+
func: Optional[Callable[..., Any]] = None,
755+
max_num_models_per_replica: int = 3,
756+
enable_batching: bool = False,
757+
max_batch_size: int = 10,
758+
batch_wait_timeout_s: float = 0.01,
759+
max_concurrent_batches: int = 1,
755760
):
756761
"""Wrap a callable or method used to load multiplexed models in a replica.
757762
@@ -811,6 +816,11 @@ async def __call__(self, request):
811816
set it to a larger number if you have enough memory on
812817
the node resource, in opposite, you can set it to a smaller
813818
number if you want to save memory on the node resource.
819+
enable_batching: whether to enable batching for model inference calls.
820+
Default is False.
821+
max_batch_size: maximum batch size for batched inference calls. Default is 10.
822+
batch_wait_timeout_s: timeout for batching inference calls. Default is 0.01s.
823+
max_concurrent_batches: maximum number of concurrent batches. Default is 1.
814824
"""
815825

816826
if func is not None:
@@ -875,7 +885,13 @@ async def _multiplex_wrapper(*args):
875885
# create a model multiplex wrapper and cache it in the multiplex object.
876886
if not hasattr(multiplex_object, multiplex_attr):
877887
model_multiplex_wrapper = _ModelMultiplexWrapper(
878-
func, self, max_num_models_per_replica
888+
func,
889+
self,
890+
max_num_models_per_replica,
891+
enable_batching=enable_batching,
892+
max_batch_size=max_batch_size,
893+
batch_wait_timeout_s=batch_wait_timeout_s,
894+
max_concurrent_batches=max_concurrent_batches,
879895
)
880896
setattr(multiplex_object, multiplex_attr, model_multiplex_wrapper)
881897
else:

python/ray/serve/multiplex.py

Lines changed: 138 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import time
55
from collections import OrderedDict
6-
from typing import Any, Callable, List, Set
6+
from typing import Any, Callable, List, Set, Optional
77

88
from ray.serve import metrics
99
from ray.serve._private.common import ReplicaID, RequestRoutingInfo
@@ -15,6 +15,8 @@
1515
from ray.serve._private.metrics_utils import MetricsPusher
1616
from ray.serve._private.usage import ServeUsageTag
1717
from ray.serve.context import _get_global_client, _get_internal_replica_context
18+
from ray.serve.batching import _LazyBatchQueueWrapper, _SingleRequest
19+
from ray._common.signature import DUMMY_TYPE
1820

1921
logger = logging.getLogger(SERVE_LOGGER_NAME)
2022

@@ -39,16 +41,26 @@ class _ModelMultiplexWrapper:
3941
def __init__(
4042
self,
4143
model_load_func: Callable[[str], Any],
42-
self_arg: Any,
43-
max_num_models_per_replica: int,
44+
self_arg: Any = None,
45+
max_num_models_per_replica: int = 3,
46+
enable_batching: bool = False,
47+
max_batch_size: int = 10,
48+
batch_wait_timeout_s: float = 0.01,
49+
max_concurrent_batches: int = 1,
4450
):
4551
"""Initialize the model multiplexer.
4652
Args:
4753
model_load_func: the model load async function.
48-
self_arg: self argument when model_load_func is class method.
54+
self_arg: self argument when model_load_func is class method. Default is None
55+
for standalone functions.
4956
max_num_models_per_replica: the maximum number of models to be loaded on the
5057
current replica. If it is -1, there is no limit for the number of models
51-
per replica.
58+
per replica. Default is 3.
59+
enable_batching: whether to enable batching for model inference calls.
60+
Default is False.
61+
max_batch_size: maximum batch size for batched inference calls. Default is 10.
62+
batch_wait_timeout_s: timeout for batching inference calls. Default is 0.01s.
63+
max_concurrent_batches: maximum number of concurrent batches. Default is 1.
5264
"""
5365

5466
ServeUsageTag.MULTIPLEXED_API_USED.record("1")
@@ -57,6 +69,15 @@ def __init__(
5769
self._func: Callable = model_load_func
5870
self.self_arg: Any = self_arg
5971
self.max_num_models_per_replica: int = max_num_models_per_replica
72+
73+
# Batching configuration
74+
self.enable_batching = enable_batching
75+
self.max_batch_size = max_batch_size
76+
self.batch_wait_timeout_s = batch_wait_timeout_s
77+
self.max_concurrent_batches = max_concurrent_batches
78+
79+
# Model-specific batch queues for inference batching
80+
self._model_batch_queues: dict[str, _LazyBatchQueueWrapper] = {}
6081

6182
# log MODEL_LOAD_LATENCY_BUCKET_MS
6283
logger.debug(f"MODEL_LOAD_LATENCY_BUCKET_MS: {MODEL_LOAD_LATENCY_BUCKETS_MS}")
@@ -123,6 +144,114 @@ def __init__(
123144
)
124145
self.metrics_pusher.start()
125146

147+
def _get_or_create_batch_queue(self, model_id: str) -> Optional[_LazyBatchQueueWrapper]:
148+
"""Get or create a batch queue for a specific model."""
149+
if not self.enable_batching:
150+
return None
151+
152+
if model_id not in self._model_batch_queues:
153+
# Create a batch handler for this specific model
154+
async def model_batch_handler(batch_requests: List[Any]) -> List[Any]:
155+
"""Handle batched inference for a specific model.
156+
157+
Args:
158+
batch_requests: List of input data items to process as a batch.
159+
160+
Returns:
161+
List of results corresponding to each input.
162+
"""
163+
model = self.models.get(model_id)
164+
if model is None:
165+
raise RuntimeError(f"Model {model_id} not loaded")
166+
167+
# Try to use batch_predict method if available
168+
if hasattr(model, 'batch_predict'):
169+
results = await model.batch_predict(batch_requests)
170+
else:
171+
# Fallback to individual prediction calls
172+
results = []
173+
for request_data in batch_requests:
174+
if hasattr(model, 'predict'):
175+
result = await model.predict(request_data)
176+
elif callable(model):
177+
result = await model(request_data)
178+
else:
179+
raise RuntimeError(
180+
f"Model {model_id} is not callable and has no predict method"
181+
)
182+
results.append(result)
183+
184+
return results
185+
186+
self._model_batch_queues[model_id] = _LazyBatchQueueWrapper(
187+
max_batch_size=self.max_batch_size,
188+
batch_wait_timeout_s=self.batch_wait_timeout_s,
189+
max_concurrent_batches=self.max_concurrent_batches,
190+
handle_batch_func=model_batch_handler,
191+
)
192+
193+
return self._model_batch_queues[model_id]
194+
195+
async def batched_inference(self, model_id: str, request: Any) -> Any:
196+
"""Perform batched inference on a specific model."""
197+
if not self.enable_batching:
198+
raise RuntimeError("Batching is not enabled for this multiplexer")
199+
200+
# Ensure model is loaded first
201+
await self.load_model(model_id)
202+
203+
# Get the batch queue for this model
204+
batch_queue = self._get_or_create_batch_queue(model_id)
205+
if batch_queue is None:
206+
raise RuntimeError("Failed to create batch queue")
207+
208+
# Submit request to the batch queue using _SingleRequest format
209+
import ray.serve.context as context
210+
future = asyncio.get_event_loop().create_future()
211+
request_context = context._get_serve_request_context()
212+
213+
# Create _SingleRequest with flattened args using DUMMY_TYPE for positional args
214+
# Format: [DUMMY_TYPE, arg1, DUMMY_TYPE, arg2, ...] for positional args
215+
single_request = _SingleRequest(
216+
self_arg=None,
217+
flattened_args=[DUMMY_TYPE, request],
218+
future=future,
219+
request_context=request_context
220+
)
221+
222+
batch_queue.queue.put(single_request)
223+
224+
return await future
225+
226+
async def predict(self, input_data: Any, model_id: str) -> Any:
227+
"""Convenience method for model prediction with optional batching.
228+
229+
Args:
230+
input_data: The input data to predict on.
231+
model_id: The model ID to use for prediction.
232+
233+
Returns:
234+
The prediction result.
235+
"""
236+
if self.enable_batching:
237+
# Use batched inference
238+
return await self.batched_inference(model_id, input_data)
239+
else:
240+
# Load model and call directly
241+
model = await self.load_model(model_id)
242+
243+
# Try different prediction methods
244+
if hasattr(model, 'predict'):
245+
result = await model.predict(input_data)
246+
elif callable(model):
247+
result = await model(input_data)
248+
else:
249+
raise RuntimeError(
250+
f"Model {model_id} is not callable and has no predict method"
251+
)
252+
253+
return result
254+
126255
def _get_loading_and_loaded_model_ids(self) -> List[str]:
127256
"""Get the model IDs of the loaded models & loading models in the replica.
128257
This is to push the model id information early to the controller, so that
@@ -244,6 +373,10 @@ async def unload_model_lru(self) -> None:
244373
model_id, model = self.models.popitem(last=False)
245374
logger.info(f"Unloading model '{model_id}'.")
246375

376+
# Clean up the batch queue for this model if it exists
377+
if model_id in self._model_batch_queues:
378+
del self._model_batch_queues[model_id]
379+
247380
# If the model has __del__ attribute, call it.
248381
# This is to clean up the model resources eagerly.
249382
if hasattr(model, "__del__"):

0 commit comments

Comments
 (0)