Skip to content

Commit 1c7efac

Browse files
committed
add batch routing logic to service + test case
1 parent d351935 commit 1c7efac

File tree

5 files changed

+123
-320
lines changed

5 files changed

+123
-320
lines changed

src/forge/controller/service/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .interface import ServiceInterface, Session, SessionContext
88
from .metrics import ServiceMetrics
99
from .replica import Replica, ReplicaMetrics, ReplicaState
10-
from .router import BatchRouter, LeastLoadedRouter, RoundRobinRouter, SessionRouter
10+
from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter
1111
from .service import Service, ServiceActor, ServiceConfig
1212

1313
__all__ = [
@@ -24,5 +24,4 @@
2424
"LeastLoadedRouter",
2525
"RoundRobinRouter",
2626
"SessionRouter",
27-
"BatchRouter",
2827
]

src/forge/controller/service/router.py

Lines changed: 2 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import asyncio
7+
88
import logging
9-
from typing import Callable, Dict, List, Optional
9+
from typing import Dict, List
1010

1111
from .interface import Router
1212
from .replica import Replica
@@ -89,134 +89,3 @@ def get_replica(
8989
replica.idx,
9090
)
9191
return replica
92-
93-
94-
class BatchRouter(Router):
95-
"""
96-
Router wrapper that batches routing decisions.
97-
Uses an inner router to pick the replica for each batch.
98-
99-
Args:
100-
inner_router: The underlying Router instance used to make routing decisions
101-
batch_max_size: Maximum number of requests to collect in a single batch (default: 8)
102-
batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01)
103-
104-
Example:
105-
rr_router = RoundRobinRouter()
106-
batch_router = BatchRouter(rr_router, batch_max_size=16, batch_max_wait_s=0.02)
107-
108-
replica = await batch_router.get_replica(healthy_replicas, sess_id, session_map)
109-
"""
110-
111-
def __init__(
112-
self,
113-
inner_router: Router,
114-
batch_max_size: int = 8,
115-
batch_max_wait_s: float = 0.01,
116-
get_healthy_replicas: Optional[Callable[[], List["Replica"]]] = None,
117-
session_map: Optional[Dict[str, int]] = None,
118-
):
119-
120-
self.inner_router = inner_router
121-
self.batch_max_size = batch_max_size
122-
self.batch_max_wait_s = batch_max_wait_s
123-
self.get_healthy_replicas = get_healthy_replicas
124-
self.session_map = session_map
125-
126-
# Internal queue for batching routing requests
127-
self._queue: asyncio.Queue = asyncio.Queue()
128-
self._running = True # flag to control loop
129-
# Background task that processes batches continuously
130-
self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop())
131-
132-
async def _batch_loop(self):
133-
"""Background task that continuously processes batches of routing requests.
134-
135-
This is the core batching logic that runs in a separate asyncio task.
136-
It collects requests from the queue and processes them in batches based
137-
on size and time constraints.
138-
139-
The loop follows these steps:
140-
1. Wait for the first request to start a new batch
141-
2. Collect additional requests until batch_max_size or batch_max_wait_s is reached
142-
3. Make a single routing decision for the entire batch
143-
4. Fulfill all futures with the selected replica
144-
145-
This process repeats indefinitely until the task is cancelled.
146-
"""
147-
while self._running:
148-
batch = []
149-
futs = []
150-
sess_ids = []
151-
152-
# Wait for first request
153-
fut, healthy_replicas, sess_id, session_map = await self._queue.get()
154-
batch.append((healthy_replicas, sess_id, session_map))
155-
futs.append(fut)
156-
sess_ids.append(sess_id)
157-
start_time = time.monotonic()
158-
159-
while True:
160-
try:
161-
timeout = max(
162-
0, self.batch_max_wait_s - (time.monotonic() - start_time)
163-
)
164-
(
165-
fut,
166-
healthy_replicas,
167-
sess_id,
168-
session_map,
169-
) = await asyncio.wait_for(
170-
self._queue.get(), timeout
171-
) # wait for timeout or until self._queue.get() finishes
172-
batch.append((healthy_replicas, sess_id, session_map))
173-
futs.append(fut)
174-
sess_ids.append(sess_id)
175-
176-
if len(batch) >= self.batch_max_size:
177-
break
178-
except asyncio.TimeoutError:
179-
break
180-
181-
if self.session_map is not None:
182-
session_map = self.session_map
183-
else:
184-
session_map = batch[-1][2] # use most recent session map
185-
if self.get_healthy_replicas is not None:
186-
healthy_replicas = self.get_healthy_replicas()
187-
else:
188-
healthy_replicas = batch[-1][0] # use most recent replica state
189-
# Check if any replicas have become unhealthy
190-
healthy_replicas = [r for r in healthy_replicas if r.healthy]
191-
192-
# One routing decision for the whole batch
193-
replica = await self.inner_router.get_replica(
194-
healthy_replicas, None, session_map
195-
)
196-
197-
# Fulfill all futures with the chosen replica
198-
for fut in futs:
199-
fut.set_result(replica)
200-
201-
async def get_replica(
202-
self,
203-
healthy_replicas: List[Replica],
204-
sess_id: Optional[str] = None,
205-
session_map: Optional[Dict[str, int]] = None,
206-
) -> Replica:
207-
"""Enqueue request and wait until batch assigns a replica."""
208-
fut = asyncio.Future()
209-
# Queue the request for batching - this is non-blocking
210-
self._queue.put_nowait((fut, healthy_replicas, sess_id, session_map))
211-
212-
# Wait for the batch processor to resolve our future
213-
return await fut
214-
215-
async def shutdown(self):
216-
"""Stop the batch loop gracefully."""
217-
self._running = False
218-
self._batch_task.cancel()
219-
try:
220-
await self._batch_task
221-
except asyncio.CancelledError:
222-
pass

