@@ -160,9 +160,11 @@ async def model_batch_handler(batch_requests: List[Any]) -> List[Any]:
160160 Returns:
161161 List of results corresponding to each input.
162162 """
163+ # Re-check model availability at processing time to handle race conditions
163164 model = self .models .get (model_id )
164165 if model is None :
165- raise RuntimeError (f"Model { model_id } not loaded" )
166+ # Model was evicted, raise an exception that will cancel pending requests
167+ raise RuntimeError (f"Model { model_id } was evicted during batch processing" )
166168
167169 # Try to use batch_predict method if available
168170 if hasattr (model , 'batch_predict' ):
@@ -192,6 +194,124 @@ async def model_batch_handler(batch_requests: List[Any]) -> List[Any]:
192194
193195 return self ._model_batch_queues [model_id ]
194196
197+ async def _shutdown_batch_queue (self , batch_queue_wrapper : _LazyBatchQueueWrapper , model_id : str ):
198+ """Gracefully shutdown a batch queue by canceling pending requests and background tasks."""
199+ if batch_queue_wrapper ._queue is None :
200+ # Queue was never initialized, nothing to clean up
201+ return
202+
203+ batch_queue = batch_queue_wrapper ._queue
204+
205+ # Cancel the background processing task if it exists
206+ if hasattr (batch_queue , '_handle_batch_task' ) and batch_queue ._handle_batch_task :
207+ batch_queue ._handle_batch_task .cancel ()
208+ try :
209+ await batch_queue ._handle_batch_task
210+ except asyncio .CancelledError :
211+ pass # Expected when cancelling
212+
213+ # Cancel all pending requests in the queue
214+ pending_requests = []
215+ try :
216+ while True :
217+ try :
218+ request = batch_queue .queue .get_nowait ()
219+ pending_requests .append (request )
220+ except asyncio .QueueEmpty :
221+ break
222+ except Exception :
223+ pass # Queue might be closed or corrupted
224+
225+ # Handle pending requests gracefully - try to reassign rather than fail
226+ reassigned_count = 0
227+ failed_count = 0
228+
229+ for request in pending_requests :
230+ if not request .future .done ():
231+ try :
232+ # Try to reassign the request back to the routing system
233+ if await self ._try_reassign_request (request , model_id ):
234+ reassigned_count += 1
235+ else :
236+ # If reassignment fails, set a descriptive error
237+ request .future .set_exception (
238+ RuntimeError (f"Model { model_id } was evicted and could not be reassigned" )
239+ )
240+ failed_count += 1
241+ except Exception :
242+ # Future might already be done or other error, count as failed
243+ failed_count += 1
244+
245+ logger .info (f"Shutdown batch queue for model { model_id } : reassigned { reassigned_count } , failed { failed_count } pending requests" )
246+
247+ async def _try_reassign_request (self , request : _SingleRequest , model_id : str ) -> bool :
248+ """Try to reassign a pending request back to the routing system.
249+
250+ Args:
251+ request: The pending request to reassign
252+ model_id: The model ID that was evicted
253+
254+ Returns:
255+ True if request was successfully reassigned, False otherwise
256+ """
257+ try :
258+ # Extract the original input from the flattened args
259+ if len (request .flattened_args ) >= 2 and request .flattened_args [0 ] == DUMMY_TYPE :
260+ original_input = request .flattened_args [1 ]
261+ else :
262+ # Fallback if format is unexpected
263+ return False
264+
265+ # Check if we have retry attempts left (prevent infinite loops)
266+ retry_count = getattr (request , '_retry_count' , 0 )
267+ if retry_count >= 2 : # Max 2 retries
268+ return False
269+
270+ # Create a new async task to retry the request with backoff
271+ async def retry_request ():
272+ try :
273+ # Add retry count to track attempts
274+ setattr (request , '_retry_count' , retry_count + 1 )
275+
276+ # Exponential backoff: wait longer for each retry
277+ backoff_time = 0.01 * (2 ** retry_count )
278+ await asyncio .sleep (backoff_time )
279+
280+ # Try to process the request again - this will go through the full
281+ # model loading process, potentially reloading on this replica
282+ # Note: We call predict directly rather than batched_inference to avoid
283+ # potential batching complications during retry
284+ if self .enable_batching :
285+ # For batching case, try individual prediction as fallback
286+ model = await self .load_model (model_id )
287+ if hasattr (model , 'predict' ):
288+ result = await model .predict (original_input )
289+ elif callable (model ):
290+ result = await model (original_input )
291+ else :
292+ raise RuntimeError (f"Model { model_id } is not callable and has no predict method" )
293+ else :
294+ result = await self .predict (original_input , model_id )
295+
296+ # Set the result on the original future
297+ if not request .future .done ():
298+ request .future .set_result (result )
299+
300+ except Exception as e :
301+ # If retry fails, set the exception on the original future
302+ if not request .future .done ():
303+ request .future .set_exception (
304+ RuntimeError (f"Model { model_id } evicted, retry failed: { str (e )} " )
305+ )
306+
307+ # Start the retry task in the background
308+ asyncio .create_task (retry_request ())
309+ return True
310+
311+ except Exception as e :
312+ logger .debug (f"Failed to reassign request for model { model_id } : { e } " )
313+ return False
314+
195315 async def batched_inference (self , model_id : str , request : Any ) -> Any :
196316 """Perform batched inference on a specific model."""
197317 if not self .enable_batching :
@@ -292,6 +412,14 @@ async def shutdown(self):
292412 logger .exception (
293413 f"Failed to unload model. Error: { e } " ,
294414 )
415+
416+ # Clean up any remaining batch queues
417+ for model_id , batch_queue_wrapper in list (self ._model_batch_queues .items ()):
418+ try :
419+ await self ._shutdown_batch_queue (batch_queue_wrapper , model_id )
420+ except Exception as e :
421+ logger .exception (f"Failed to shutdown batch queue for model { model_id } . Error: { e } " )
422+ self ._model_batch_queues .clear ()
295423
296424 async def load_model (self , model_id : str ) -> Any :
297425 """Load the model if it is not loaded yet, and return
@@ -373,8 +501,10 @@ async def unload_model_lru(self) -> None:
373501 model_id , model = self .models .popitem (last = False )
374502 logger .info (f"Unloading model '{ model_id } '." )
375503
376- # Clean up the batch queue for this model if it exists
504+ # Gracefully shutdown the batch queue for this model if it exists
377505 if model_id in self ._model_batch_queues :
506+ batch_queue_wrapper = self ._model_batch_queues [model_id ]
507+ await self ._shutdown_batch_queue (batch_queue_wrapper , model_id )
378508 del self ._model_batch_queues [model_id ]
379509
380510 # If the model has __del__ attribute, call it.
0 commit comments