Skip to content

Commit ab02aed

Browse files
CharlesR-Wclaude
andcommitted
feat: add --server-port option to connect to external vLLM server
When set, uses OpenAI-compatible API to talk to a vLLM server instead of loading the model in-process. Default behavior unchanged. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent b253b96 commit ab02aed

File tree

4 files changed

+94
-13
lines changed

4 files changed

+94
-13
lines changed

delphi/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ async def process_cache(
151151
max_model_len=run_cfg.explainer_model_max_len,
152152
num_gpus=run_cfg.num_gpus,
153153
statistics=run_cfg.verbose,
154+
server_port=run_cfg.server_port,
154155
)
155156
elif run_cfg.explainer_provider == "openrouter":
156157
if (

delphi/clients/offline.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import Union
77

8+
from openai import AsyncOpenAI
89
from transformers import AutoTokenizer
910
from vllm import LLM, SamplingParams
1011
from vllm.distributed.parallel_state import (
@@ -51,27 +52,47 @@ def __init__(
5152
num_gpus: int = 2,
5253
enforce_eager: bool = False,
5354
statistics: bool = False,
55+
server_port: int | None = None,
5456
):
5557
"""Client for offline generation. Models not already present in the on-disk
5658
HuggingFace cache will be downloaded. Note that temperature must be increased
5759
for best-of-n sampling.
60+
61+
If server_port is provided, connects to an external vLLM server running on
62+
localhost at the specified port via the OpenAI-compatible API, instead of
63+
loading the model locally.
5864
"""
5965
super().__init__(model)
6066
self.model = model
67+
self.max_model_len = max_model_len
6168
self.queue = asyncio.Queue()
6269
self.task = None
63-
self.client = LLM(
64-
model=model,
65-
gpu_memory_utilization=max_memory,
66-
enable_prefix_caching=prefix_caching,
67-
tensor_parallel_size=num_gpus,
68-
max_model_len=max_model_len,
69-
enforce_eager=enforce_eager,
70-
)
70+
self.server_port = server_port
71+
72+
if server_port is None:
73+
# Local mode: load model in-process
74+
self.client = LLM(
75+
model=model,
76+
gpu_memory_utilization=max_memory,
77+
enable_prefix_caching=prefix_caching,
78+
tensor_parallel_size=num_gpus,
79+
max_model_len=max_model_len,
80+
enforce_eager=enforce_eager,
81+
)
82+
self.openai_client = None
83+
else:
84+
# Server mode: connect to external vLLM server
85+
self.client = None
86+
self.openai_client = AsyncOpenAI(
87+
base_url=f"http://localhost:{server_port}/v1",
88+
api_key="EMPTY",
89+
)
90+
7191
self.sampling_params = SamplingParams(max_tokens=number_tokens_to_generate)
7292
self.tokenizer = AutoTokenizer.from_pretrained(model)
7393
self.batch_size = batch_size
7494
self.statistics = statistics
95+
self.number_tokens_to_generate = number_tokens_to_generate
7596

7697
if self.statistics:
7798
self.statistics_path = Path("statistics")
@@ -168,12 +189,58 @@ async def generate(
168189
"""
169190
Enqueue a request and wait for the result.
170191
"""
192+
if self.server_port is not None:
193+
# Server mode: use OpenAI-compatible API directly
194+
return await self._generate_server(prompt, **kwargs)
195+
196+
# Local mode: use batching queue
171197
future = asyncio.Future()
172198
if self.task is None:
173199
self.task = asyncio.create_task(self._process_batches())
174200
await self.queue.put((prompt, future, kwargs))
175201
return await future
176202

203+
async def _generate_server(
204+
self, prompt: Union[str, list[dict[str, str]]], **kwargs
205+
) -> Response:
206+
"""
207+
Generate using external vLLM server via OpenAI-compatible API.
208+
"""
209+
temperature = kwargs.get("temperature", 0.0)
210+
max_tokens = kwargs.get("max_tokens", self.number_tokens_to_generate)
211+
212+
# Handle logprobs if requested
213+
logprobs = kwargs.get("logprobs", False)
214+
top_logprobs = kwargs.get("top_logprobs", None) if logprobs else None
215+
216+
messages = prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}]
217+
218+
response = await self.openai_client.chat.completions.create(
219+
model=self.model,
220+
messages=messages,
221+
temperature=temperature,
222+
max_tokens=max_tokens,
223+
logprobs=logprobs,
224+
top_logprobs=top_logprobs,
225+
)
226+
227+
text = response.choices[0].message.content or ""
228+
229+
# Parse logprobs from OpenAI format if present
230+
parsed_logprobs = None
231+
if logprobs and response.choices[0].logprobs:
232+
parsed_logprobs = []
233+
for token_logprob in response.choices[0].logprobs.content or []:
234+
top_lps = [
235+
Top_Logprob(token=lp.token, logprob=lp.logprob)
236+
for lp in (token_logprob.top_logprobs or [])
237+
]
238+
parsed_logprobs.append(
239+
Logprobs(token=token_logprob.token, top_logprobs=top_lps)
240+
)
241+
242+
return Response(text=text, logprobs=parsed_logprobs, prompt_logprobs=None)
243+
177244
def _parse_logprobs(self, response):
178245
response_tokens = response.outputs[0].token_ids
179246
logprobs = response.outputs[0].logprobs
@@ -253,10 +320,17 @@ async def close(self):
253320
"""
254321
Clean up resources when the client is no longer needed.
255322
"""
256-
destroy_model_parallel()
257-
destroy_distributed_environment()
258-
del self.client
259-
self.client = None
323+
if self.client is not None:
324+
# Only destroy local model resources in local mode
325+
destroy_model_parallel()
326+
destroy_distributed_environment()
327+
del self.client
328+
self.client = None
329+
330+
if self.openai_client is not None:
331+
await self.openai_client.close()
332+
self.openai_client = None
333+
260334
if self.task:
261335
self.task.cancel()
262336
try:

delphi/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ class RunConfig(Serializable):
140140
"""Provider to use for explanation and scoring. Options are 'offline' for local
141141
models and 'openrouter' for API calls."""
142142

143+
server_port: int | None = field(default=None)
144+
"""Port for external vLLM server. If set, connects to a vLLM server running on
145+
localhost at this port via OpenAI-compatible API instead of loading the model
146+
locally. Start a server with: vllm serve <model> --port <port>"""
147+
143148
explainer: str = field(
144149
choices=["default", "none"],
145150
default="default",

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ dependencies = [
2828
"anyio>=4.8.0",
2929
"faiss-cpu",
3030
"asyncer>=0.0.8",
31-
"beartype"
31+
"beartype",
32+
"openai>=1.0.0",
3233
]
3334

3435
[project.optional-dependencies]

0 commit comments

Comments
 (0)