Skip to content

Commit 34e8e7a

Browse files
committed
Extract chat completions streaming helpers
1 parent 80de53e commit 34e8e7a

File tree

2 files changed

+301
-275
lines changed

2 files changed

+301
-275
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import AsyncIterator
4+
from dataclasses import dataclass, field
5+
6+
from openai import AsyncStream
7+
from openai.types.chat import ChatCompletionChunk
8+
from openai.types.completion_usage import CompletionUsage
9+
from openai.types.responses import (
10+
Response,
11+
ResponseCompletedEvent,
12+
ResponseContentPartAddedEvent,
13+
ResponseContentPartDoneEvent,
14+
ResponseCreatedEvent,
15+
ResponseFunctionCallArgumentsDeltaEvent,
16+
ResponseFunctionToolCall,
17+
ResponseOutputItem,
18+
ResponseOutputItemAddedEvent,
19+
ResponseOutputItemDoneEvent,
20+
ResponseOutputMessage,
21+
ResponseOutputRefusal,
22+
ResponseOutputText,
23+
ResponseRefusalDeltaEvent,
24+
ResponseTextDeltaEvent,
25+
ResponseUsage,
26+
)
27+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
28+
29+
from ..items import TResponseStreamEvent
30+
from .fake_id import FAKE_RESPONSES_ID
31+
32+
33+
@dataclass
34+
class StreamingState:
35+
started: bool = False
36+
text_content_index_and_output: tuple[int, ResponseOutputText] | None = None
37+
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
38+
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
39+
40+
41+
class ChatCmplStreamHandler:
42+
@classmethod
43+
async def handle_stream(
44+
cls,
45+
response: Response,
46+
stream: AsyncStream[ChatCompletionChunk],
47+
) -> AsyncIterator[TResponseStreamEvent]:
48+
usage: CompletionUsage | None = None
49+
state = StreamingState()
50+
51+
async for chunk in stream:
52+
if not state.started:
53+
state.started = True
54+
yield ResponseCreatedEvent(
55+
response=response,
56+
type="response.created",
57+
)
58+
59+
usage = chunk.usage
60+
61+
if not chunk.choices or not chunk.choices[0].delta:
62+
continue
63+
64+
delta = chunk.choices[0].delta
65+
66+
# Handle text
67+
if delta.content:
68+
if not state.text_content_index_and_output:
69+
# Initialize a content tracker for streaming text
70+
state.text_content_index_and_output = (
71+
0 if not state.refusal_content_index_and_output else 1,
72+
ResponseOutputText(
73+
text="",
74+
type="output_text",
75+
annotations=[],
76+
),
77+
)
78+
# Start a new assistant message stream
79+
assistant_item = ResponseOutputMessage(
80+
id=FAKE_RESPONSES_ID,
81+
content=[],
82+
role="assistant",
83+
type="message",
84+
status="in_progress",
85+
)
86+
# Notify consumers of the start of a new output message + first content part
87+
yield ResponseOutputItemAddedEvent(
88+
item=assistant_item,
89+
output_index=0,
90+
type="response.output_item.added",
91+
)
92+
yield ResponseContentPartAddedEvent(
93+
content_index=state.text_content_index_and_output[0],
94+
item_id=FAKE_RESPONSES_ID,
95+
output_index=0,
96+
part=ResponseOutputText(
97+
text="",
98+
type="output_text",
99+
annotations=[],
100+
),
101+
type="response.content_part.added",
102+
)
103+
# Emit the delta for this segment of content
104+
yield ResponseTextDeltaEvent(
105+
content_index=state.text_content_index_and_output[0],
106+
delta=delta.content,
107+
item_id=FAKE_RESPONSES_ID,
108+
output_index=0,
109+
type="response.output_text.delta",
110+
)
111+
# Accumulate the text into the response part
112+
state.text_content_index_and_output[1].text += delta.content
113+
114+
# Handle refusals (model declines to answer)
115+
if delta.refusal:
116+
if not state.refusal_content_index_and_output:
117+
# Initialize a content tracker for streaming refusal text
118+
state.refusal_content_index_and_output = (
119+
0 if not state.text_content_index_and_output else 1,
120+
ResponseOutputRefusal(refusal="", type="refusal"),
121+
)
122+
# Start a new assistant message if one doesn't exist yet (in-progress)
123+
assistant_item = ResponseOutputMessage(
124+
id=FAKE_RESPONSES_ID,
125+
content=[],
126+
role="assistant",
127+
type="message",
128+
status="in_progress",
129+
)
130+
# Notify downstream that assistant message + first content part are starting
131+
yield ResponseOutputItemAddedEvent(
132+
item=assistant_item,
133+
output_index=0,
134+
type="response.output_item.added",
135+
)
136+
yield ResponseContentPartAddedEvent(
137+
content_index=state.refusal_content_index_and_output[0],
138+
item_id=FAKE_RESPONSES_ID,
139+
output_index=0,
140+
part=ResponseOutputText(
141+
text="",
142+
type="output_text",
143+
annotations=[],
144+
),
145+
type="response.content_part.added",
146+
)
147+
# Emit the delta for this segment of refusal
148+
yield ResponseRefusalDeltaEvent(
149+
content_index=state.refusal_content_index_and_output[0],
150+
delta=delta.refusal,
151+
item_id=FAKE_RESPONSES_ID,
152+
output_index=0,
153+
type="response.refusal.delta",
154+
)
155+
# Accumulate the refusal string in the output part
156+
state.refusal_content_index_and_output[1].refusal += delta.refusal
157+
158+
# Handle tool calls
159+
# Because we don't know the name of the function until the end of the stream, we'll
160+
# save everything and yield events at the end
161+
if delta.tool_calls:
162+
for tc_delta in delta.tool_calls:
163+
if tc_delta.index not in state.function_calls:
164+
state.function_calls[tc_delta.index] = ResponseFunctionToolCall(
165+
id=FAKE_RESPONSES_ID,
166+
arguments="",
167+
name="",
168+
type="function_call",
169+
call_id="",
170+
)
171+
tc_function = tc_delta.function
172+
173+
state.function_calls[tc_delta.index].arguments += (
174+
tc_function.arguments if tc_function else ""
175+
) or ""
176+
state.function_calls[tc_delta.index].name += (
177+
tc_function.name if tc_function else ""
178+
) or ""
179+
state.function_calls[tc_delta.index].call_id += tc_delta.id or ""
180+
181+
function_call_starting_index = 0
182+
if state.text_content_index_and_output:
183+
function_call_starting_index += 1
184+
# Send end event for this content part
185+
yield ResponseContentPartDoneEvent(
186+
content_index=state.text_content_index_and_output[0],
187+
item_id=FAKE_RESPONSES_ID,
188+
output_index=0,
189+
part=state.text_content_index_and_output[1],
190+
type="response.content_part.done",
191+
)
192+
193+
if state.refusal_content_index_and_output:
194+
function_call_starting_index += 1
195+
# Send end event for this content part
196+
yield ResponseContentPartDoneEvent(
197+
content_index=state.refusal_content_index_and_output[0],
198+
item_id=FAKE_RESPONSES_ID,
199+
output_index=0,
200+
part=state.refusal_content_index_and_output[1],
201+
type="response.content_part.done",
202+
)
203+
204+
# Actually send events for the function calls
205+
for function_call in state.function_calls.values():
206+
# First, a ResponseOutputItemAdded for the function call
207+
yield ResponseOutputItemAddedEvent(
208+
item=ResponseFunctionToolCall(
209+
id=FAKE_RESPONSES_ID,
210+
call_id=function_call.call_id,
211+
arguments=function_call.arguments,
212+
name=function_call.name,
213+
type="function_call",
214+
),
215+
output_index=function_call_starting_index,
216+
type="response.output_item.added",
217+
)
218+
# Then, yield the args
219+
yield ResponseFunctionCallArgumentsDeltaEvent(
220+
delta=function_call.arguments,
221+
item_id=FAKE_RESPONSES_ID,
222+
output_index=function_call_starting_index,
223+
type="response.function_call_arguments.delta",
224+
)
225+
# Finally, the ResponseOutputItemDone
226+
yield ResponseOutputItemDoneEvent(
227+
item=ResponseFunctionToolCall(
228+
id=FAKE_RESPONSES_ID,
229+
call_id=function_call.call_id,
230+
arguments=function_call.arguments,
231+
name=function_call.name,
232+
type="function_call",
233+
),
234+
output_index=function_call_starting_index,
235+
type="response.output_item.done",
236+
)
237+
238+
# Finally, send the Response completed event
239+
outputs: list[ResponseOutputItem] = []
240+
if state.text_content_index_and_output or state.refusal_content_index_and_output:
241+
assistant_msg = ResponseOutputMessage(
242+
id=FAKE_RESPONSES_ID,
243+
content=[],
244+
role="assistant",
245+
type="message",
246+
status="completed",
247+
)
248+
if state.text_content_index_and_output:
249+
assistant_msg.content.append(state.text_content_index_and_output[1])
250+
if state.refusal_content_index_and_output:
251+
assistant_msg.content.append(state.refusal_content_index_and_output[1])
252+
outputs.append(assistant_msg)
253+
254+
# send a ResponseOutputItemDone for the assistant message
255+
yield ResponseOutputItemDoneEvent(
256+
item=assistant_msg,
257+
output_index=0,
258+
type="response.output_item.done",
259+
)
260+
261+
for function_call in state.function_calls.values():
262+
outputs.append(function_call)
263+
264+
final_response = response.model_copy()
265+
final_response.output = outputs
266+
final_response.usage = (
267+
ResponseUsage(
268+
input_tokens=usage.prompt_tokens,
269+
output_tokens=usage.completion_tokens,
270+
total_tokens=usage.total_tokens,
271+
output_tokens_details=OutputTokensDetails(
272+
reasoning_tokens=usage.completion_tokens_details.reasoning_tokens
273+
if usage.completion_tokens_details
274+
and usage.completion_tokens_details.reasoning_tokens
275+
else 0
276+
),
277+
input_tokens_details=InputTokensDetails(
278+
cached_tokens=usage.prompt_tokens_details.cached_tokens
279+
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens
280+
else 0
281+
),
282+
)
283+
if usage
284+
else None
285+
)
286+
287+
yield ResponseCompletedEvent(
288+
response=final_response,
289+
type="response.completed",
290+
)

0 commit comments

Comments
 (0)