diff --git a/src/agents/models/openai_provider.py b/src/agents/models/openai_provider.py index 91f2366bc..91eeaccc8 100644 --- a/src/agents/models/openai_provider.py +++ b/src/agents/models/openai_provider.py @@ -81,13 +81,17 @@ def _get_client(self) -> AsyncOpenAI: return self._client def get_model(self, model_name: str | None) -> Model: - if model_name is None: - model_name = get_default_model() + model_is_explicit = model_name is not None + resolved_model_name = model_name if model_name is not None else get_default_model() client = self._get_client() return ( - OpenAIResponsesModel(model=model_name, openai_client=client) + OpenAIResponsesModel( + model=resolved_model_name, + openai_client=client, + model_is_explicit=model_is_explicit, + ) if self._use_responses - else OpenAIChatCompletionsModel(model=model_name, openai_client=client) + else OpenAIChatCompletionsModel(model=resolved_model_name, openai_client=client) ) diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 4588937cb..a8695c89c 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -67,8 +67,11 @@ def __init__( self, model: str | ChatModel, openai_client: AsyncOpenAI, + *, + model_is_explicit: bool = True, ) -> None: self.model = model + self._model_is_explicit = model_is_explicit self._client = openai_client def _non_null_or_omit(self, value: Any) -> Any: @@ -262,6 +265,12 @@ async def _fetch_response( converted_tools = Converter.convert_tools(tools, handoffs) converted_tools_payload = _to_dump_compatible(converted_tools.tools) response_format = Converter.get_response_format(output_schema) + should_omit_model = prompt is not None and not self._model_is_explicit + model_param: str | ChatModel | Omit = self.model if not should_omit_model else omit + should_omit_tools = prompt is not None and len(converted_tools_payload) == 0 + tools_param: list[ToolParam] | Omit = ( + converted_tools_payload if not should_omit_tools else omit + ) include_set: set[str] = set(converted_tools.includes) if model_settings.response_include is not None: @@ -309,10 +318,10 @@ async def _fetch_response( previous_response_id=self._non_null_or_omit(previous_response_id), conversation=self._non_null_or_omit(conversation_id), instructions=self._non_null_or_omit(system_instructions), - model=self.model, + model=model_param, input=list_input, include=include, - tools=converted_tools_payload, + tools=tools_param, prompt=self._non_null_or_omit(prompt), temperature=self._non_null_or_omit(model_settings.temperature), top_p=self._non_null_or_omit(model_settings.top_p), diff --git a/tests/test_agent_prompt.py b/tests/test_agent_prompt.py index 3d5ed5a3f..e3ed40fbe 100644 --- a/tests/test_agent_prompt.py +++ b/tests/test_agent_prompt.py @@ -1,8 +1,13 @@ +from __future__ import annotations + import pytest +from openai import omit -from agents import Agent, Prompt, RunContextWrapper, Runner +from agents import Agent, Prompt, RunConfig, RunContextWrapper, Runner +from agents.models.interface import Model, ModelProvider +from agents.models.openai_responses import OpenAIResponsesModel -from .fake_model import FakeModel +from .fake_model import FakeModel, get_response_obj from .test_responses import get_text_message @@ -97,3 +102,43 @@ async def test_prompt_is_passed_to_model(): "variables": None, } assert model.last_prompt == expected_prompt + + +class _SingleModelProvider(ModelProvider): + def __init__(self, model: Model): + self._model = model + + def get_model(self, model_name: str | None) -> Model: + return self._model + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_agent_prompt_with_default_model_omits_model_and_tools_parameters(): + called_kwargs: dict[str, object] = {} + + class DummyResponses: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + return get_response_obj([get_text_message("done")]) + + class DummyResponsesClient: + def __init__(self): + self.responses = DummyResponses() + + model = OpenAIResponsesModel( + model="gpt-4.1", + openai_client=DummyResponsesClient(), # type: ignore[arg-type] + model_is_explicit=False, + ) + + run_config = RunConfig(model_provider=_SingleModelProvider(model)) + agent = Agent(name="prompt-agent", prompt={"id": "pmpt_agent"}) + + await Runner.run(agent, input="hi", run_config=run_config) + + expected_prompt = {"id": "pmpt_agent", "version": None, "variables": None} + assert called_kwargs["prompt"] == expected_prompt + assert called_kwargs["model"] is omit + assert called_kwargs["tools"] is omit diff --git a/tests/test_openai_responses.py b/tests/test_openai_responses.py index 0823d3cac..ecd509ac6 100644 --- a/tests/test_openai_responses.py +++ b/tests/test_openai_responses.py @@ -3,6 +3,7 @@ from typing import Any import pytest +from openai import omit from openai.types.responses import ResponseCompletedEvent from agents import ModelSettings, ModelTracing, __version__ @@ -63,3 +64,74 @@ def __init__(self): assert "extra_headers" in called_kwargs assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_prompt_id_omits_model_parameter(): + called_kwargs: dict[str, Any] = {} + + class DummyResponses: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + return get_response_obj([]) + + class DummyResponsesClient: + def __init__(self): + self.responses = DummyResponses() + + model = OpenAIResponsesModel( + model="gpt-4", + openai_client=DummyResponsesClient(), # type: ignore[arg-type] + model_is_explicit=False, + ) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + prompt={"id": "pmpt_123"}, + ) + + assert called_kwargs["prompt"] == {"id": "pmpt_123"} + assert called_kwargs["model"] is omit + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_prompt_id_omits_tools_parameter_when_no_tools_configured(): + called_kwargs: dict[str, Any] = {} + + class DummyResponses: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + return get_response_obj([]) + + class DummyResponsesClient: + def __init__(self): + self.responses = DummyResponses() + + model = OpenAIResponsesModel( + model="gpt-4", + openai_client=DummyResponsesClient(), # type: ignore[arg-type] + model_is_explicit=False, + ) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + prompt={"id": "pmpt_123"}, + ) + + assert called_kwargs["tools"] is omit