Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 181 additions & 3 deletions python/ray/serve/_private/request_router/request_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class MultiplexMixin:
It adds necessary attributes and methods to keep track of multiplexed
model IDs and offer the helpers to apply multiplex routing and rank
replicas based on multiplexed model IDs.

Now supports batching-aware routing to group requests by model ID
for optimal batching performance.
"""

def __init__(self, *args, **kwargs):
Expand All @@ -212,6 +215,15 @@ def __init__(self, *args, **kwargs):
self._replica_id_set: Set[ReplicaID] = set()
self._replicas: Dict[ReplicaID, RunningReplica] = {}

# Batching-aware routing: track pending requests by model ID for better batching
self._pending_requests_by_model_id: DefaultDict[str, List] = defaultdict(list)
# Counters for efficient cleanup
self._pending_requests_added_since_cleanup = 0
self._last_cleanup_time = time.time()
self._cleanup_threshold = 50 # Cleanup after 50 new requests
self._cleanup_interval = 10.0 # Cleanup every 10 seconds
self._cleanup_task = None # Track async cleanup task

def _get_pending_request_matching_multiplexed_model_id(
self,
request_metadata: Optional[RequestMetadata] = None,
Expand All @@ -228,6 +240,115 @@ def _get_pending_request_matching_multiplexed_model_id(
):
return pr

def _track_pending_request_by_model_id(self, pending_request: PendingRequest):
"""Track pending requests by model ID for batching-aware routing."""
if pending_request.metadata.multiplexed_model_id:
model_id = pending_request.metadata.multiplexed_model_id
self._pending_requests_by_model_id[model_id].append(pending_request)
self._pending_requests_added_since_cleanup += 1

def _get_pending_requests_for_model(self, model_id: str) -> List[PendingRequest]:
"""Get all pending requests for a specific model ID."""
# Filter out completed requests on-the-fly for immediate use
# and update the list in-place to avoid accumulating completed requests
if model_id not in self._pending_requests_by_model_id:
return []

active_requests = []
completed_count = 0

for pr in self._pending_requests_by_model_id[model_id]:
if not pr.future.done():
active_requests.append(pr)
else:
completed_count += 1

# Update the stored list with only active requests to prevent accumulation
if completed_count > 0:
self._pending_requests_by_model_id[model_id] = active_requests
if not active_requests:
del self._pending_requests_by_model_id[model_id]

# Trigger periodic cleanup if we've seen enough completed requests
if completed_count > 0 and self._should_cleanup_pending_requests():
# Schedule cleanup asynchronously to avoid blocking routing
self._schedule_async_cleanup()

return active_requests

def _should_cleanup_pending_requests(self) -> bool:
"""Determine if we should perform cleanup based on counters and time."""
return (
self._pending_requests_added_since_cleanup >= self._cleanup_threshold
or (time.time() - self._last_cleanup_time) >= self._cleanup_interval
)

def _cleanup_completed_pending_requests(self):
"""Clean up completed requests from model ID tracking efficiently."""
# Only cleanup if we've accumulated enough requests or enough time has passed
if not self._should_cleanup_pending_requests():
return

cleanup_start = time.time()
total_requests_before = sum(
len(requests) for requests in self._pending_requests_by_model_id.values()
)

for model_id in list(self._pending_requests_by_model_id.keys()):
self._pending_requests_by_model_id[model_id] = [
pr
for pr in self._pending_requests_by_model_id[model_id]
if not pr.future.done()
]
if not self._pending_requests_by_model_id[model_id]:
del self._pending_requests_by_model_id[model_id]

total_requests_after = sum(
len(requests) for requests in self._pending_requests_by_model_id.values()
)
cleanup_time = time.time() - cleanup_start

# Reset counters
self._pending_requests_added_since_cleanup = 0
self._last_cleanup_time = time.time()

if total_requests_before != total_requests_after:
logger.debug(
f"Cleaned up {total_requests_before - total_requests_after} "
f"completed requests in {cleanup_time:.3f}s, "
f"{total_requests_after} active requests remaining"
)

def _schedule_async_cleanup(self):
"""Schedule cleanup to run asynchronously without blocking routing."""
# Only schedule if cleanup isn't already running
if (
not hasattr(self, "_cleanup_task")
or self._cleanup_task is None
or self._cleanup_task.done()
):
import asyncio

try:
# Get the current event loop
loop = asyncio.get_event_loop()
self._cleanup_task = loop.create_task(self._async_cleanup())
except RuntimeError:
# If no event loop is running, fall back to synchronous cleanup
# This should rarely happen in the Ray Serve context
self._cleanup_completed_pending_requests()

async def _async_cleanup(self):
"""Perform cleanup asynchronously."""
try:
# Small delay to avoid blocking the current operation
await asyncio.sleep(0.001)
self._cleanup_completed_pending_requests()
except Exception as e:
logger.warning(f"Async cleanup failed: {e}")
finally:
self._cleanup_task = None

def _update_multiplexed_model_ids_with_replicas(
self, replicas: List[RunningReplica]
):
Expand Down Expand Up @@ -280,6 +401,9 @@ def apply_multiplex_routing(
then the replicas with the fewest multiplexed models, and finally all
replicas.

Enhanced with batching-aware routing to prioritize replicas that already
have pending requests for the same model ID to improve batching efficiency.

Args:
pending_request: The pending request to be routed based on
multiplexed model policy.
Expand All @@ -291,6 +415,9 @@ def apply_multiplex_routing(
if not pending_request:
return self._replica_id_set

# Track this request for batching-aware routing
self._track_pending_request_by_model_id(pending_request)

if not pending_request.routing_context.multiplexed_start_matching_time:
pending_request.routing_context.multiplexed_start_matching_time = (
time.time()
Expand All @@ -300,13 +427,63 @@ def apply_multiplex_routing(
pending_request.routing_context.multiplexed_start_matching_time
)
multiplexed_model_id = pending_request.metadata.multiplexed_model_id

if (
time.time() - multiplexed_start_matching_time
< self._multiplexed_matching_timeout
):
candidate_replica_ids = self._multiplexed_model_id_to_replica_ids.get(
multiplexed_model_id, None
)

# Batching-aware enhancement: prioritize replicas with pending requests
# for the same model ID to improve batching efficiency
if candidate_replica_ids and multiplexed_model_id:
pending_for_model = self._get_pending_requests_for_model(
multiplexed_model_id
)
if len(pending_for_model) > 1: # Multiple requests for same model
# Find replicas that already have pending requests for this model
batching_friendly_replicas = set()

for pending_req in pending_for_model:
# Check if this request has been assigned to a replica
if (
pending_req.future.done()
and not pending_req.future.cancelled()
and not pending_req.future.exception()
):
try:
assigned_replica = pending_req.future.result()
if (
hasattr(assigned_replica, "replica_id")
and assigned_replica.replica_id
in candidate_replica_ids
):
batching_friendly_replicas.add(
assigned_replica.replica_id
)
except Exception:
# Future might not have replica result, skip
pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Batching Optimization Fails due to Incorrect Future Check

The batching-aware routing logic is ineffective. The _get_pending_requests_for_model method filters out completed requests, but the subsequent batching-friendly replica selection logic incorrectly checks for pending_req.future.done(). This condition is always false, preventing the batching optimization from executing.

Fix in Cursor Fix in Web


# If we found replicas with pending requests for this model,
# prioritize them
if batching_friendly_replicas:
candidate_replica_ids = batching_friendly_replicas
logger.debug(
f"Found {len(pending_for_model)} pending requests for "
f"model {multiplexed_model_id}, prioritizing "
f"{len(batching_friendly_replicas)} batching-friendly "
f"replicas"
)
else:
logger.debug(
f"Found {len(pending_for_model)} pending requests for "
f"model {multiplexed_model_id}, but no batching-friendly "
f"replicas found in candidates"
)

if (
not candidate_replica_ids
and multiplexed_model_id
Expand Down Expand Up @@ -497,7 +674,8 @@ def __init__(

# We keep two separate queues of pending requests:
# - self._pending_requests_to_fulfill is a queue that will be used to fulfill
# requests (potentially out of order) by routing tasks once they've acquired a replica.
# requests (potentially out of order) by routing tasks once they've
# acquired a replica.
# - self.routing is a queue that is used for tasks to
# best-effort grab the metadata of requests waiting to be fulfilled. This is
# currently used for routing tasks to know which multiplexed model IDs they
Expand Down Expand Up @@ -538,8 +716,8 @@ def __init__(

def initialize_state(self, **kwargs):
"""
Initialize the state of the request router. Called by the Ray Serve framework with the
contents of `RequestRouter.request_router_kwargs`.
Initialize the state of the request router. Called by the Ray Serve
framework with the contents of `RequestRouter.request_router_kwargs`.
"""
pass

Expand Down
20 changes: 18 additions & 2 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,12 @@ def delete(name: str, _blocking: bool = True):

@PublicAPI(stability="beta")
def multiplexed(
func: Optional[Callable[..., Any]] = None, max_num_models_per_replica: int = 3
func: Optional[Callable[..., Any]] = None,
max_num_models_per_replica: int = 3,
enable_batching: bool = False,
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.01,
max_concurrent_batches: int = 1,
):
"""Wrap a callable or method used to load multiplexed models in a replica.

Expand Down Expand Up @@ -798,6 +803,11 @@ async def __call__(self, request):
set it to a larger number if you have enough memory on
the node resource, in opposite, you can set it to a smaller
number if you want to save memory on the node resource.
enable_batching: whether to enable batching for model inference calls.
Default is False.
max_batch_size: maximum batch size for batched inference calls. Default is 10.
batch_wait_timeout_s: timeout for batching inference calls. Default is 0.01s.
max_concurrent_batches: maximum number of concurrent batches. Default is 1.
"""

if func is not None:
Expand Down Expand Up @@ -862,7 +872,13 @@ async def _multiplex_wrapper(*args):
# create a model multiplex wrapper and cache it in the multiplex object.
if not hasattr(multiplex_object, multiplex_attr):
model_multiplex_wrapper = _ModelMultiplexWrapper(
func, self, max_num_models_per_replica
func,
self,
max_num_models_per_replica,
enable_batching=enable_batching,
max_batch_size=max_batch_size,
batch_wait_timeout_s=batch_wait_timeout_s,
max_concurrent_batches=max_concurrent_batches,
)
setattr(multiplex_object, multiplex_attr, model_multiplex_wrapper)
else:
Expand Down
Loading