Skip to content

Commit 0a9ec7f

Browse files
feat: add stateful jupyter notebook option and make responses API more reliable (#199)
1. Adds the option for a local jupyter kernel as source for python 2. Increases browser reliability by handling the case where the model accidentally outputs a call to functions.browser.search as opposed to browser.search as long as there is no conflicting function defined. 3. Fixed some additional Responses API bugs like correctly assigning message Ids and not sending python calls as reasoning_text
1 parent 683982d commit 0a9ec7f

File tree

4 files changed

+345
-40
lines changed

4 files changed

+345
-40
lines changed

gpt_oss/responses_api/api_server.py

Lines changed: 121 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def generate_response(
131131
dict[str, list[CodeInterpreterOutputLogs | CodeInterpreterOutputImage]]
132132
] = None,
133133
reasoning_ids: Optional[list[str]] = None,
134+
message_ids: Optional[list[str]] = None,
134135
treat_functions_python_as_builtin: bool = False,
135136
) -> ResponseObject:
136137
output = []
@@ -157,6 +158,7 @@ def generate_response(
157158
browser_tool_index = 0
158159
python_tool_index = 0
159160
reasoning_ids_iter = iter(reasoning_ids or [])
161+
message_ids_iter = iter(message_ids or [])
160162

161163
for entry in entries:
162164
entry_dict = entry.to_dict()
@@ -296,15 +298,22 @@ def generate_response(
296298
)
297299
)
298300

301+
message_id = next(message_ids_iter, None)
299302
output.append(
300303
Item(
304+
id=message_id,
301305
type="message",
302306
role="assistant",
303307
content=content,
304308
status="completed",
305309
)
306310
)
307311
elif entry_dict["channel"] == "analysis":
312+
if entry_dict.get("recipient"):
313+
continue
314+
author_dict = entry_dict.get("author") or {}
315+
if author_dict.get("role") and author_dict.get("role") != "assistant":
316+
continue
308317
summary = []
309318
content = [
310319
ReasoningTextContentItem(
@@ -374,6 +383,7 @@ def generate_response(
374383
)
375384

376385
class StreamResponsesEvents:
386+
BROWSER_RESERVED_FUNCTIONS = {"browser.search", "browser.open", "browser.find"}
377387
initial_tokens: list[int]
378388
tokens: list[int]
379389
output_tokens: list[int]
@@ -429,7 +439,48 @@ def __init__(
429439
] = {}
430440
self.reasoning_item_ids: list[str] = []
431441
self.current_reasoning_item_id: Optional[str] = None
442+
self.message_item_ids: list[str] = []
443+
self.current_message_item_id: Optional[str] = None
432444
self.functions_python_as_builtin = functions_python_as_builtin
445+
self.user_defined_function_names = {
446+
name
447+
for tool in (request_body.tools or [])
448+
for name in [getattr(tool, "name", None)]
449+
if getattr(tool, "type", None) == "function" and name
450+
}
451+
452+
def _resolve_browser_recipient(
453+
self, recipient: Optional[str]
454+
) -> tuple[Optional[str], bool]:
455+
if not self.use_browser_tool or not recipient:
456+
return (None, False)
457+
458+
if recipient.startswith("browser."):
459+
return (recipient, False)
460+
461+
if recipient.startswith("functions."):
462+
potential = recipient[len("functions.") :]
463+
if (
464+
potential in self.BROWSER_RESERVED_FUNCTIONS
465+
and potential not in self.user_defined_function_names
466+
):
467+
return (potential, True)
468+
469+
return (None, False)
470+
471+
def _ensure_message_item_id(self) -> str:
472+
if self.current_message_item_id is None:
473+
message_id = f"item_{uuid.uuid4().hex}"
474+
self.current_message_item_id = message_id
475+
self.message_item_ids.append(message_id)
476+
return self.current_message_item_id
477+
478+
def _ensure_reasoning_item_id(self) -> str:
479+
if self.current_reasoning_item_id is None:
480+
reasoning_id = f"rs_{uuid.uuid4().hex}"
481+
self.current_reasoning_item_id = reasoning_id
482+
self.reasoning_item_ids.append(reasoning_id)
483+
return self.current_reasoning_item_id
433484

434485
def _send_event(self, event: ResponseEvent):
435486
event.sequence_number = self.sequence_number
@@ -455,6 +506,7 @@ async def run(self):
455506
python_call_ids=self.python_call_ids,
456507
python_call_outputs=getattr(self, "python_call_outputs", None),
457508
reasoning_ids=self.reasoning_item_ids,
509+
message_ids=self.message_item_ids,
458510
treat_functions_python_as_builtin=self.functions_python_as_builtin,
459511
)
460512
initial_response.status = "in_progress"
@@ -508,8 +560,11 @@ async def run(self):
508560
previous_item = self.parser.messages[-1]
509561
if previous_item.recipient is not None:
510562
recipient = previous_item.recipient
563+
browser_recipient, _ = self._resolve_browser_recipient(
564+
recipient
565+
)
511566
if (
512-
not recipient.startswith("browser.")
567+
browser_recipient is None
513568
and not (
514569
recipient == "python"
515570
or (
@@ -542,28 +597,34 @@ async def run(self):
542597
),
543598
)
544599
)
545-
if previous_item.channel == "analysis":
546-
reasoning_id = self.current_reasoning_item_id
547-
if reasoning_id is None:
548-
reasoning_id = f"rs_{uuid.uuid4().hex}"
549-
self.reasoning_item_ids.append(reasoning_id)
550-
self.current_reasoning_item_id = reasoning_id
600+
if (
601+
previous_item.channel == "analysis"
602+
and previous_item.recipient is None
603+
):
604+
reasoning_id = (
605+
self.current_reasoning_item_id
606+
if self.current_reasoning_item_id is not None
607+
else self._ensure_reasoning_item_id()
608+
)
609+
reasoning_text = previous_item.content[0].text
551610
yield self._send_event(
552611
ResponseReasoningTextDone(
553612
type="response.reasoning_text.done",
554613
output_index=current_output_index,
555614
content_index=current_content_index,
556-
text=previous_item.content[0].text,
615+
item_id=reasoning_id,
616+
text=reasoning_text,
557617
)
558618
)
559619
yield self._send_event(
560620
ResponseContentPartDone(
561621
type="response.content_part.done",
562622
output_index=current_output_index,
563623
content_index=current_content_index,
624+
item_id=reasoning_id,
564625
part=ReasoningTextContentItem(
565626
type="reasoning_text",
566-
text=previous_item.content[0].text,
627+
text=reasoning_text,
567628
),
568629
)
569630
)
@@ -578,7 +639,7 @@ async def run(self):
578639
content=[
579640
ReasoningTextContentItem(
580641
type="reasoning_text",
581-
text=previous_item.content[0].text,
642+
text=reasoning_text,
582643
)
583644
],
584645
),
@@ -605,11 +666,17 @@ async def run(self):
605666
text=normalized_text,
606667
annotations=annotations,
607668
)
669+
message_id = (
670+
self.current_message_item_id
671+
if self.current_message_item_id is not None
672+
else self._ensure_message_item_id()
673+
)
608674
yield self._send_event(
609675
ResponseOutputTextDone(
610676
type="response.output_text.done",
611677
output_index=current_output_index,
612678
content_index=current_content_index,
679+
item_id=message_id,
613680
text=normalized_text,
614681
)
615682
)
@@ -618,6 +685,7 @@ async def run(self):
618685
type="response.content_part.done",
619686
output_index=current_output_index,
620687
content_index=current_content_index,
688+
item_id=message_id,
621689
part=text_content,
622690
)
623691
)
@@ -626,6 +694,7 @@ async def run(self):
626694
type="response.output_item.done",
627695
output_index=current_output_index,
628696
item=Item(
697+
id=message_id,
629698
type="message",
630699
role="assistant",
631700
content=[text_content],
@@ -634,6 +703,7 @@ async def run(self):
634703
)
635704
current_annotations = []
636705
current_output_text_content = ""
706+
self.current_message_item_id = None
637707

