Skip to content

Commit 5ea4fdc

Browse files
committed
handling comments
1 parent 02222b8 commit 5ea4fdc

File tree

4 files changed

+385
-15
lines changed

4 files changed

+385
-15
lines changed

python/ray/serve/_private/request_router/request_router.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ def __init__(self, *args, **kwargs):
217217

218218
# Batching-aware routing: track pending requests by model ID for better batching
219219
self._pending_requests_by_model_id: DefaultDict[str, List] = defaultdict(list)
220+
# Counters for efficient cleanup
221+
self._pending_requests_added_since_cleanup = 0
222+
self._last_cleanup_time = time.time()
223+
self._cleanup_threshold = 50 # Cleanup after 50 new requests
224+
self._cleanup_interval = 10.0 # Cleanup every 10 seconds
220225

221226
def _get_pending_request_matching_multiplexed_model_id(
222227
self,
@@ -239,21 +244,47 @@ def _track_pending_request_by_model_id(self, pending_request: PendingRequest):
239244
if pending_request.metadata.multiplexed_model_id:
240245
model_id = pending_request.metadata.multiplexed_model_id
241246
self._pending_requests_by_model_id[model_id].append(pending_request)
247+
self._pending_requests_added_since_cleanup += 1
242248

243249
def _get_pending_requests_for_model(self, model_id: str) -> List[PendingRequest]:
244250
"""Get all pending requests for a specific model ID."""
245-
return [pr for pr in self._pending_requests_by_model_id[model_id]
246-
if not pr.future.done()]
251+
# Filter out completed requests on-the-fly for immediate use
252+
active_requests = [pr for pr in self._pending_requests_by_model_id[model_id]
253+
if not pr.future.done()]
254+
return active_requests
255+
256+
def _should_cleanup_pending_requests(self) -> bool:
257+
"""Determine if we should perform cleanup based on counters and time."""
258+
return (self._pending_requests_added_since_cleanup >= self._cleanup_threshold or
259+
(time.time() - self._last_cleanup_time) >= self._cleanup_interval)
247260

248261
def _cleanup_completed_pending_requests(self):
249-
"""Clean up completed requests from model ID tracking."""
262+
"""Clean up completed requests from model ID tracking efficiently."""
263+
# Only cleanup if we've accumulated enough requests or enough time has passed
264+
if not self._should_cleanup_pending_requests():
265+
return
266+
267+
cleanup_start = time.time()
268+
total_requests_before = sum(len(requests) for requests in self._pending_requests_by_model_id.values())
269+
250270
for model_id in list(self._pending_requests_by_model_id.keys()):
251271
self._pending_requests_by_model_id[model_id] = [
252272
pr for pr in self._pending_requests_by_model_id[model_id]
253273
if not pr.future.done()
254274
]
255275
if not self._pending_requests_by_model_id[model_id]:
256276
del self._pending_requests_by_model_id[model_id]
277+
278+
total_requests_after = sum(len(requests) for requests in self._pending_requests_by_model_id.values())
279+
cleanup_time = time.time() - cleanup_start
280+
281+
# Reset counters
282+
self._pending_requests_added_since_cleanup = 0
283+
self._last_cleanup_time = time.time()
284+
285+
if total_requests_before != total_requests_after:
286+
logger.debug(f"Cleaned up {total_requests_before - total_requests_after} completed requests "
287+
f"in {cleanup_time:.3f}s, {total_requests_after} active requests remaining")
257288

258289
def _update_multiplexed_model_ids_with_replicas(
259290
self, replicas: List[RunningReplica]
@@ -349,9 +380,31 @@ def apply_multiplex_routing(
349380
if candidate_replica_ids and multiplexed_model_id:
350381
pending_for_model = self._get_pending_requests_for_model(multiplexed_model_id)
351382
if len(pending_for_model) > 1: # Multiple requests for same model
352-
# Prefer replicas that are likely processing this model
353-
logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, "
354-
f"prioritizing batching-friendly routing")
383+
# Find replicas that already have pending requests for this model
384+
batching_friendly_replicas = set()
385+
386+
for pending_req in pending_for_model:
387+
# Check if this request has been assigned to a replica
388+
if (pending_req.future.done() and
389+
not pending_req.future.cancelled() and
390+
not pending_req.future.exception()):
391+
try:
392+
assigned_replica = pending_req.future.result()
393+
if (hasattr(assigned_replica, 'replica_id') and
394+
assigned_replica.replica_id in candidate_replica_ids):
395+
batching_friendly_replicas.add(assigned_replica.replica_id)
396+
except Exception:
397+
# Future might not have replica result, skip
398+
pass
399+
400+
# If we found replicas with pending requests for this model, prioritize them
401+
if batching_friendly_replicas:
402+
candidate_replica_ids = batching_friendly_replicas
403+
logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, "
404+
f"prioritizing {len(batching_friendly_replicas)} batching-friendly replicas")
405+
else:
406+
logger.debug(f"Found {len(pending_for_model)} pending requests for model {multiplexed_model_id}, "
407+
f"but no batching-friendly replicas found in candidates")
355408

356409
if (
357410
not candidate_replica_ids

python/ray/serve/multiplex.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,11 @@ async def model_batch_handler(batch_requests: List[Any]) -> List[Any]:
160160
Returns:
161161
List of results corresponding to each input.
162162
"""
163+
# Re-check model availability at processing time to handle race conditions
163164
model = self.models.get(model_id)
164165
if model is None:
165-
raise RuntimeError(f"Model {model_id} not loaded")
166+
# Model was evicted, raise an exception that will cancel pending requests
167+
raise RuntimeError(f"Model {model_id} was evicted during batch processing")
166168

167169
# Try to use batch_predict method if available
168170
if hasattr(model, 'batch_predict'):
@@ -192,6 +194,124 @@ async def model_batch_handler(batch_requests: List[Any]) -> List[Any]:
192194

193195
return self._model_batch_queues[model_id]
194196

197+
async def _shutdown_batch_queue(self, batch_queue_wrapper: _LazyBatchQueueWrapper, model_id: str):
198+
"""Gracefully shutdown a batch queue by canceling pending requests and background tasks."""
199+
if batch_queue_wrapper._queue is None:
200+
# Queue was never initialized, nothing to clean up
201+
return
202+
203+
batch_queue = batch_queue_wrapper._queue
204+
205+
# Cancel the background processing task if it exists
206+
if hasattr(batch_queue, '_handle_batch_task') and batch_queue._handle_batch_task:
207+
batch_queue._handle_batch_task.cancel()
208+
try:
209+
await batch_queue._handle_batch_task
210+
except asyncio.CancelledError:
211+
pass # Expected when cancelling
212+
213+
# Cancel all pending requests in the queue
214+
pending_requests = []
215+
try:
216+
while True:
217+
try:
218+
request = batch_queue.queue.get_nowait()
219+
pending_requests.append(request)
220+
except asyncio.QueueEmpty:
221+
break
222+
except Exception:
223+
pass # Queue might be closed or corrupted
224+
225+
# Handle pending requests gracefully - try to reassign rather than fail
226+
reassigned_count = 0
227+
failed_count = 0
228+
229+
for request in pending_requests:
230+
if not request.future.done():
231+
try:
232+
# Try to reassign the request back to the routing system
233+
if await self._try_reassign_request(request, model_id):
234+
reassigned_count += 1
235+
else:
236+
# If reassignment fails, set a descriptive error
237+
request.future.set_exception(
238+
RuntimeError(f"Model {model_id} was evicted and could not be reassigned")
239+
)
240+
failed_count += 1
241+
except Exception:
242+
# Future might already be done or other error, count as failed
243+
failed_count += 1
244+
245+
logger.info(f"Shutdown batch queue for model {model_id}: reassigned {reassigned_count}, failed {failed_count} pending requests")
246+
247+
async def _try_reassign_request(self, request: _SingleRequest, model_id: str) -> bool:
248+
"""Try to reassign a pending request back to the routing system.
249+
250+
Args:
251+
request: The pending request to reassign
252+
model_id: The model ID that was evicted
253+
254+
Returns:
255+
True if request was successfully reassigned, False otherwise
256+
"""
257+
try:
258+
# Extract the original input from the flattened args
259+
if len(request.flattened_args) >= 2 and request.flattened_args[0] == DUMMY_TYPE:
260+
original_input = request.flattened_args[1]
261+
else:
262+
# Fallback if format is unexpected
263+
return False
264+
265+
# Check if we have retry attempts left (prevent infinite loops)
266+
retry_count = getattr(request, '_retry_count', 0)
267+
if retry_count >= 2: # Max 2 retries
268+
return False
269+
270+
# Create a new async task to retry the request with backoff
271+
async def retry_request():
272+
try:
273+
# Add retry count to track attempts
274+
setattr(request, '_retry_count', retry_count + 1)
275+
276+
# Exponential backoff: wait longer for each retry
277+
backoff_time = 0.01 * (2 ** retry_count)
278+
await asyncio.sleep(backoff_time)
279+
280+
# Try to process the request again - this will go through the full
281+
# model loading process, potentially reloading on this replica
282+
# Note: We call predict directly rather than batched_inference to avoid
283+
# potential batching complications during retry
284+
if self.enable_batching:
285+
# For batching case, try individual prediction as fallback
286+
model = await self.load_model(model_id)
287+
if hasattr(model, 'predict'):
288+
result = await model.predict(original_input)
289+
elif callable(model):
290+
result = await model(original_input)
291+
else:
292+
raise RuntimeError(f"Model {model_id} is not callable and has no predict method")
293+
else:
294+
result = await self.predict(original_input, model_id)
295+
296+
# Set the result on the original future
297+
if not request.future.done():
298+
request.future.set_result(result)
299+
300+
except Exception as e:
301+
# If retry fails, set the exception on the original future
302+
if not request.future.done():
303+
request.future.set_exception(
304+
RuntimeError(f"Model {model_id} evicted, retry failed: {str(e)}")
305+
)
306+
307+
# Start the retry task in the background
308+
asyncio.create_task(retry_request())
309+
return True
310+
311+
except Exception as e:
312+
logger.debug(f"Failed to reassign request for model {model_id}: {e}")
313+
return False
314+
195315
async def batched_inference(self, model_id: str, request: Any) -> Any:
196316
"""Perform batched inference on a specific model."""
197317
if not self.enable_batching:
@@ -292,6 +412,14 @@ async def shutdown(self):
292412
logger.exception(
293413
f"Failed to unload model. Error: {e}",
294414
)
415+
416+
# Clean up any remaining batch queues
417+
for model_id, batch_queue_wrapper in list(self._model_batch_queues.items()):
418+
try:
419+
await self._shutdown_batch_queue(batch_queue_wrapper, model_id)
420+
except Exception as e:
421+
logger.exception(f"Failed to shutdown batch queue for model {model_id}. Error: {e}")
422+
self._model_batch_queues.clear()
295423

296424
async def load_model(self, model_id: str) -> Any:
297425
"""Load the model if it is not loaded yet, and return
@@ -373,8 +501,10 @@ async def unload_model_lru(self) -> None:
373501
model_id, model = self.models.popitem(last=False)
374502
logger.info(f"Unloading model '{model_id}'.")
375503

376-
# Clean up the batch queue for this model if it exists
504+
# Gracefully shutdown the batch queue for this model if it exists
377505
if model_id in self._model_batch_queues:
506+
batch_queue_wrapper = self._model_batch_queues[model_id]
507+
await self._shutdown_batch_queue(batch_queue_wrapper, model_id)
378508
del self._model_batch_queues[model_id]
379509

380510
# If the model has __del__ attribute, call it.

python/ray/serve/tests/test_multiplex_batching_router.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,24 @@ async def load_model(model_id: str):
200200
# Load model first
201201
model = await wrapper_batched.load_model("batched_model")
202202

203-
# Send concurrent requests to same model using the model directly
203+
# Send concurrent requests to the wrapper to test batching mechanism
204204
start_time = time.time()
205205
tasks = []
206206
for i in range(10):
207-
task = model.batch_predict([f"data_{i}"])
207+
# Use wrapper.predict() to test the actual batching mechanism
208+
task = wrapper_batched.predict(f"data_{i}", "batched_model")
208209
tasks.append(task)
209210

210-
results_nested = await asyncio.gather(*tasks)
211-
# Flatten results since batch_predict returns lists
212-
results = [item for sublist in results_nested for item in sublist]
211+
results = await asyncio.gather(*tasks)
213212
batched_time = time.time() - start_time
214213

215-
# Check the model's batch predict was called
214+
# Check that batch predict was called (indicating batching worked)
216215
assert model.batch_predict_count > 0, "Batch predict should be called"
217216
assert len(results) == 10, "All requests should complete"
218217

218+
# Verify results are correct format - should be from batch_predict
219+
assert all("batch_batched_model" in result for result in results), f"Expected batch results, got: {results[:3]}"
220+
219221
# Test without batching for comparison
220222
TrackableModel.reset_tracking()
221223

0 commit comments

Comments
 (0)