|
35 | 35 | import asyncio |
36 | 36 | import logging |
37 | 37 | import pprint |
| 38 | +import time |
38 | 39 | import uuid |
39 | 40 | from typing import Dict, List |
40 | 41 |
|
@@ -110,6 +111,13 @@ async def __initialize__(self): |
110 | 111 | self._default_router = RoundRobinRouter() |
111 | 112 | self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) |
112 | 113 |
|
| 114 | + # Batching |
| 115 | + self._max_batch_size = self._cfg.max_batch_size |
| 116 | + self._batch_max_wait_s = self._cfg.batch_max_wait_s |
| 117 | + self._batch_task: asyncio.Task | None = None |
| 118 | + self._running_batch_loop = False |
| 119 | + self._batch_queue: asyncio.Queue = asyncio.Queue() |
| 120 | + |
113 | 121 | # Initialize all replicas |
114 | 122 | replicas = [] |
115 | 123 | num_replicas = self._cfg.num_replicas |
@@ -138,6 +146,60 @@ async def __initialize__(self): |
138 | 146 | self._health_loop(poll_rate_s=self._cfg.health_poll_rate) |
139 | 147 | ) |
140 | 148 |
|
| 149 | + # Start batch loop if batching enabled |
| 150 | + if self._max_batch_size > 1: |
| 151 | + self._running_batch_loop = True |
| 152 | + self._batch_task = asyncio.create_task(self._batch_loop()) |
| 153 | + |
| 154 | + async def _batch_loop(self): |
| 155 | + """Background task that continuously processes batches of routing requests. |
| 156 | +
|
| 157 | + This is the core batching logic that runs in a separate asyncio task. |
| 158 | + It collects requests from the queue and processes them in batches based |
| 159 | + on size and time constraints. |
| 160 | +
|
| 161 | + The loop follows these steps: |
| 162 | + 1. Wait for the first request to start a new batch |
| 163 | + 2. Collect additional requests until batch_max_size or batch_max_wait_s is reached |
| 164 | + 3. Make a single routing decision for the entire batch |
| 165 | + 4. Fulfill all futures with the selected replica |
| 166 | +
|
| 167 | + This process repeats indefinitely until the task is cancelled. |
| 168 | + """ |
| 169 | + while self._running_batch_loop: |
| 170 | + batch_futs = [] |
| 171 | + |
| 172 | + # Wait for first request |
| 173 | + fut = await self._batch_queue.get() |
| 174 | + batch_futs.append(fut) |
| 175 | + start_time = time.monotonic() |
| 176 | + |
| 177 | + while True: |
| 178 | + try: |
| 179 | + timeout = max( |
| 180 | + 0, self._batch_max_wait_s - (time.monotonic() - start_time) |
| 181 | + ) |
| 182 | + fut = await asyncio.wait_for( |
| 183 | + self._batch_queue.get(), timeout |
| 184 | + ) # wait for timeout or until self._queue.get() finishes |
| 185 | + batch_futs.append(fut) |
| 186 | + |
| 187 | + if len(batch_futs) >= self._max_batch_size: |
| 188 | + break |
| 189 | + except asyncio.TimeoutError: |
| 190 | + break |
| 191 | + |
| 192 | + healthy_replicas = self._get_healthy_replicas() |
| 193 | + |
| 194 | + # One routing decision for the whole batch |
| 195 | + replica = self._default_router.get_replica( |
| 196 | + healthy_replicas, None, self._session_replica_map |
| 197 | + ) |
| 198 | + |
| 199 | + # Fulfill all futures with the chosen replica |
| 200 | + for fut in batch_futs: |
| 201 | + fut.set_result(replica) |
| 202 | + |
141 | 203 | async def _call(self, sess_id: str | None, function: str, *args, **kwargs): |
142 | 204 | """ |
143 | 205 | Routes a function call to the appropriate replica with load balancing and fault tolerance. |
@@ -211,7 +273,7 @@ async def call_all(self, function: str, *args, **kwargs) -> List: |
211 | 273 | Raises: |
212 | 274 | RuntimeError: If no healthy replicas are available |
213 | 275 | """ |
214 | | - healthy_replicas = [r for r in self._replicas if r.healthy] |
| 276 | + healthy_replicas = self._get_healthy_replicas() |
215 | 277 |
|
216 | 278 | if not healthy_replicas: |
217 | 279 | raise RuntimeError("No healthy replicas available for broadcast call") |
@@ -280,9 +342,7 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): |
280 | 342 | ) |
281 | 343 |
|
282 | 344 | # Find healthy replicas |
283 | | - healthy_replicas = [ |
284 | | - r for r in self._replicas if r.healthy and r != failed_replica |
285 | | - ] |
| 345 | + healthy_replicas = self._get_healthy_replicas() |
286 | 346 |
|
287 | 347 | if not healthy_replicas: |
288 | 348 | # No healthy replicas, fail all requests |
@@ -334,7 +394,7 @@ def _update_service_metrics(self): |
334 | 394 | """Updates service-level metrics.""" |
335 | 395 | self._metrics.total_sessions = len(self._active_sessions) |
336 | 396 | self._metrics.total_replicas = len(self._replicas) |
337 | | - self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) |
| 397 | + self._metrics.healthy_replicas = len(self._get_healthy_replicas()) |
338 | 398 | # Store direct references to replica metrics for aggregation |
339 | 399 | self._metrics.replica_metrics = {} |
340 | 400 | for replica in self._replicas: |
@@ -446,6 +506,10 @@ async def terminate_session(self, sess_id: str): |
446 | 506 | # Update metrics |
447 | 507 | self._update_service_metrics() |
448 | 508 |
|
| 509 | + def _get_healthy_replicas(self) -> list[Replica]: |
| 510 | + """Returns a list of healthy replicas.""" |
| 511 | + return [r for r in self._replicas if r.healthy] |
| 512 | + |
449 | 513 | async def _health_loop(self, poll_rate_s: float): |
450 | 514 | """Runs the health loop to monitor and recover replicas. |
451 | 515 |
|
@@ -476,14 +540,24 @@ async def _health_loop(self, poll_rate_s: float): |
476 | 540 |
|
477 | 541 | async def _get_replica(self, sess_id: str | None) -> "Replica": |
478 | 542 | """Get a replica for the given session ID.""" |
479 | | - healthy_replicas = [r for r in self._replicas if r.healthy] |
480 | | - if sess_id is None: |
481 | | - # No session, use the default router |
482 | | - return self._default_router.get_replica(healthy_replicas) |
483 | 543 |
|
484 | | - return self._session_router.get_replica( |
485 | | - healthy_replicas, sess_id, self._session_replica_map |
486 | | - ) |
| 544 | + if sess_id: |
| 545 | + # Stateful routing always uses session router |
| 546 | + healthy_replicas = self._get_healthy_replicas() |
| 547 | + return self._session_router.get_replica( |
| 548 | + healthy_replicas, sess_id, self._session_replica_map |
| 549 | + ) |
| 550 | + |
| 551 | + # Stateless: batching |
| 552 | + if self._max_batch_size > 1: |
| 553 | + fut = asyncio.Future() |
| 554 | + healthy_replicas = self._get_healthy_replicas() |
| 555 | + self._batch_queue.put_nowait(fut) |
| 556 | + return await fut |
| 557 | + else: |
| 558 | + # No batching, pick immediately |
| 559 | + healthy_replicas = self._get_healthy_replicas() |
| 560 | + return self._default_router.get_replica(healthy_replicas) |
487 | 561 |
|
488 | 562 | async def stop(self): |
489 | 563 | logger.debug("Stopping service...") |
@@ -582,7 +656,7 @@ async def _get_internal_state(self) -> dict: |
582 | 656 | # Load balancing state |
583 | 657 | # Service-level state |
584 | 658 | "total_replicas": len(self._replicas), |
585 | | - "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), |
| 659 | + "healthy_replica_count": len(self._get_healthy_replicas()), |
586 | 660 | "shutdown_requested": self._shutdown_requested, |
587 | 661 | # Metrics summary |
588 | 662 | "total_sessions": len(self._active_sessions), |
|
0 commit comments