638708
if (
639709
self.parser.last_content_delta
@@ -642,18 +712,25 @@ async def run(self):
642712
):
643713
if not sent_output_item_added:
644714
sent_output_item_added = True
715+
message_id = self._ensure_message_item_id()
645716
yield self._send_event(
646717
ResponseOutputItemAdded(
647718
type="response.output_item.added",
648719
output_index=current_output_index,
649-
item=Item(type="message", role="assistant", content=[]),
720+
item=Item(
721+
id=message_id,
722+
type="message",
723+
role="assistant",
724+
content=[],
725+
),
650726
)
651727
)
652728
yield self._send_event(
653729
ResponseContentPartAdded(
654730
type="response.content_part.added",
655731
output_index=current_output_index,
656732
content_index=current_content_index,
733+
item_id=message_id,
657734
part=TextContentItem(type="output_text", text=""),
658735
)
659736
)
@@ -685,11 +762,13 @@ async def run(self):
685762
for a in new_annotations:
686763
current_annotations.append(a)
687764
citation = UrlCitation(**a)
765+
message_id = self._ensure_message_item_id()
688766
yield self._send_event(
689767
ResponseOutputTextAnnotationAdded(
690768
type="response.output_text.annotation.added",
691769
output_index=current_output_index,
692770
content_index=current_content_index,
771+
item_id=message_id,
693772
annotation_index=len(current_annotations),
694773
annotation=citation,
695774
)
@@ -699,11 +778,13 @@ async def run(self):
699778
should_send_output_text_delta = False
700779

701780
if should_send_output_text_delta:
781+
message_id = self._ensure_message_item_id()
702782
yield self._send_event(
703783
ResponseOutputTextDelta(
704784
type="response.output_text.delta",
705785
output_index=current_output_index,
706786
content_index=current_content_index,
787+
item_id=message_id,
707788
delta=output_delta_buffer,
708789
)
709790
)
@@ -717,9 +798,7 @@ async def run(self):
717798
):
718799
if not sent_output_item_added:
719800
sent_output_item_added = True
720-
reasoning_id = f"rs_{uuid.uuid4().hex}"
721-
self.current_reasoning_item_id = reasoning_id
722-
self.reasoning_item_ids.append(reasoning_id)
801+
reasoning_id = self._ensure_reasoning_item_id()
723802
yield self._send_event(
724803
ResponseOutputItemAdded(
725804
type="response.output_item.added",
@@ -737,16 +816,19 @@ async def run(self):
737816
type="response.content_part.added",
738817
output_index=current_output_index,
739818
content_index=current_content_index,
819+
item_id=reasoning_id,
740820
part=ReasoningTextContentItem(
741821
type="reasoning_text", text=""
742822
),
743823
)
744824
)
825+
reasoning_id = self._ensure_reasoning_item_id()
745826
yield self._send_event(
746827
ResponseReasoningTextDelta(
747828
type="response.reasoning_text.delta",
748829
output_index=current_output_index,
749830
content_index=current_content_index,
831+
item_id=reasoning_id,
750832
delta=self.parser.last_content_delta,
751833
)
752834
)
@@ -763,14 +845,20 @@ async def run(self):
763845
if next_tok in encoding.stop_tokens_for_assistant_actions():
764846
if len(self.parser.messages) > 0:
765847
last_message = self.parser.messages[-1]
766-
if (
767-
self.use_browser_tool
768-
and last_message.recipient is not None
769-
and last_message.recipient.startswith("browser.")
770-
):
771-
function_name = last_message.recipient[len("browser.") :]
848+
browser_recipient, is_browser_fallback = (
849+
self._resolve_browser_recipient(last_message.recipient)
850+
)
851+
if browser_recipient is not None and browser_tool is not None:
852+
message_for_browser = (
853+
last_message
854+
if not is_browser_fallback
855+
else last_message.with_recipient(browser_recipient)
856+
)
857+
function_name = browser_recipient[len("browser.") :]
772858
action = None
773-
parsed_args = browser_tool.process_arguments(last_message)
859+
parsed_args = browser_tool.process_arguments(
860+
message_for_browser
861+
)
774862
if function_name == "search":
775863
action = WebSearchActionSearch(
776864
type="search",
@@ -810,25 +898,27 @@ async def run(self):
810898
),
811899
)
812900
)
813-
yield self._send_event(
814-
ResponseWebSearchCallInProgress(
815-
type="response.web_search_call.in_progress",
816-
output_index=current_output_index,
817-
id=web_search_call_id,
818-
)
901+
yield self._send_event(
902+
ResponseWebSearchCallInProgress(
903+
type="response.web_search_call.in_progress",
904+
output_index=current_output_index,
905+
item_id=web_search_call_id,
819906
)
907+
)
820908

821909
async def run_tool():
822910
results = []
823-
async for msg in browser_tool.process(last_message):
911+
async for msg in browser_tool.process(
912+
message_for_browser
913+
):
824914
results.append(msg)
825915
return results
826916

827917
yield self._send_event(
828918
ResponseWebSearchCallSearching(
829919
type="response.web_search_call.searching",
830920
output_index=current_output_index,
831-
id=web_search_call_id,
921+
item_id=web_search_call_id,
832922
)
833923
)
834924
result = await run_tool()
@@ -852,7 +942,7 @@ async def run_tool():
852942
ResponseWebSearchCallCompleted(
853943
type="response.web_search_call.completed",
854944
output_index=current_output_index,
855-
id=web_search_call_id,
945+
item_id=web_search_call_id,
856946
)
857947
)
858948
yield self._send_event(
@@ -1030,6 +1120,7 @@ async def run_python_tool():
10301120
python_call_ids=self.python_call_ids,
10311121
python_call_outputs=self.python_call_outputs,
10321122
reasoning_ids=self.reasoning_item_ids,
1123+
message_ids=self.message_item_ids,
10331124
treat_functions_python_as_builtin=self.functions_python_as_builtin,
10341125
)
10351126
if self.store_callback and self.request_body.store:

gpt_oss/responses_api/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class ReasoningItem(BaseModel):
4343

4444

4545
class Item(BaseModel):
46+
id: Optional[str] = None
4647
type: Optional[Literal["message"]] = "message"
4748
role: Literal["user", "assistant", "system"]
4849
content: Union[list[TextContentItem], str]

0 commit comments

Comments
 (0)