|
5 | 5 | from pathlib import Path |
6 | 6 | from typing import Union |
7 | 7 |
|
| 8 | +from openai import AsyncOpenAI |
8 | 9 | from transformers import AutoTokenizer |
9 | 10 | from vllm import LLM, SamplingParams |
10 | 11 | from vllm.distributed.parallel_state import ( |
@@ -51,27 +52,47 @@ def __init__( |
51 | 52 | num_gpus: int = 2, |
52 | 53 | enforce_eager: bool = False, |
53 | 54 | statistics: bool = False, |
| 55 | + server_port: int | None = None, |
54 | 56 | ): |
55 | 57 | """Client for offline generation. Models not already present in the on-disk |
56 | 58 | HuggingFace cache will be downloaded. Note that temperature must be increased |
57 | 59 | for best-of-n sampling. |
| 60 | +
|
| 61 | + If server_port is provided, connects to an external vLLM server running on |
| 62 | + localhost at the specified port via the OpenAI-compatible API, instead of |
| 63 | + loading the model locally. |
58 | 64 | """ |
59 | 65 | super().__init__(model) |
60 | 66 | self.model = model |
| 67 | + self.max_model_len = max_model_len |
61 | 68 | self.queue = asyncio.Queue() |
62 | 69 | self.task = None |
63 | | - self.client = LLM( |
64 | | - model=model, |
65 | | - gpu_memory_utilization=max_memory, |
66 | | - enable_prefix_caching=prefix_caching, |
67 | | - tensor_parallel_size=num_gpus, |
68 | | - max_model_len=max_model_len, |
69 | | - enforce_eager=enforce_eager, |
70 | | - ) |
| 70 | + self.server_port = server_port |
| 71 | + |
| 72 | + if server_port is None: |
| 73 | + # Local mode: load model in-process |
| 74 | + self.client = LLM( |
| 75 | + model=model, |
| 76 | + gpu_memory_utilization=max_memory, |
| 77 | + enable_prefix_caching=prefix_caching, |
| 78 | + tensor_parallel_size=num_gpus, |
| 79 | + max_model_len=max_model_len, |
| 80 | + enforce_eager=enforce_eager, |
| 81 | + ) |
| 82 | + self.openai_client = None |
| 83 | + else: |
| 84 | + # Server mode: connect to external vLLM server |
| 85 | + self.client = None |
| 86 | + self.openai_client = AsyncOpenAI( |
| 87 | + base_url=f"http://localhost:{server_port}/v1", |
| 88 | + api_key="EMPTY", |
| 89 | + ) |
| 90 | + |
71 | 91 | self.sampling_params = SamplingParams(max_tokens=number_tokens_to_generate) |
72 | 92 | self.tokenizer = AutoTokenizer.from_pretrained(model) |
73 | 93 | self.batch_size = batch_size |
74 | 94 | self.statistics = statistics |
| 95 | + self.number_tokens_to_generate = number_tokens_to_generate |
75 | 96 |
|
76 | 97 | if self.statistics: |
77 | 98 | self.statistics_path = Path("statistics") |
@@ -168,12 +189,58 @@ async def generate( |
168 | 189 | """ |
169 | 190 | Enqueue a request and wait for the result. |
170 | 191 | """ |
| 192 | + if self.server_port is not None: |
| 193 | + # Server mode: use OpenAI-compatible API directly |
| 194 | + return await self._generate_server(prompt, **kwargs) |
| 195 | + |
| 196 | + # Local mode: use batching queue |
171 | 197 | future = asyncio.Future() |
172 | 198 | if self.task is None: |
173 | 199 | self.task = asyncio.create_task(self._process_batches()) |
174 | 200 | await self.queue.put((prompt, future, kwargs)) |
175 | 201 | return await future |
176 | 202 |
|
| 203 | + async def _generate_server( |
| 204 | + self, prompt: Union[str, list[dict[str, str]]], **kwargs |
| 205 | + ) -> Response: |
| 206 | + """ |
| 207 | + Generate using external vLLM server via OpenAI-compatible API. |
| 208 | + """ |
| 209 | + temperature = kwargs.get("temperature", 0.0) |
| 210 | + max_tokens = kwargs.get("max_tokens", self.number_tokens_to_generate) |
| 211 | + |
| 212 | + # Handle logprobs if requested |
| 213 | + logprobs = kwargs.get("logprobs", False) |
| 214 | + top_logprobs = kwargs.get("top_logprobs", None) if logprobs else None |
| 215 | + |
| 216 | + messages = prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] |
| 217 | + |
| 218 | + response = await self.openai_client.chat.completions.create( |
| 219 | + model=self.model, |
| 220 | + messages=messages, |
| 221 | + temperature=temperature, |
| 222 | + max_tokens=max_tokens, |
| 223 | + logprobs=logprobs, |
| 224 | + top_logprobs=top_logprobs, |
| 225 | + ) |
| 226 | + |
| 227 | + text = response.choices[0].message.content or "" |
| 228 | + |
| 229 | + # Parse logprobs from OpenAI format if present |
| 230 | + parsed_logprobs = None |
| 231 | + if logprobs and response.choices[0].logprobs: |
| 232 | + parsed_logprobs = [] |
| 233 | + for token_logprob in response.choices[0].logprobs.content or []: |
| 234 | + top_lps = [ |
| 235 | + Top_Logprob(token=lp.token, logprob=lp.logprob) |
| 236 | + for lp in (token_logprob.top_logprobs or []) |
| 237 | + ] |
| 238 | + parsed_logprobs.append( |
| 239 | + Logprobs(token=token_logprob.token, top_logprobs=top_lps) |
| 240 | + ) |
| 241 | + |
| 242 | + return Response(text=text, logprobs=parsed_logprobs, prompt_logprobs=None) |
| 243 | + |
177 | 244 | def _parse_logprobs(self, response): |
178 | 245 | response_tokens = response.outputs[0].token_ids |
179 | 246 | logprobs = response.outputs[0].logprobs |
@@ -253,10 +320,17 @@ async def close(self): |
253 | 320 | """ |
254 | 321 | Clean up resources when the client is no longer needed. |
255 | 322 | """ |
256 | | - destroy_model_parallel() |
257 | | - destroy_distributed_environment() |
258 | | - del self.client |
259 | | - self.client = None |
| 323 | + if self.client is not None: |
| 324 | + # Only destroy local model resources in local mode |
| 325 | + destroy_model_parallel() |
| 326 | + destroy_distributed_environment() |
| 327 | + del self.client |
| 328 | + self.client = None |
| 329 | + |
| 330 | + if self.openai_client is not None: |
| 331 | + await self.openai_client.close() |
| 332 | + self.openai_client = None |
| 333 | + |
260 | 334 | if self.task: |
261 | 335 | self.task.cancel() |
262 | 336 | try: |
|
0 commit comments