diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 3af577d0a42f0..987b3be89b8fe 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -13,7 +13,6 @@ get_type_hints, ) -from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage from langchain_core.tools import BaseTool from langgraph._internal._runnable import RunnableCallable @@ -42,6 +41,7 @@ ResponseFormat, StructuredOutputValidationError, ToolStrategy, + _supports_provider_strategy, ) from langchain.chat_models import init_chat_model from langchain.tools.tool_node import ToolCallWithContext, _ToolNode @@ -49,6 +49,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Sequence + from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables import Runnable from langgraph.cache.base import BaseCache from langgraph.graph.state import CompiledStateGraph @@ -347,29 +348,6 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l return [] -def _supports_provider_strategy(model: str | BaseChatModel) -> bool: - """Check if a model supports provider-specific structured output. - - Args: - model: Model name string or `BaseChatModel` instance. - - Returns: - `True` if the model supports provider-specific structured output, `False` otherwise. - """ - model_name: str | None = None - if isinstance(model, str): - model_name = model - elif isinstance(model, BaseChatModel): - model_name = getattr(model, "model_name", None) - - return ( - "grok" in model_name.lower() - or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"]) - if model_name - else False - ) - - def _handle_structured_output_error( exception: Exception, response_format: ResponseFormat, @@ -932,16 +910,34 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | # Determine effective response format (auto-detect if needed) effective_response_format: ResponseFormat | None + model_name: str = cast( + "str", + ( + request.model + if isinstance(request.model, str) + else getattr(request.model, "model_name", "") + ), + ) if isinstance(request.response_format, AutoStrategy): # User provided raw schema via AutoStrategy - auto-detect best strategy based on model - if _supports_provider_strategy(request.model): + if _supports_provider_strategy(model_name): # Model supports provider strategy - use it effective_response_format = ProviderStrategy(schema=request.response_format.schema) else: # Model doesn't support provider strategy - use ToolStrategy effective_response_format = ToolStrategy(schema=request.response_format.schema) + elif isinstance(request.response_format, ProviderStrategy): + if not _supports_provider_strategy(model_name): + msg = ( + f"Cannot use ProviderStrategy with {model_name}. " + "Supported models: OpenAI (gpt-5, gpt-4.1, gpt-oss, o3-pro, o3-mini), " + "X.AI (Grok). " + "Consider using a raw schema (which auto-selects the best strategy) or " + "explicitly use `ToolStrategy` for unsupported providers." + ) + raise ValueError(msg) + effective_response_format = request.response_format else: - # User explicitly specified a strategy - preserve it effective_response_format = request.response_format # Build final tools list including structured output tools @@ -957,12 +953,9 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | if isinstance(effective_response_format, ProviderStrategy): # Use provider-specific structured output kwargs = effective_response_format.to_model_kwargs() - return ( - request.model.bind_tools( - final_tools, strict=True, **kwargs, **request.model_settings - ), - effective_response_format, - ) + return request.model.bind_tools( + final_tools, **kwargs, **request.model_settings + ), effective_response_format if isinstance(effective_response_format, ToolStrategy): # Current implementation requires that tools used for structured output diff --git a/libs/langchain_v1/langchain/agents/structured_output.py b/libs/langchain_v1/langchain/agents/structured_output.py index cd6a2fd9aed31..02d5f44a84d8f 100644 --- a/libs/langchain_v1/langchain/agents/structured_output.py +++ b/libs/langchain_v1/langchain/agents/structured_output.py @@ -31,6 +31,23 @@ SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"] +def _supports_provider_strategy(model_name: str) -> bool: + """Check if a model supports provider-specific structured output. + + Args: + model_name: Model name string. + + Returns: + `True` if the model supports provider-specific structured output, `False` otherwise. + """ + return ( + "grok" in model_name.lower() + or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"]) + if model_name + else False + ) + + class StructuredOutputError(Exception): """Base class for structured output errors.""" @@ -238,7 +255,56 @@ def _iter_variants(schema: Any) -> Iterable[Any]: @dataclass(init=False) class ProviderStrategy(Generic[SchemaT]): - """Use the model provider's native structured output method.""" + """Use the model provider's native structured output method. + + `ProviderStrategy` uses provider-specific structured output APIs that enforce + JSON schema validation at the model level. This provides stronger guarantees + than tool-based approaches but is only supported by certain providers. + + Supported Providers: + - **OpenAI**: All models that support structured outputs (requires `strict=True`) + - **X.AI (Grok)**: All models that support structured outputs (requires `strict=True`) + + Important: + When using `ProviderStrategy`, the agent will validate at runtime that the + model provider is supported. If you're using an unsupported provider, consider: + + - Using a **raw schema** (recommended): Automatically selects the best strategy + based on model capabilities + - Using **`ToolStrategy`**: Explicitly use tool-based structured output for any + provider + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.structured_output import ProviderStrategy + from pydantic import BaseModel + + + class WeatherResponse(BaseModel): + temperature: float + condition: str + + + # Explicitly use provider strategy (only for OpenAI/Grok) + agent = create_agent( + model="openai:gpt-4", tools=[], response_format=ProviderStrategy(WeatherResponse) + ) + + # Or use raw schema for automatic strategy selection (recommended) + # This will auto-select ProviderStrategy for OpenAI/Grok, ToolStrategy for others + agent = create_agent( + model="openai:gpt-4", + tools=[], + response_format=WeatherResponse, # Auto-selects best strategy + ) + ``` + + Note: + `ProviderStrategy` can be used with middleware that changes the model at runtime. + Validation occurs after the model is resolved, allowing dynamic model selection + while ensuring provider compatibility. + """ schema: type[SchemaT] """Schema for native mode.""" @@ -255,9 +321,19 @@ def __init__( self.schema_spec = _SchemaSpec(schema) def to_model_kwargs(self) -> dict[str, Any]: - """Convert to kwargs to bind to a model to force structured output.""" - # OpenAI: - # - see https://platform.openai.com/docs/guides/structured-outputs + """Convert to kwargs to bind to a model to force structured output. + + Args: + model: The model instance to check provider for conditional `strict` param. + + Returns: + Model kwargs with `response_format` and optionally `strict`. + """ + # Provider-specific structured output: + # - OpenAI: https://platform.openai.com/docs/guides/structured-outputs + # - Uses strict=True for schema validation + # - X.AI (Grok): https://docs.x.ai/docs/guides/structured-outputs + # - Uses strict=True for schema validation (required) response_format = { "type": "json_schema", "json_schema": { @@ -265,7 +341,8 @@ def to_model_kwargs(self) -> dict[str, Any]: "schema": self.schema_spec.json_schema, }, } - return {"response_format": response_format} + + return {"response_format": response_format, "strict": True} @dataclass diff --git a/libs/langchain_v1/tests/unit_tests/agents/model.py b/libs/langchain_v1/tests/unit_tests/agents/model.py index 07ed23995eb26..c56a81bcb38cb 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/model.py +++ b/libs/langchain_v1/tests/unit_tests/agents/model.py @@ -11,6 +11,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.language_models.base import LangSmithParams from langchain_core.messages import ( AIMessage, BaseMessage, @@ -29,6 +30,7 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]): structured_response: StructuredResponseT | None = None index: int = 0 tool_style: Literal["openai", "anthropic"] = "openai" + model_name: str = "fake-model" def _generate( self, diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py index a7963ced16f57..0f8fa6d0122df 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py @@ -619,7 +619,81 @@ def test_pydantic_model(self) -> None: ] model = FakeToolCallingModel[WeatherBaseModel]( - tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC + tool_calls=tool_calls, + structured_response=EXPECTED_WEATHER_PYDANTIC, + model_name="gpt-4.1", + ) + + agent = create_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + assert len(response["messages"]) == 4 + + def test_unsupported_model_raises_error(self) -> None: + """Test that ProviderStrategy raises ValueError for unsupported models.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + # Use a model name that doesn't support provider strategy + model = FakeToolCallingModel[WeatherBaseModel]( + tool_calls=tool_calls, + structured_response=EXPECTED_WEATHER_PYDANTIC, + model_name="claude-3-5-sonnet", + ) + + agent = create_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel) + ) + + with pytest.raises( + ValueError, + match=( + r"Cannot use ProviderStrategy with claude-3-5-sonnet\. " + r"Supported models: OpenAI \(gpt-5, gpt-4\.1, gpt-oss, o3-pro, o3-mini\), " + r"X\.AI \(Grok\)\. " + r"Consider using a raw schema \(which auto-selects the best strategy\) or " + r"explicitly use `ToolStrategy` for unsupported providers\." + ), + ): + agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + def test_supported_openai_models(self) -> None: + """Test that ProviderStrategy works with all supported OpenAI model variants.""" + supported_models = ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"] + + for model_name in supported_models: + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + model = FakeToolCallingModel[WeatherBaseModel]( + tool_calls=tool_calls, + structured_response=EXPECTED_WEATHER_PYDANTIC, + model_name=model_name, + ) + + agent = create_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel) + ) + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + assert len(response["messages"]) == 4 + + def test_supported_grok_model(self) -> None: + """Test that ProviderStrategy works with Grok models.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + model = FakeToolCallingModel[WeatherBaseModel]( + tool_calls=tool_calls, + structured_response=EXPECTED_WEATHER_PYDANTIC, + model_name="grok-beta", ) agent = create_agent( @@ -637,7 +711,9 @@ def test_dataclass(self) -> None: ] model = FakeToolCallingModel[WeatherDataclass]( - tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS + tool_calls=tool_calls, + structured_response=EXPECTED_WEATHER_DATACLASS, + model_name="gpt-4.1", ) agent = create_agent( @@ -657,7 +733,7 @@ def test_typed_dict(self) -> None: ] model = FakeToolCallingModel[WeatherTypedDict]( - tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT + tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1" ) agent = create_agent( @@ -675,7 +751,7 @@ def test_json_schema(self) -> None: ] model = FakeToolCallingModel[dict]( - tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT + tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1" ) agent = create_agent( @@ -697,13 +773,13 @@ def test_middleware_model_swap_provider_to_tool_strategy(self) -> None: on the middleware-modified model (not the original), ensuring the correct strategy is selected based on the final model's capabilities. """ - from unittest.mock import patch from langchain.agents.middleware.types import AgentMiddleware, ModelRequest from langchain_core.language_models.fake_chat_models import GenericFakeChatModel - # Custom model that we'll use to test whether the tool strategy is applied - # correctly at runtime. + # Custom model that we'll use to test whether the provider strategy is applied + # correctly at runtime. Use a model_name that supports provider strategy. class CustomModel(GenericFakeChatModel): + model_name: str = "gpt-4.1" tool_bindings: list[Any] = [] def bind_tools( @@ -736,14 +812,6 @@ def wrap_model_call( request.model = model return handler(request) - # Track which model is checked for provider strategy support - calls = [] - - def mock_supports_provider_strategy(model) -> bool: - """Track which model is checked and return True for ProviderStrategy.""" - calls.append(model) - return True - # Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy) # This should auto-detect strategy based on model capabilities agent = create_agent( @@ -754,14 +822,7 @@ def mock_supports_provider_strategy(model) -> bool: middleware=[ModelSwappingMiddleware()], ) - with patch( - "langchain.agents.factory._supports_provider_strategy", - side_effect=mock_supports_provider_strategy, - ): - response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) - - # Verify strategy resolution was deferred: check was called once during _get_bound_model - assert len(calls) == 1 + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) # Verify successful parsing of JSON as structured output via ProviderStrategy assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC