Skip to content

Commit 51f68ae

Browse files
authored
[Feature] add dealer manager to reuse the connection (#3471)
* [BugFix] fix control signal release failed * [BugFix] fix control signal release failed * update * update * update * [Feature] add dealer manager to reuse the connection * fix * fix * fix * fix * fix * fix * Create test_dealer_connection_manager.py * Delete test/entrypoints/openai directory * Update test_dealer_connection_manager.py * Update test_dealer_connection_manager.py
1 parent 985b126 commit 51f68ae

File tree

7 files changed

+360
-25
lines changed

7 files changed

+360
-25
lines changed

fastdeploy/entrypoints/engine_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
import os
1718
import time
1819
import traceback
1920
import uuid
@@ -22,6 +23,7 @@
2223

2324
from fastdeploy import envs
2425
from fastdeploy.engine.config import ModelConfig
26+
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
2527
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
2628
from fastdeploy.input.preprocess import InputPreprocessor
2729
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
@@ -91,6 +93,10 @@ def __init__(
9193
suffix=pid,
9294
create=False,
9395
)
96+
self.connection_manager = DealerConnectionManager(
97+
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
98+
)
99+
self.connection_initialized = False
94100

95101
def create_zmq_client(self, model, mode):
96102
"""

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ async def lifespan(app: FastAPI):
154154
yield
155155
# close zmq
156156
try:
157+
await engine_client.connection_manager.close()
157158
engine_client.zmq_client.close()
158159
from prometheus_client import multiprocess
159160

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
import uuid
2121
from typing import List, Optional
2222

23-
import aiozmq
24-
import msgpack
2523
import numpy as np
26-
from aiozmq import zmq
2724

2825
from fastdeploy.entrypoints.openai.protocol import (
2926
ChatCompletionRequest,
@@ -62,6 +59,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time, chat_template):
6259
else:
6360
self.master_ip = self.master_ip.split(",")[0]
6461

62+
async def _ensure_connection_manager(self):
63+
"""ensure connection manager initialized"""
64+
if not self.engine_client.connection_initialized:
65+
await self.engine_client.connection_manager.initialize()
66+
self.engine_client.connection_initialized = True
67+
6568
def _check_master(self):
6669
if self.master_ip is None:
6770
return True
@@ -180,14 +183,16 @@ async def chat_completion_stream_generator(
180183
choices=[],
181184
model=model_name,
182185
)
186+
183187
try:
184-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
188+
await self._ensure_connection_manager()
189+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
185190
dealer.write([b"", request_id.encode("utf-8")])
186191
choices = []
187192
current_waiting_time = 0
188193
while num_choices > 0:
189194
try:
190-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
195+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
191196
current_waiting_time = 0
192197
except asyncio.TimeoutError:
193198
current_waiting_time += 10
@@ -202,7 +207,6 @@ async def chat_completion_stream_generator(
202207
current_waiting_time = 0
203208
await asyncio.sleep(0.01)
204209
continue
205-
response = msgpack.unpackb(raw_data[-1])
206210
for res in response:
207211
if res.get("error_code", 200) != 200:
208212
raise ValueError("{}".format(res["error_msg"]))
@@ -353,9 +357,9 @@ async def chat_completion_stream_generator(
353357
)
354358
yield f"data: {error_data}\n\n"
355359
finally:
356-
dealer.close()
360+
await self.engine_client.connection_manager.cleanup_request(request_id)
357361
self.engine_client.semaphore.release()
358-
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
362+
api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}")
359363
yield "data: [DONE]\n\n"
360364

361365
async def chat_completion_full_generator(
@@ -378,7 +382,8 @@ async def chat_completion_full_generator(
378382
include_stop_str_in_output = request.include_stop_str_in_output
379383

380384
try:
381-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
385+
await self._ensure_connection_manager()
386+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
382387
dealer.write([b"", request_id.encode("utf-8")])
383388
final_res = None
384389
previous_num_tokens = 0
@@ -387,7 +392,7 @@ async def chat_completion_full_generator(
387392
completion_token_ids = []
388393
while True:
389394
try:
390-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
395+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
391396
current_waiting_time = 0
392397
except asyncio.TimeoutError:
393398
current_waiting_time += 10
@@ -400,7 +405,6 @@ async def chat_completion_full_generator(
400405
await asyncio.sleep(0.1)
401406
continue
402407

403-
response = msgpack.unpackb(raw_data[-1])
404408
task_is_finished = False
405409
for data in response:
406410
if data.get("error_code", 200) != 200:
@@ -430,7 +434,7 @@ async def chat_completion_full_generator(
430434
if task_is_finished:
431435
break
432436
finally:
433-
dealer.close()
437+
await self.engine_client.connection_manager.cleanup_request(request_id)
434438
self.engine_client.semaphore.release()
435439
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
436440

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
import uuid
2121
from typing import List, Optional
2222

23-
import aiozmq
24-
import msgpack
2523
import numpy as np
26-
from aiozmq import zmq
2724

2825
from fastdeploy.engine.request import RequestOutput
2926
from fastdeploy.entrypoints.openai.protocol import (
@@ -53,6 +50,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time):
5350
else:
5451
self.master_ip = self.master_ip.split(",")[0]
5552

53+
async def _ensure_connection_manager(self):
54+
"""ensure connection manager initialized"""
55+
if not self.engine_client.connection_initialized:
56+
await self.engine_client.connection_manager.initialize()
57+
self.engine_client.connection_initialized = True
58+
5659
def _check_master(self):
5760
if self.master_ip is None:
5861
return True
@@ -185,7 +188,10 @@ async def completion_full_generator(
185188
try:
186189
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
187190
# create dealer
188-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
191+
await self._ensure_connection_manager()
192+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
193+
request_id, num_choices
194+
)
189195

190196
for rid in request_ids:
191197
dealer.write([b"", rid.encode("utf-8")])
@@ -198,7 +204,7 @@ async def completion_full_generator(
198204
current_waiting_time = 0
199205
while num_choices > 0:
200206
try:
201-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
207+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
202208
current_waiting_time = 0
203209
except asyncio.TimeoutError:
204210
current_waiting_time += 10
@@ -210,7 +216,7 @@ async def completion_full_generator(
210216
current_waiting_time = 0
211217
await asyncio.sleep(0.1)
212218
continue
213-
response = msgpack.unpackb(raw_data[-1])
219+
214220
for data in response:
215221
rid = int(data["request_id"].split("-")[-1])
216222
if data.get("error_code", 200) != 200:
@@ -255,7 +261,7 @@ async def completion_full_generator(
255261
finally:
256262
self.engine_client.semaphore.release()
257263
if dealer is not None:
258-
dealer.close()
264+
await self.engine_client.connection_manager.cleanup_request(request_id)
259265

260266
async def _echo_back_prompt(self, request, res, idx):
261267
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
@@ -288,7 +294,10 @@ async def completion_stream_generator(
288294
Process the stream completion request.
289295
"""
290296
try:
291-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
297+
await self._ensure_connection_manager()
298+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
299+
request_id, num_choices
300+
)
292301

293302
for i in range(num_choices):
294303
req_id = f"{request_id}-{i}"
@@ -312,7 +321,7 @@ async def completion_stream_generator(
312321
current_waiting_time = 0
313322
while num_choices > 0:
314323
try:
315-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
324+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
316325
current_waiting_time = 0
317326
except asyncio.TimeoutError:
318327
current_waiting_time += 10
@@ -325,7 +334,6 @@ async def completion_stream_generator(
325334
await asyncio.sleep(0.1)
326335
continue
327336

328-
response = msgpack.unpackb(raw_data[-1])
329337
for res in response:
330338
idx = int(res["request_id"].split("-")[-1])
331339
if res.get("error_code", 200) != 200:
@@ -453,9 +461,9 @@ async def completion_stream_generator(
453461
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
454462
finally:
455463
del request
456-
self.engine_client.semaphore.release()
457464
if dealer is not None:
458-
dealer.close()
465+
await self.engine_client.connection_manager.cleanup_request(request_id)
466+
self.engine_client.semaphore.release()
459467
yield "data: [DONE]\n\n"
460468

461469
def request_output_to_completion_response(
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import asyncio
18+
import heapq
19+
import random
20+
21+
import aiozmq
22+
import msgpack
23+
import zmq
24+
25+
from fastdeploy.utils import api_server_logger
26+
27+
28+
class DealerConnectionManager:
29+
"""
30+
Manager for dealer connections, supporting multiplexing and connection reuse
31+
"""
32+
33+
def __init__(self, pid, max_connections=10):
34+
self.pid = pid
35+
self.max_connections = max(max_connections, 10)
36+
self.connections = []
37+
self.connection_load = []
38+
self.connection_heap = []
39+
self.request_map = {} # request_id -> response_queue
40+
self.request_num = {} # request_id -> num_choices
41+
self.lock = asyncio.Lock()
42+
self.connection_tasks = []
43+
self.running = False
44+
45+
async def initialize(self):
46+
"""initialize all connections"""
47+
self.running = True
48+
for index in range(self.max_connections):
49+
await self._add_connection(index)
50+
api_server_logger.info(f"Started {self.max_connections} connections")
51+
52+
async def _add_connection(self, index):
53+
"""create a new connection and start listening task"""
54+
try:
55+
dealer = await aiozmq.create_zmq_stream(
56+
zmq.DEALER,
57+
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
58+
)
59+
async with self.lock:
60+
self.connections.append(dealer)
61+
self.connection_load.append(0)
62+
heapq.heappush(self.connection_heap, (0, index))
63+
64+
# start listening
65+
task = asyncio.create_task(self._listen_connection(dealer, index))
66+
self.connection_tasks.append(task)
67+
return True
68+
except Exception as e:
69+
api_server_logger.error(f"Failed to create dealer: {str(e)}")
70+
return False
71+
72+
async def _listen_connection(self, dealer, conn_index):
73+
"""
74+
listen for messages from the dealer connection
75+
"""
76+
while self.running:
77+
try:
78+
raw_data = await dealer.read()
79+
response = msgpack.unpackb(raw_data[-1])
80+
request_id = response[-1]["request_id"]
81+
if "cmpl" == request_id[:4]:
82+
request_id = request_id.rsplit("-", 1)[0]
83+
async with self.lock:
84+
if request_id in self.request_map:
85+
await self.request_map[request_id].put(response)
86+
if response[-1]["finished"]:
87+
self.request_num[request_id] -= 1
88+
if self.request_num[request_id] == 0:
89+
self._update_load(conn_index, -1)
90+
except Exception as e:
91+
api_server_logger.error(f"Listener error: {str(e)}")
92+
break
93+
94+
def _update_load(self, conn_index, delta):
95+
"""Update connection load and maintain the heap"""
96+
self.connection_load[conn_index] += delta
97+
heapq.heapify(self.connection_heap)
98+
99+
# For Debugging purposes
100+
if random.random() < 0.01:
101+
min_load = self.connection_heap[0][0] if self.connection_heap else 0
102+
max_load = max(self.connection_load) if self.connection_load else 0
103+
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
104+
105+
def _get_least_loaded_connection(self):
106+
"""
107+
Get the least loaded connection
108+
"""
109+
if not self.connection_heap:
110+
return None
111+
112+
load, conn_index = self.connection_heap[0]
113+
self._update_load(conn_index, 1)
114+
115+
return self.connections[conn_index]
116+
117+
async def get_connection(self, request_id, num_choices=1):
118+
"""get a connection for the request"""
119+
120+
response_queue = asyncio.Queue()
121+
122+
async with self.lock:
123+
self.request_map[request_id] = response_queue
124+
self.request_num[request_id] = num_choices
125+
dealer = self._get_least_loaded_connection()
126+
if not dealer:
127+
raise RuntimeError("No available connections")
128+
129+
return dealer, response_queue
130+
131+
async def cleanup_request(self, request_id):
132+
"""
133+
clean up the request after it is finished
134+
"""
135+
async with self.lock:
136+
if request_id in self.request_map:
137+
del self.request_map[request_id]
138+
del self.request_num[request_id]
139+
140+
async def close(self):
141+
"""
142+
close all connections and tasks
143+
"""
144+
self.running = False
145+
146+
for task in self.connection_tasks:
147+
task.cancel()
148+
149+
async with self.lock:
150+
for dealer in self.connections:
151+
try:
152+
dealer.close()
153+
except:
154+
pass
155+
self.connections.clear()
156+
self.connection_load.clear()
157+
self.request_map.clear()
158+
159+
api_server_logger.info("All connections and tasks closed")

0 commit comments

Comments
 (0)