src/forge/controller/service/service.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import asyncio
3636
import logging
3737
import pprint
38+
import time
3839
import uuid
3940
from typing import Dict, List
4041

@@ -110,6 +111,13 @@ async def __initialize__(self):
110111
self._default_router = RoundRobinRouter()
111112
self._session_router = SessionRouter(fallback_router=LeastLoadedRouter())
112113

114+
# Batching
115+
self._max_batch_size = self._cfg.max_batch_size
116+
self._batch_max_wait_s = self._cfg.batch_max_wait_s
117+
self._batch_task: asyncio.Task | None = None
118+
self._running_batch_loop = False
119+
self._batch_queue: asyncio.Queue = asyncio.Queue()
120+
113121
# Initialize all replicas
114122
replicas = []
115123
num_replicas = self._cfg.num_replicas
@@ -138,6 +146,60 @@ async def __initialize__(self):
138146
self._health_loop(poll_rate_s=self._cfg.health_poll_rate)
139147
)
140148

149+
# Start batch loop if batching enabled
150+
if self._max_batch_size > 1:
151+
self._running_batch_loop = True
152+
self._batch_task = asyncio.create_task(self._batch_loop())
153+
154+
async def _batch_loop(self):
155+
"""Background task that continuously processes batches of routing requests.
156+
157+
This is the core batching logic that runs in a separate asyncio task.
158+
It collects requests from the queue and processes them in batches based
159+
on size and time constraints.
160+
161+
The loop follows these steps:
162+
1. Wait for the first request to start a new batch
163+
2. Collect additional requests until batch_max_size or batch_max_wait_s is reached
164+
3. Make a single routing decision for the entire batch
165+
4. Fulfill all futures with the selected replica
166+
167+
This process repeats indefinitely until the task is cancelled.
168+
"""
169+
while self._running_batch_loop:
170+
batch_futs = []
171+
172+
# Wait for first request
173+
fut = await self._batch_queue.get()
174+
batch_futs.append(fut)
175+
start_time = time.monotonic()
176+
177+
while True:
178+
try:
179+
timeout = max(
180+
0, self._batch_max_wait_s - (time.monotonic() - start_time)
181+
)
182+
fut = await asyncio.wait_for(
183+
self._batch_queue.get(), timeout
184+
) # wait for timeout or until self._queue.get() finishes
185+
batch_futs.append(fut)
186+
187+
if len(batch_futs) >= self._max_batch_size:
188+
break
189+
except asyncio.TimeoutError:
190+
break
191+
192+
healthy_replicas = self._get_healthy_replicas()
193+
194+
# One routing decision for the whole batch
195+
replica = self._default_router.get_replica(
196+
healthy_replicas, None, self._session_replica_map
197+
)
198+
199+
# Fulfill all futures with the chosen replica
200+
for fut in batch_futs:
201+
fut.set_result(replica)
202+
141203
async def _call(self, sess_id: str | None, function: str, *args, **kwargs):
142204
"""
143205
Routes a function call to the appropriate replica with load balancing and fault tolerance.
@@ -211,7 +273,7 @@ async def call_all(self, function: str, *args, **kwargs) -> List:
211273
Raises:
212274
RuntimeError: If no healthy replicas are available
213275
"""
214-
healthy_replicas = [r for r in self._replicas if r.healthy]
276+
healthy_replicas = self._get_healthy_replicas()
215277

216278
if not healthy_replicas:
217279
raise RuntimeError("No healthy replicas available for broadcast call")
@@ -280,9 +342,7 @@ async def _migrate_remaining_requests(self, failed_replica: Replica):
280342
)
281343

282344
# Find healthy replicas
283-
healthy_replicas = [
284-
r for r in self._replicas if r.healthy and r != failed_replica
285-
]
345+
healthy_replicas = self._get_healthy_replicas()
286346

287347
if not healthy_replicas:
288348
# No healthy replicas, fail all requests
@@ -334,7 +394,7 @@ def _update_service_metrics(self):
334394
"""Updates service-level metrics."""
335395
self._metrics.total_sessions = len(self._active_sessions)
336396
self._metrics.total_replicas = len(self._replicas)
337-
self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy)
397+
self._metrics.healthy_replicas = len(self._get_healthy_replicas())
338398
# Store direct references to replica metrics for aggregation
339399
self._metrics.replica_metrics = {}
340400
for replica in self._replicas:
@@ -446,6 +506,10 @@ async def terminate_session(self, sess_id: str):
446506
# Update metrics
447507
self._update_service_metrics()
448508

509+
def _get_healthy_replicas(self) -> list[Replica]:
510+
"""Returns a list of healthy replicas."""
511+
return [r for r in self._replicas if r.healthy]
512+
449513
async def _health_loop(self, poll_rate_s: float):
450514
"""Runs the health loop to monitor and recover replicas.
451515
@@ -476,14 +540,24 @@ async def _health_loop(self, poll_rate_s: float):
476540

477541
async def _get_replica(self, sess_id: str | None) -> "Replica":
478542
"""Get a replica for the given session ID."""
479-
healthy_replicas = [r for r in self._replicas if r.healthy]
480-
if sess_id is None:
481-
# No session, use the default router
482-
return self._default_router.get_replica(healthy_replicas)
483543

484-
return self._session_router.get_replica(
485-
healthy_replicas, sess_id, self._session_replica_map
486-
)
544+
if sess_id:
545+
# Stateful routing always uses session router
546+
healthy_replicas = self._get_healthy_replicas()
547+
return self._session_router.get_replica(
548+
healthy_replicas, sess_id, self._session_replica_map
549+
)
550+
551+
# Stateless: batching
552+
if self._max_batch_size > 1:
553+
fut = asyncio.Future()
554+
healthy_replicas = self._get_healthy_replicas()
555+
self._batch_queue.put_nowait(fut)
556+
return await fut
557+
else:
558+
# No batching, pick immediately
559+
healthy_replicas = self._get_healthy_replicas()
560+
return self._default_router.get_replica(healthy_replicas)
487561

488562
async def stop(self):
489563
logger.debug("Stopping service...")
@@ -582,7 +656,7 @@ async def _get_internal_state(self) -> dict:
582656
# Load balancing state
583657
# Service-level state
584658
"total_replicas": len(self._replicas),
585-
"healthy_replica_count": sum(1 for r in self._replicas if r.healthy),
659+
"healthy_replica_count": len(self._get_healthy_replicas()),
586660
"shutdown_requested": self._shutdown_requested,
587661
# Metrics summary
588662
"total_sessions": len(self._active_sessions),

src/forge/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ class ServiceConfig:
118118
health_poll_rate: float = 0.2
119119
replica_max_concurrent_requests: int = 10
120120
return_first_rank_result: bool = True
121+
max_batch_size: int = 1
122+
batch_max_wait_s: float = 0.01
121123

122124
def to_process_config(self) -> ProcessConfig:
123125
"""Extract ProcessConfig from this ServiceConfig.

0 commit comments

Comments
 (0)