@@ -196,7 +196,7 @@ class MultiplexMixin:
196196 It adds necessary attributes and methods to keep track of multiplexed
197197 model IDs and offer the helpers to apply multiplex routing and rank
198198 replicas based on multiplexed model IDs.
199-
199+
200200 Now supports batching-aware routing to group requests by model ID
201201 for optimal batching performance.
202202 """
@@ -214,14 +214,15 @@ def __init__(self, *args, **kwargs):
214214 self ._multiplexed_model_id_fallback_match : Set [str ] = set ()
215215 self ._replica_id_set : Set [ReplicaID ] = set ()
216216 self ._replicas : Dict [ReplicaID , RunningReplica ] = {}
217-
217+
218218 # Batching-aware routing: track pending requests by model ID for better batching
219219 self ._pending_requests_by_model_id : DefaultDict [str , List ] = defaultdict (list )
220220 # Counters for efficient cleanup
221221 self ._pending_requests_added_since_cleanup = 0
222222 self ._last_cleanup_time = time .time ()
223223 self ._cleanup_threshold = 50 # Cleanup after 50 new requests
224224 self ._cleanup_interval = 10.0 # Cleanup every 10 seconds
225+ self ._cleanup_task = None # Track async cleanup task
225226
226227 def _get_pending_request_matching_multiplexed_model_id (
227228 self ,
@@ -249,42 +250,104 @@ def _track_pending_request_by_model_id(self, pending_request: PendingRequest):
249250 def _get_pending_requests_for_model (self , model_id : str ) -> List [PendingRequest ]:
250251 """Get all pending requests for a specific model ID."""
251252 # Filter out completed requests on-the-fly for immediate use
252- active_requests = [pr for pr in self ._pending_requests_by_model_id [model_id ]
253- if not pr .future .done ()]
253+ # and update the list in-place to avoid accumulating completed requests
254+ if model_id not in self ._pending_requests_by_model_id :
255+ return []
256+
257+ active_requests = []
258+ completed_count = 0
259+
260+ for pr in self ._pending_requests_by_model_id [model_id ]:
261+ if not pr .future .done ():
262+ active_requests .append (pr )
263+ else :
264+ completed_count += 1
265+
266+ # Update the stored list with only active requests to prevent accumulation
267+ if completed_count > 0 :
268+ self ._pending_requests_by_model_id [model_id ] = active_requests
269+ if not active_requests :
270+ del self ._pending_requests_by_model_id [model_id ]
271+
272+ # Trigger periodic cleanup if we've seen enough completed requests
273+ if completed_count > 0 and self ._should_cleanup_pending_requests ():
274+ # Schedule cleanup asynchronously to avoid blocking routing
275+ self ._schedule_async_cleanup ()
276+
254277 return active_requests
255278
256279 def _should_cleanup_pending_requests (self ) -> bool :
257280 """Determine if we should perform cleanup based on counters and time."""
258- return (self ._pending_requests_added_since_cleanup >= self ._cleanup_threshold or
259- (time .time () - self ._last_cleanup_time ) >= self ._cleanup_interval )
281+ return (
282+ self ._pending_requests_added_since_cleanup >= self ._cleanup_threshold
283+ or (time .time () - self ._last_cleanup_time ) >= self ._cleanup_interval
284+ )
260285
261286 def _cleanup_completed_pending_requests (self ):
262287 """Clean up completed requests from model ID tracking efficiently."""
263288 # Only cleanup if we've accumulated enough requests or enough time has passed
264289 if not self ._should_cleanup_pending_requests ():
265290 return
266-
291+
267292 cleanup_start = time .time ()
268- total_requests_before = sum (len (requests ) for requests in self ._pending_requests_by_model_id .values ())
269-
293+ total_requests_before = sum (
294+ len (requests ) for requests in self ._pending_requests_by_model_id .values ()
295+ )
296+
270297 for model_id in list (self ._pending_requests_by_model_id .keys ()):
271298 self ._pending_requests_by_model_id [model_id ] = [
272- pr for pr in self ._pending_requests_by_model_id [model_id ]
299+ pr
300+ for pr in self ._pending_requests_by_model_id [model_id ]
273301 if not pr .future .done ()
274302 ]
275303 if not self ._pending_requests_by_model_id [model_id ]:
276304 del self ._pending_requests_by_model_id [model_id ]
277-
278- total_requests_after = sum (len (requests ) for requests in self ._pending_requests_by_model_id .values ())
305+
306+ total_requests_after = sum (
307+ len (requests ) for requests in self ._pending_requests_by_model_id .values ()
308+ )
279309 cleanup_time = time .time () - cleanup_start
280-
310+
281311 # Reset counters
282312 self ._pending_requests_added_since_cleanup = 0
283313 self ._last_cleanup_time = time .time ()
284-
314+
285315 if total_requests_before != total_requests_after :
286- logger .debug (f"Cleaned up { total_requests_before - total_requests_after } completed requests "
287- f"in { cleanup_time :.3f} s, { total_requests_after } active requests remaining" )
316+ logger .debug (
317+ f"Cleaned up { total_requests_before - total_requests_after } "
318+ f"completed requests in { cleanup_time :.3f} s, "
319+ f"{ total_requests_after } active requests remaining"
320+ )
321+
322+ def _schedule_async_cleanup (self ):
323+ """Schedule cleanup to run asynchronously without blocking routing."""
324+ # Only schedule if cleanup isn't already running
325+ if (
326+ not hasattr (self , "_cleanup_task" )
327+ or self ._cleanup_task is None
328+ or self ._cleanup_task .done ()
329+ ):
330+ import asyncio
331+
332+ try :
333+ # Get the current event loop
334+ loop = asyncio .get_event_loop ()
335+ self ._cleanup_task = loop .create_task (self ._async_cleanup ())
336+ except RuntimeError :
337+ # If no event loop is running, fall back to synchronous cleanup
338+ # This should rarely happen in the Ray Serve context
339+ self ._cleanup_completed_pending_requests ()
340+
341+ async def _async_cleanup (self ):
342+ """Perform cleanup asynchronously."""
343+ try :
344+ # Small delay to avoid blocking the current operation
345+ await asyncio .sleep (0.001 )
346+ self ._cleanup_completed_pending_requests ()
347+ except Exception as e :
348+ logger .warning (f"Async cleanup failed: { e } " )
349+ finally :
350+ self ._cleanup_task = None
288351
289352 def _update_multiplexed_model_ids_with_replicas (
290353 self , replicas : List [RunningReplica ]
@@ -354,8 +417,6 @@ def apply_multiplex_routing(
354417
355418 # Track this request for batching-aware routing
356419 self ._track_pending_request_by_model_id (pending_request )
357- # Clean up completed requests periodically
358- self ._cleanup_completed_pending_requests ()
359420
360421 if not pending_request .routing_context .multiplexed_start_matching_time :
361422 pending_request .routing_context .multiplexed_start_matching_time = (
@@ -366,46 +427,63 @@ def apply_multiplex_routing(
366427 pending_request .routing_context .multiplexed_start_matching_time
367428 )
368429 multiplexed_model_id = pending_request .metadata .multiplexed_model_id
369-
430+
370431 if (
371432 time .time () - multiplexed_start_matching_time
372433 < self ._multiplexed_matching_timeout
373434 ):
374435 candidate_replica_ids = self ._multiplexed_model_id_to_replica_ids .get (
375436 multiplexed_model_id , None
376437 )
377-
438+
378439 # Batching-aware enhancement: prioritize replicas with pending requests
379440 # for the same model ID to improve batching efficiency
380441 if candidate_replica_ids and multiplexed_model_id :
381- pending_for_model = self ._get_pending_requests_for_model (multiplexed_model_id )
442+ pending_for_model = self ._get_pending_requests_for_model (
443+ multiplexed_model_id
444+ )
382445 if len (pending_for_model ) > 1 : # Multiple requests for same model
383446 # Find replicas that already have pending requests for this model
384447 batching_friendly_replicas = set ()
385-
448+
386449 for pending_req in pending_for_model :
387450 # Check if this request has been assigned to a replica
388- if (pending_req .future .done () and
389- not pending_req .future .cancelled () and
390- not pending_req .future .exception ()):
451+ if (
452+ pending_req .future .done ()
453+ and not pending_req .future .cancelled ()
454+ and not pending_req .future .exception ()
455+ ):
391456 try :
392457 assigned_replica = pending_req .future .result ()
393- if (hasattr (assigned_replica , 'replica_id' ) and
394- assigned_replica .replica_id in candidate_replica_ids ):
395- batching_friendly_replicas .add (assigned_replica .replica_id )
458+ if (
459+ hasattr (assigned_replica , "replica_id" )
460+ and assigned_replica .replica_id
461+ in candidate_replica_ids
462+ ):
463+ batching_friendly_replicas .add (
464+ assigned_replica .replica_id
465+ )
396466 except Exception :
397467 # Future might not have replica result, skip
398468 pass
399-
400- # If we found replicas with pending requests for this model, prioritize them
469+
470+ # If we found replicas with pending requests for this model,
471+ # prioritize them
401472 if batching_friendly_replicas :
402473 candidate_replica_ids = batching_friendly_replicas
403- logger .debug (f"Found { len (pending_for_model )} pending requests for model { multiplexed_model_id } , "
404- f"prioritizing { len (batching_friendly_replicas )} batching-friendly replicas" )
474+ logger .debug (
475+ f"Found { len (pending_for_model )} pending requests for "
476+ f"model { multiplexed_model_id } , prioritizing "
477+ f"{ len (batching_friendly_replicas )} batching-friendly "
478+ f"replicas"
479+ )
405480 else :
406- logger .debug (f"Found { len (pending_for_model )} pending requests for model { multiplexed_model_id } , "
407- f"but no batching-friendly replicas found in candidates" )
408-
481+ logger .debug (
482+ f"Found { len (pending_for_model )} pending requests for "
483+ f"model { multiplexed_model_id } , but no batching-friendly "
484+ f"replicas found in candidates"
485+ )
486+
409487 if (
410488 not candidate_replica_ids
411489 and multiplexed_model_id
@@ -596,7 +674,8 @@ def __init__(
596674
597675 # We keep two separate queues of pending requests:
598676 # - self._pending_requests_to_fulfill is a queue that will be used to fulfill
599- # requests (potentially out of order) by routing tasks once they've acquired a replica.
677+ # requests (potentially out of order) by routing tasks once they've
678+ # acquired a replica.
600679 # - self.routing is a queue that is used for tasks to
601680 # best-effort grab the metadata of requests waiting to be fulfilled. This is
602681 # currently used for routing tasks to know which multiplexed model IDs they
@@ -637,8 +716,8 @@ def __init__(
637716
638717 def initialize_state (self , ** kwargs ):
639718 """
640- Initialize the state of the request router. Called by the Ray Serve framework with the
641- contents of `RequestRouter.request_router_kwargs`.
719+ Initialize the state of the request router. Called by the Ray Serve
720+ framework with the contents of `RequestRouter.request_router_kwargs`.
642721 """
643722 pass
644723
0 commit comments