Skip to content

Commit 9ca00a8

Browse files
committed
fix: addressing edge cases when resuming (continued)
1 parent f4cd704 commit 9ca00a8

File tree

3 files changed

+227
-34
lines changed

3 files changed

+227
-34
lines changed

src/agents/run.py

Lines changed: 102 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,13 @@ def prepare_input(
161161

162162
# On first call (when there are no generated items yet), include the original input
163163
if not generated_items:
164-
input_items.extend(ItemHelpers.input_to_new_input_list(original_input))
164+
# Normalize original_input items to ensure field names are in snake_case
165+
# (items from RunState deserialization may have camelCase)
166+
raw_input_list = ItemHelpers.input_to_new_input_list(original_input)
167+
# Filter out function_call items that don't have corresponding function_call_output
168+
# (API requires every function_call to have a function_call_output)
169+
filtered_input_list = AgentRunner._filter_incomplete_function_calls(raw_input_list)
170+
input_items.extend(AgentRunner._normalize_input_items(filtered_input_list))
165171

166172
# First, collect call_ids from tool_call_output_item items
167173
# (completed tool calls with outputs) and build a map of
@@ -737,8 +743,8 @@ async def run(
737743
original_user_input = run_state._original_input
738744
# Normalize items to remove top-level providerData (API doesn't accept it there)
739745
if isinstance(original_user_input, list):
740-
prepared_input: str | list[TResponseInputItem] = (
741-
AgentRunner._normalize_input_items(original_user_input)
746+
prepared_input: str | list[TResponseInputItem] = AgentRunner._normalize_input_items(
747+
original_user_input
742748
)
743749
else:
744750
prepared_input = original_user_input
@@ -833,8 +839,7 @@ async def run(
833839
if session is not None and generated_items:
834840
# Save tool_call_output_item items (the outputs)
835841
tool_output_items: list[RunItem] = [
836-
item for item in generated_items
837-
if item.type == "tool_call_output_item"
842+
item for item in generated_items if item.type == "tool_call_output_item"
838843
]
839844
# Also find and save the corresponding function_call items
840845
# (they might not be in session if the run was interrupted before saving)
@@ -995,7 +1000,7 @@ async def run(
9951000
)
9961001
if call_id in output_call_ids and item not in items_to_save:
9971002
items_to_save.append(item)
998-
1003+
9991004
# Don't save original_user_input again - it was already saved at the start
10001005
await self._save_result_to_session(session, [], items_to_save)
10011006

@@ -1369,9 +1374,12 @@ async def _start_streaming(
13691374
# state's input, causing duplicate items.
13701375
if run_state is not None:
13711376
# Resuming from state - normalize items to remove top-level providerData
1377+
# and filter incomplete function_call pairs
13721378
if isinstance(starting_input, list):
1379+
# Filter incomplete function_call pairs before normalizing
1380+
filtered = AgentRunner._filter_incomplete_function_calls(starting_input)
13731381
prepared_input: str | list[TResponseInputItem] = (
1374-
AgentRunner._normalize_input_items(starting_input)
1382+
AgentRunner._normalize_input_items(filtered)
13751383
)
13761384
else:
13771385
prepared_input = starting_input
@@ -2345,20 +2353,82 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
23452353

23462354
return run_config.model_provider.get_model(agent.model)
23472355

2356+
@staticmethod
2357+
def _filter_incomplete_function_calls(
2358+
items: list[TResponseInputItem],
2359+
) -> list[TResponseInputItem]:
2360+
"""Filter out function_call items that don't have corresponding function_call_output.
2361+
2362+
The OpenAI API requires every function_call in an assistant message to have a
2363+
corresponding function_call_output (tool message). This function ensures only
2364+
complete pairs are included to prevent API errors.
2365+
2366+
IMPORTANT: This only filters incomplete function_call items. All other items
2367+
(messages, complete function_call pairs, etc.) are preserved to maintain
2368+
conversation history integrity.
2369+
2370+
Args:
2371+
items: List of input items to filter
2372+
2373+
Returns:
2374+
Filtered list with only complete function_call pairs. All non-function_call
2375+
items and complete function_call pairs are preserved.
2376+
"""
2377+
# First pass: collect call_ids from function_call_output/function_call_result items
2378+
completed_call_ids: set[str] = set()
2379+
for item in items:
2380+
if isinstance(item, dict):
2381+
item_type = item.get("type")
2382+
# Handle both API format (function_call_output) and
2383+
# protocol format (function_call_result)
2384+
if item_type in ("function_call_output", "function_call_result"):
2385+
call_id = item.get("call_id") or item.get("callId")
2386+
if call_id and isinstance(call_id, str):
2387+
completed_call_ids.add(call_id)
2388+
2389+
# Second pass: only include function_call items that have corresponding outputs
2390+
filtered: list[TResponseInputItem] = []
2391+
for item in items:
2392+
if isinstance(item, dict):
2393+
item_type = item.get("type")
2394+
if item_type == "function_call":
2395+
call_id = item.get("call_id") or item.get("callId")
2396+
# Only include if there's a corresponding
2397+
# function_call_output/function_call_result
2398+
if call_id and call_id in completed_call_ids:
2399+
filtered.append(item)
2400+
else:
2401+
# Include all non-function_call items
2402+
filtered.append(item)
2403+
else:
2404+
# Include non-dict items as-is
2405+
filtered.append(item)
2406+
2407+
return filtered
2408+
23482409
@staticmethod
23492410
def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInputItem]:
2350-
"""Normalize input items by removing top-level providerData/provider_data.
2351-
2411+
"""Normalize input items by removing top-level providerData/provider_data
2412+
and normalizing field names (callId -> call_id).
2413+
23522414
The OpenAI API doesn't accept providerData at the top level of input items.
23532415
providerData should only be in content where it belongs. This function removes
23542416
top-level providerData while preserving it in content.
2355-
2417+
2418+
Also normalizes field names from camelCase (callId) to snake_case (call_id)
2419+
to match API expectations.
2420+
2421+
Normalizes item types: converts 'function_call_result' to 'function_call_output'
2422+
to match API expectations.
2423+
23562424
Args:
23572425
items: List of input items to normalize
2358-
2426+
23592427
Returns:
23602428
Normalized list of input items
23612429
"""
2430+
from .run_state import _normalize_field_names
2431+
23622432
normalized: list[TResponseInputItem] = []
23632433
for item in items:
23642434
if isinstance(item, dict):
@@ -2368,6 +2438,18 @@ def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInp
23682438
# The API doesn't accept providerData at the top level of input items
23692439
normalized_item.pop("providerData", None)
23702440
normalized_item.pop("provider_data", None)
2441+
# Normalize item type: API expects 'function_call_output',
2442+
# not 'function_call_result'
2443+
item_type = normalized_item.get("type")
2444+
if item_type == "function_call_result":
2445+
normalized_item["type"] = "function_call_output"
2446+
item_type = "function_call_output"
2447+
# Remove invalid fields based on item type
2448+
# function_call_output items should not have 'name' field
2449+
if item_type == "function_call_output":
2450+
normalized_item.pop("name", None)
2451+
# Normalize field names (callId -> call_id, responseId -> response_id)
2452+
normalized_item = _normalize_field_names(normalized_item)
23712453
normalized.append(cast(TResponseInputItem, normalized_item))
23722454
else:
23732455
# For non-dict items, keep as-is (they should already be in correct format)
@@ -2414,10 +2496,14 @@ async def _prepare_input_with_session(
24142496
f"Invalid `session_input_callback` value: {session_input_callback}. "
24152497
"Choose between `None` or a custom callable function."
24162498
)
2417-
2499+
2500+
# Filter incomplete function_call pairs before normalizing
2501+
# (API requires every function_call to have a function_call_output)
2502+
filtered = cls._filter_incomplete_function_calls(merged)
2503+
24182504
# Normalize items to remove top-level providerData and deduplicate by ID
2419-
normalized = cls._normalize_input_items(merged)
2420-
2505+
normalized = cls._normalize_input_items(filtered)
2506+
24212507
# Deduplicate items by ID to prevent sending duplicate items to the API
24222508
# This can happen when resuming from state and items are already in the session
24232509
seen_ids: set[str] = set()
@@ -2429,13 +2515,13 @@ async def _prepare_input_with_session(
24292515
item_id = cast(str | None, item.get("id"))
24302516
elif hasattr(item, "id"):
24312517
item_id = cast(str | None, getattr(item, "id", None))
2432-
2518+
24332519
# Only add items we haven't seen before (or items without IDs)
24342520
if item_id is None or item_id not in seen_ids:
24352521
deduplicated.append(item)
24362522
if item_id:
24372523
seen_ids.add(item_id)
2438-
2524+
24392525
return deduplicated
24402526

24412527
@classmethod

src/agents/run_state.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ class RunState(Generic[TContext, TAgent]):
4848
_current_turn: int = 0
4949
"""Current turn number in the conversation."""
5050

51-
_current_turn_persisted_item_count: int = 0
52-
"""Tracks how many generated run items from this turn were already persisted to session.
53-
54-
When saving to session, we slice off only new entries. When a turn is interrupted
55-
(e.g., awaiting tool approval) and later resumed, we rewind this counter before
56-
continuing so pending tool outputs still get stored.
57-
"""
58-
5951
_current_agent: TAgent | None = None
6052
"""The agent currently handling the conversation."""
6153

@@ -250,13 +242,63 @@ def to_json(self) -> dict[str, Any]:
250242
}
251243
model_responses.append(response_dict)
252244

245+
# Normalize and camelize originalInput if it's a list of items
246+
# Convert API format to protocol format to match TypeScript schema
247+
# Protocol expects function_call_result (not function_call_output)
248+
original_input_serialized = self._original_input
249+
if isinstance(original_input_serialized, list):
250+
# First pass: build a map of call_id -> function_call name
251+
# to help convert function_call_output to function_call_result
252+
call_id_to_name: dict[str, str] = {}
253+
for item in original_input_serialized:
254+
if isinstance(item, dict):
255+
item_type = item.get("type")
256+
call_id = item.get("call_id") or item.get("callId")
257+
name = item.get("name")
258+
if item_type == "function_call" and call_id and name:
259+
call_id_to_name[call_id] = name
260+
261+
normalized_items = []
262+
for item in original_input_serialized:
263+
if isinstance(item, dict):
264+
# Create a copy to avoid modifying the original
265+
normalized_item = dict(item)
266+
# Remove session/conversation metadata fields that shouldn't be in originalInput
267+
# These are not part of the input protocol schema
268+
normalized_item.pop("id", None)
269+
normalized_item.pop("created_at", None)
270+
# Remove top-level providerData/provider_data (protocol allows it but
271+
# we remove it for cleaner serialization)
272+
normalized_item.pop("providerData", None)
273+
normalized_item.pop("provider_data", None)
274+
# Convert API format to protocol format
275+
# API uses function_call_output, protocol uses function_call_result
276+
item_type = normalized_item.get("type")
277+
call_id = normalized_item.get("call_id") or normalized_item.get("callId")
278+
if item_type == "function_call_output":
279+
# Convert to protocol format: function_call_result
280+
normalized_item["type"] = "function_call_result"
281+
# Protocol format requires status field (default to 'completed')
282+
if "status" not in normalized_item:
283+
normalized_item["status"] = "completed"
284+
# Protocol format requires name field
285+
# Look it up from the corresponding function_call if missing
286+
if "name" not in normalized_item and call_id:
287+
normalized_item["name"] = call_id_to_name.get(call_id, "")
288+
# Normalize field names to camelCase for JSON (call_id -> callId)
289+
normalized_item = self._camelize_field_names(normalized_item)
290+
normalized_items.append(normalized_item)
291+
else:
292+
normalized_items.append(item)
293+
original_input_serialized = normalized_items
294+
253295
result = {
254296
"$schemaVersion": CURRENT_SCHEMA_VERSION,
255297
"currentTurn": self._current_turn,
256298
"currentAgent": {
257299
"name": self._current_agent.name,
258300
},
259-
"originalInput": self._original_input,
301+
"originalInput": original_input_serialized,
260302
"modelResponses": model_responses,
261303
"context": {
262304
"usage": {
@@ -345,7 +387,6 @@ def to_json(self) -> dict[str, Any]:
345387
if self._last_processed_response
346388
else None
347389
)
348-
result["currentTurnPersistedItemCount"] = self._current_turn_persisted_item_count
349390
result["trace"] = None
350391

351392
return result
@@ -571,18 +612,29 @@ async def from_string(
571612
context.usage = usage
572613
context._rebuild_approvals(context_data.get("approvals", {}))
573614

615+
# Normalize originalInput to remove providerData fields that may have been
616+
# included by TypeScript serialization. These fields are metadata and should
617+
# not be sent to the API.
618+
original_input_raw = state_json["originalInput"]
619+
if isinstance(original_input_raw, list):
620+
# Normalize each item in the list to remove providerData fields
621+
normalized_original_input = [
622+
_normalize_field_names(item) if isinstance(item, dict) else item
623+
for item in original_input_raw
624+
]
625+
else:
626+
# If it's a string, use it as-is
627+
normalized_original_input = original_input_raw
628+
574629
# Create the RunState instance
575630
state = RunState(
576631
context=context,
577-
original_input=state_json["originalInput"],
632+
original_input=normalized_original_input,
578633
starting_agent=current_agent,
579634
max_turns=state_json["maxTurns"],
580635
)
581636

582637
state._current_turn = state_json["currentTurn"]
583-
state._current_turn_persisted_item_count = state_json.get(
584-
"currentTurnPersistedItemCount", 0
585-
)
586638

587639
# Reconstruct model responses
588640
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))
@@ -679,18 +731,29 @@ async def from_json(
679731
context.usage = usage
680732
context._rebuild_approvals(context_data.get("approvals", {}))
681733

734+
# Normalize originalInput to remove providerData fields that may have been
735+
# included by TypeScript serialization. These fields are metadata and should
736+
# not be sent to the API.
737+
original_input_raw = state_json["originalInput"]
738+
if isinstance(original_input_raw, list):
739+
# Normalize each item in the list to remove providerData fields
740+
normalized_original_input = [
741+
_normalize_field_names(item) if isinstance(item, dict) else item
742+
for item in original_input_raw
743+
]
744+
else:
745+
# If it's a string, use it as-is
746+
normalized_original_input = original_input_raw
747+
682748
# Create the RunState instance
683749
state = RunState(
684750
context=context,
685-
original_input=state_json["originalInput"],
751+
original_input=normalized_original_input,
686752
starting_agent=current_agent,
687753
max_turns=state_json["maxTurns"],
688754
)
689755

690756
state._current_turn = state_json["currentTurn"]
691-
state._current_turn_persisted_item_count = state_json.get(
692-
"currentTurnPersistedItemCount", 0
693-
)
694757

695758
# Reconstruct model responses
696759
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))

0 commit comments

Comments
 (0)