Skip to content

Commit 8cdf240

Browse files
committed
lint fix with router counter fix
Signed-off-by: manickavela29 <[email protected]>
1 parent 3c639b4 commit 8cdf240

File tree

5 files changed

+579
-460
lines changed

5 files changed

+579
-460
lines changed

python/ray/serve/_private/request_router/request_router.py

Lines changed: 117 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class MultiplexMixin:
196196
It adds necessary attributes and methods to keep track of multiplexed
197197
model IDs and offer the helpers to apply multiplex routing and rank
198198
replicas based on multiplexed model IDs.
199-
199+
200200
Now supports batching-aware routing to group requests by model ID
201201
for optimal batching performance.
202202
"""
@@ -214,14 +214,15 @@ def __init__(self, *args, **kwargs):
214214
self._multiplexed_model_id_fallback_match: Set[str] = set()
215215
self._replica_id_set: Set[ReplicaID] = set()
216216
self._replicas: Dict[ReplicaID, RunningReplica] = {}
217-
217+
218218
# Batching-aware routing: track pending requests by model ID for better batching
219219
self._pending_requests_by_model_id: DefaultDict[str, List] = defaultdict(list)
220220
# Counters for efficient cleanup
221221
self._pending_requests_added_since_cleanup = 0
222222
self._last_cleanup_time = time.time()
223223
self._cleanup_threshold = 50 # Cleanup after 50 new requests
224224
self._cleanup_interval = 10.0 # Cleanup every 10 seconds
225+
self._cleanup_task = None # Track async cleanup task
225226

226227
def _get_pending_request_matching_multiplexed_model_id(
227228
self,
@@ -249,42 +250,104 @@ def _track_pending_request_by_model_id(self, pending_request: PendingRequest):
249250
def _get_pending_requests_for_model(self, model_id: str) -> List[PendingRequest]:
250251
"""Get all pending requests for a specific model ID."""
251252
# Filter out completed requests on-the-fly for immediate use
252-
active_requests = [pr for pr in self._pending_requests_by_model_id[model_id]
253-
if not pr.future.done()]
253+
# and update the list in-place to avoid accumulating completed requests
254+
if model_id not in self._pending_requests_by_model_id:
255+
return []
256+
257+
active_requests = []
258+
completed_count = 0
259+
260+
for pr in self._pending_requests_by_model_id[model_id]:
261+
if not pr.future.done():
262+
active_requests.append(pr)
263+
else:
264+
completed_count += 1
265+
266+
# Update the stored list with only active requests to prevent accumulation
267+
if completed_count > 0:
268+
self._pending_requests_by_model_id[model_id] = active_requests
269+
if not active_requests:
270+
del self._pending_requests_by_model_id[model_id]
271+
272+
# Trigger periodic cleanup if we've seen enough completed requests
273+
if completed_count > 0 and self._should_cleanup_pending_requests():
274+
# Schedule cleanup asynchronously to avoid blocking routing
275+
self._schedule_async_cleanup()
276+
254277
return active_requests
255278

256279
def _should_cleanup_pending_requests(self) -> bool:
257280
"""Determine if we should perform cleanup based on counters and time."""
258-
return (self._pending_requests_added_since_cleanup >= self._cleanup_threshold or
259-
(time.time() - self._last_cleanup_time) >= self._cleanup_interval)
281+
return (
282+
self._pending_requests_added_since_cleanup >= self._cleanup_threshold
283+
or (time.time() - self._last_cleanup_time) >= self._cleanup_interval
284+
)
260285

261286
def _cleanup_completed_pending_requests(self):
262287
"""Clean up completed requests from model ID tracking efficiently."""
263288
# Only cleanup if we've accumulated enough requests or enough time has passed
264289
if not self._should_cleanup_pending_requests():
265290
return
266-
291+
267292
cleanup_start = time.time()
268-
total_requests_before = sum(len(requests) for requests in self._pending_requests_by_model_id.values())
269-
293+
total_requests_before = sum(
294+
len(requests) for requests in self._pending_requests_by_model_id.values()
295+
)
296+
270297
for model_id in list(self._pending_requests_by_model_id.keys()):
271298
self._pending_requests_by_model_id[model_id] = [
272-
pr for pr in self._pending_requests_by_model_id[model_id]
299+
pr
300+
for pr in self._pending_requests_by_model_id[model_id]
273301
if not pr.future.done()
274302
]
275303
if not self._pending_requests_by_model_id[model_id]:
276304
del self._pending_requests_by_model_id[model_id]
277-
278-
total_requests_after = sum(len(requests) for requests in self._pending_requests_by_model_id.values())
305+
306+
total_requests_after = sum(
307+
len(requests) for requests in self._pending_requests_by_model_id.values()
308+
)
279309
cleanup_time = time.time() - cleanup_start
280-
310+
281311
# Reset counters
282312
self._pending_requests_added_since_cleanup = 0
283313
self._last_cleanup_time = time.time()
284-
314+
285315
if total_requests_before != total_requests_after:
286-
logger.debug(f"Cleaned up {total_requests_before - total_requests_after} completed requests "
287-
f"in {cleanup_time:.3f}s, {total_requests_after} active requests remaining")
316+
logger.debug(
317+
f"Cleaned up {total_requests_before - total_requests_after} "
318+
f"completed requests in {cleanup_time:.3f}s, "
319+
f"{total_requests_after} active requests remaining"
320+
)
321+
322+
def _schedule_async_cleanup(self):
323+
"""Schedule cleanup to run asynchronously without blocking routing."""
324+
# Only schedule if cleanup isn't already running
325+
if (
326+
not hasattr(self, "_cleanup_task")
327+
or self._cleanup_task is None
328+
or self._cleanup_task.done()
329+
):
330+
import asyncio
331+
332+
try:
333+
# Get the current event loop
334+
loop = asyncio.get_event_loop()
335+
self._cleanup_task = loop.create_task(self._async_cleanup())
336+
except RuntimeError:
337+
# If no event loop is running, fall back to synchronous cleanup
338+
# This should rarely happen in the Ray Serve context
339+
self._cleanup_completed_pending_requests()
340+
341+
async def _async_cleanup(self):
342+
"""Perform cleanup asynchronously."""
343+
try:
344+
# Small delay to avoid blocking the current operation
345+
await asyncio.sleep(0.001)
346+
self._cleanup_completed_pending_requests()
347+
except Exception as e:
348+
logger.warning(f"Async cleanup failed: {e}")
349+
finally:
350+
self._cleanup_task = None
288351

289352
def _update_multiplexed_model_ids_with_replicas(
290353
self, replicas: List[RunningReplica]
@@ -354,8 +417,6 @@ def apply_multiplex_routing(
354417

355418
# Track this request for batching-aware routing
356419
self._track_pending_request_by_model_id(pending_request)
357-
# Clean up completed requests periodically
358-
self._cleanup_completed_pending_requests()
359420

360421
if not pending_request.routing_context.multiplexed_start_matching_time:
361422
pending_request.routing_context.multiplexed_start_matching_time = (
@@ -366,46 +427,63 @@ def apply_multiplex_routing(
366427
pending_request.routing_context.multiplexed_start_matching_time
367428
)
368429
multiplexed_model_id = pending_request.metadata.multiplexed_model_id
369-
430+
370431
if (
371432
time.time() - multiplexed_start_matching_time
372433
< self._multiplexed_matching_timeout
373434
):
374435
candidate_replica_ids = self._multiplexed_model_id_to_replica_ids.get(
375436
multiplexed_model_id, None
376437
)
377-
438+
378439
# Batching-aware enhancement: prioritize replicas with pending requests
379440
# for the same model ID to improve batching efficiency
380441
if candidate_replica_ids and multiplexed_model_id:
381-
pending_for_model = self._get_pending_requests_for_model(multiplexed_model_id)
442+
pending_for_model = self._get_pending_requests_for_model(
443+
multiplexed_model_id
444+
)
382445
if len(pending_for_model) > 1: # Multiple requests for same model
383446
# Find replicas that already have pending requests for this model
384447
batching_friendly_replicas = set()
385-
448+
386449
for pending_req in pending_for_model:
387450
# Check if this request has been assigned to a replica
388-
if (pending_req.future.done() and
389-
not pending_req.future.cancelled() and
390-
not pending_req.future.exception()):
451+
if (
452+
pending_req.future.done()
453+
and not pending_req.future.cancelled()
454+
and not pending_req.future.exception()
455+
):
391456
try:
392457
assigned_replica = pending_req.future.result()
393-
if (hasattr(assigned_replica, 'replica_id') and
394-
assigned_replica.replica_id in candidate_replica_ids):
395-
batching_friendly_replicas.add(assigned_replica.replica_id)
458+
if (
459+
hasattr(assigned_replica, "replica_id")
460+
and assigned_replica.replica_id
461+
in candidate_replica_ids
462+
):
463+
batching_friendly_replicas.add(
464+
assigned_replica.replica_id
465+
)
396466
except Exception:
397467
# Future might not have replica result, skip
398468
pass
399-
400-
# If we found replicas with pending requests for this model, prioritize them
469+
470+
# If we found replicas with pending requests for this model,
471+
# prioritize them
401472
if batching_friendly_replicas:
402473
candidate_replica_ids = batching_friendly_replicas
403-
logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, "
404-
f"prioritizing {len(batching_friendly_replicas)} batching-friendly replicas")
474+
logger.debug(
475+
f"Found {len(pending_for_model)} pending requests for "
476+
f"model {multiplexed_model_id}, prioritizing "
477+
f"{len(batching_friendly_replicas)} batching-friendly "
478+
f"replicas"
479+
)
405480
else:
406-
logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, "
407-
f"but no batching-friendly replicas found in candidates")
408-
481+
logger.debug(
482+
f"Found {len(pending_for_model)} pending requests for "
483+
f"model {multiplexed_model_id}, but no batching-friendly "
484+
f"replicas found in candidates"
485+
)
486+
409487
if (
410488
not candidate_replica_ids
411489
and multiplexed_model_id
@@ -596,7 +674,8 @@ def __init__(
596674

597675
# We keep two separate queues of pending requests:
598676
# - self._pending_requests_to_fulfill is a queue that will be used to fulfill
599-
# requests (potentially out of order) by routing tasks once they've acquired a replica.
677+
# requests (potentially out of order) by routing tasks once they've
678+
# acquired a replica.
600679
# - self.routing is a queue that is used for tasks to
601680
# best-effort grab the metadata of requests waiting to be fulfilled. This is
602681
# currently used for routing tasks to know which multiplexed model IDs they
@@ -637,8 +716,8 @@ def __init__(
637716

638717
def initialize_state(self, **kwargs):
639718
"""
640-
Initialize the state of the request router. Called by the Ray Serve framework with the
641-
contents of `RequestRouter.request_router_kwargs`.
719+
Initialize the state of the request router. Called by the Ray Serve
720+
framework with the contents of `RequestRouter.request_router_kwargs`.
642721
"""
643722
pass
644723

0 commit comments

Comments
 (0)