Skip to content

Commit 54911df

Browse files
committed
Add llama.cpp server inference backend for responses_api
This is adapted from the ollama backend, but uses llama.cpp server. Another difference is that it passes/receives raw tokens from llama.cpp.
1 parent 48db88d commit 54911df

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ You can start this server with the following inference backends:
382382
- `triton` — uses the triton implementation
383383
- `metal` — uses the metal implementation on Apple Silicon only
384384
- `ollama` — uses the Ollama /api/generate API as an inference solution
385+
- `llamaccp_server` — uses Llama.cpp server /complete API as an inference solution
385386
- `vllm` — uses your installed vllm version to perform inference
386387
- `transformers` — uses your installed transformers version to perform local inference
387388

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
Inference with llama.cpp server /completion endpoint with raw token ids.
3+
"""
4+
5+
import json
6+
import threading
7+
import time
8+
from typing import Callable, Optional
9+
10+
import requests
11+
12+
EOS_TOKEN = 200002 # only used on hard timeout
13+
14+
# Tunables
15+
POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
16+
CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
17+
NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
18+
FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
19+
20+
# Shared state
21+
_token_buffer: list[int] = []
22+
_buffer_lock = threading.Lock()
23+
_stream_thread: Optional[threading.Thread] = None
24+
_stream_done = threading.Event()
25+
_stream_error: Optional[Exception] = None
26+
_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
27+
_previous_request_tokens: list[int] = []
28+
29+
30+
def lcp(cache: list[int], inp: list[int]) -> list[int]:
31+
i = 0
32+
max_len = min(len(cache), len(inp))
33+
while i < max_len and cache[i] == inp[i]:
34+
i += 1
35+
return cache[:i]
36+
37+
38+
def _now():
39+
return time.monotonic()
40+
41+
42+
def _touch_progress():
43+
global _last_progress_ts
44+
_last_progress_ts = _now()
45+
46+
47+
def _reset_stream_state():
48+
global _token_buffer, _stream_thread, _stream_error
49+
with _buffer_lock:
50+
_token_buffer = []
51+
_stream_done.clear()
52+
_stream_thread = None
53+
_stream_error = None
54+
_touch_progress()
55+
56+
57+
def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]:
58+
# For llama-server, checkpoint is the base URL (e.g., "http://localhost:8080")
59+
server_url = checkpoint if checkpoint.startswith("http") else f"http://{checkpoint}"
60+
61+
def _start_stream(token_ids: list[int], temperature: float):
62+
def run():
63+
nonlocal temperature
64+
global _stream_error
65+
global _previous_request_tokens
66+
67+
toks = []
68+
last_len = 0 # number of tokens already emitted
69+
70+
try:
71+
url = f"{server_url}/completion"
72+
73+
payload = {
74+
"prompt": token_ids,
75+
"stream": True,
76+
"temperature": temperature,
77+
"return_tokens": True,
78+
"cache_prompt": True, # Re-use KV cache for better performance
79+
"n_predict": -1, # Generate until EOS or stop condition
80+
}
81+
82+
with requests.post(url, json=payload, stream=True, timeout=60) as resp:
83+
resp.raise_for_status()
84+
for line in resp.iter_lines(decode_unicode=True):
85+
if not line:
86+
continue
87+
88+
# llama-server uses Server-sent events format
89+
if line.startswith("data: "):
90+
line = line[6:] # Remove "data: " prefix
91+
92+
obj = json.loads(line)
93+
chunk_tokens = obj.get('tokens')
94+
95+
if chunk_tokens is not None:
96+
toks += chunk_tokens
97+
if len(toks) > last_len:
98+
new_toks = toks[last_len:]
99+
with _buffer_lock:
100+
_token_buffer.extend(new_toks)
101+
last_len = len(toks)
102+
_touch_progress()
103+
104+
# Check if generation is complete
105+
if obj.get("stop", False):
106+
_token_buffer.append(EOS_TOKEN)
107+
_touch_progress()
108+
break
109+
110+
_stream_done.set()
111+
112+
except Exception as e:
113+
_stream_error = e
114+
_stream_done.set()
115+
116+
t = threading.Thread(target=run, name="llama-server-stream", daemon=True)
117+
t.start()
118+
return t
119+
120+
def infer_next_token(
121+
tokens: list[int], temperature: float = 0.0, new_request: bool = False
122+
) -> int:
123+
"""
124+
- Starts a new llama-server stream on new_request.
125+
- Forwards tokens as they arrive.
126+
- Only emits EOS_TOKEN if we exceed an inactivity timeout.
127+
"""
128+
global _stream_thread
129+
130+
if new_request:
131+
_reset_stream_state()
132+
_stream_thread = _start_stream(token_ids=tokens, temperature=temperature)
133+
# Wait for first byte within FIRST_BYTE_TIMEOUT_S (without emitting EOS early)
134+
start = _now()
135+
while _now() - start < FIRST_BYTE_TIMEOUT_S:
136+
with _buffer_lock:
137+
if _token_buffer:
138+
tok = _token_buffer.pop(0)
139+
_touch_progress()
140+
return tok
141+
if _stream_error is not None:
142+
raise RuntimeError(f"llama-server stream error: {_stream_error!r}")
143+
# If llama-server finished instantly with no output, continue loop until timeout
144+
time.sleep(POLL_INTERVAL_S)
145+
# Hard first-byte timeout -> emit EOS so the server can stop this request
146+
return EOS_TOKEN
147+
148+
if _stream_error is not None:
149+
raise RuntimeError(f"llama-server stream error: {_stream_error!r}")
150+
151+
# Normal path: wait up to CALL_MAX_WAIT_S for a token to arrive
152+
wait_start = _now()
153+
while _now() - wait_start < CALL_MAX_WAIT_S:
154+
with _buffer_lock:
155+
if _token_buffer:
156+
tok = _token_buffer.pop(0)
157+
_touch_progress()
158+
return tok
159+
# No token yet; if we've been idle too long overall, end with EOS
160+
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
161+
return EOS_TOKEN
162+
time.sleep(POLL_INTERVAL_S)
163+
164+
# Still no token in this call slice. Do NOT send EOS unless we've timed out.
165+
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
166+
return EOS_TOKEN
167+
168+
# Tell caller to call us again; block minimally by returning *nothing new*.
169+
# We must return an int; safest is to wait a tiny bit longer for a token.
170+
# If still none, keep returning only after short waits. Avoid EOS here.
171+
# One more short wait to reduce hot-looping:
172+
time.sleep(POLL_INTERVAL_S)
173+
with _buffer_lock:
174+
if _token_buffer:
175+
tok = _token_buffer.pop(0)
176+
_touch_progress()
177+
return tok
178+
179+
# As a last resort for this call slice, return EOS only on true inactivity timeout.
180+
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
181+
return EOS_TOKEN
182+
183+
# If we reach here, we still haven't got a token—ask the caller to call again soon.
184+
# Return a harmless token that the server will replace/ignore if your interface supports it.
185+
# If your interface does NOT allow a sentinel, keep the short-blocking behavior above.
186+
return (
187+
EOS_TOKEN if False else 0
188+
) # replace `0` with a PAD/NOOP token your server ignores
189+
190+
return infer_next_token
191+

gpt_oss/responses_api/serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from .inference.vllm import setup_model
5050
elif args.inference_backend == "transformers":
5151
from .inference.transformers import setup_model
52+
elif args.inference_backend == "llamacpp_server":
53+
from .inference.llamacpp_server import setup_model
5254
else:
5355
raise ValueError(f"Invalid inference backend: {args.inference_backend}")
5456

0 commit comments

Comments
 (0)