From 342266935ee83d350bbbcf844bdd4f57f3b3b9ce Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sat, 12 Apr 2025 12:09:12 -0600 Subject: [PATCH] Add support for callable output tools --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/_output.py | 71 +++++++++++++++----- pydantic_ai_slim/pydantic_ai/result.py | 59 +++++++++++----- 3 files changed, 97 insertions(+), 35 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 9cb436f86..80a28f51c 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -438,7 +438,7 @@ async def _handle_tool_calls( if output_schema is not None: for call, output_tool in output_schema.find_tool(tool_calls): try: - result_data = output_tool.validate(call) + result_data = await output_tool.execute(call) result_data = await _validate_output(result_data, ctx, call) except _output.ToolRetryError as e: # TODO: Should only increment retry stuff once per node execution, not for each tool call diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 2f10d3fe3..73345dc83 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -19,6 +19,7 @@ """An invariant TypeVar.""" +# TODO: Deprecate output validators in favor of ToolOutput with a call @dataclass class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]): function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv] @@ -98,12 +99,14 @@ def build( if output_type is str: return None + call: Callable[..., T | Awaitable[T]] | None = None if isinstance(output_type, ToolOutput): # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads name = output_type.name description = output_type.description output_type_ = output_type.output_type strict = output_type.strict + call = output_type.call else: output_type_ = output_type @@ -115,6 +118,7 @@ def build( tools: dict[str, OutputSchemaTool[T]] = {} if args := get_union_args(output_type_): + # Note: this will not be hit if output_type_ is a ToolOutput since get_union_args will return () for i, arg in enumerate(args, start=1): tool_name = raw_tool_name = union_tool_name(name, arg) while tool_name in tools: @@ -122,7 +126,12 @@ def build( tools[tool_name] = cast( OutputSchemaTool[T], OutputSchemaTool( - output_type=arg, name=tool_name, description=description, multiple=True, strict=strict + output_type=arg, + call=call, + name=tool_name, + description=description, + multiple=True, + strict=strict, ), ) else: @@ -130,7 +139,12 @@ def build( tools[name] = cast( OutputSchemaTool[T], OutputSchemaTool( - output_type=output_type_, name=name, description=description, multiple=False, strict=strict + output_type=output_type_, + call=call, + name=name, + description=description, + multiple=False, + strict=strict, ), ) @@ -171,18 +185,26 @@ def tool_defs(self) -> list[ToolDefinition]: class OutputSchemaTool(Generic[OutputDataT]): tool_def: ToolDefinition type_adapter: TypeAdapter[Any] + call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None def __init__( - self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None + self, + *, + output_type: type[OutputDataT], + call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None, + name: str, + description: str | None, + multiple: bool, + strict: bool | None, ): """Build a OutputSchemaTool from a response type.""" - if _utils.is_model_like(output_type): + self.call = call + + outer_typed_dict_key: str | None = None + if call is not None: + self.type_adapter = TypeAdapter(call) + elif _utils.is_model_like(output_type): self.type_adapter = TypeAdapter(output_type) - outer_typed_dict_key: str | None = None - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) else: response_data_typed_dict = TypedDict( # noqa: UP013 'response_data_typed_dict', @@ -190,10 +212,12 @@ def __init__( ) self.type_adapter = TypeAdapter(response_data_typed_dict) outer_typed_dict_key = 'response' - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) + + # noinspection PyArgumentList + parameters_json_schema = _utils.check_object_json_schema( + self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) + ) + if outer_typed_dict_key is not None: # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM parameters_json_schema.pop('title') @@ -215,14 +239,14 @@ def __init__( strict=strict, ) - def validate( + async def execute( self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True ) -> OutputDataT: - """Validate an output message. + """Execute the tool call. In the case that the `call` attribute is None, this just amounts to validation. Args: - tool_call: The tool call from the LLM to validate. - allow_partial: If true, allow partial validation. + tool_call: The tool call from the LLM to execute. + allow_partial: If true, allow partial validation (prior to execution if there is a call). wrap_validation_errors: If true, wrap the validation errors in a retry message. Returns: @@ -234,6 +258,13 @@ def validate( output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial) else: output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial) + except ModelRetry as e: + m = _messages.RetryPromptPart( + tool_name=tool_call.tool_name, + content=e.message, + tool_call_id=tool_call.tool_call_id, + ) + raise ToolRetryError(m) from e except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -247,7 +278,11 @@ def validate( else: if k := self.tool_def.outer_typed_dict_key: output = output[k] - return output + + if self.call and inspect.isawaitable(output): + # The check for `self.call` is just there to skip the `isawaitable` check when no call is present + output = await output + return output def union_tool_name(base_name: str | None, union_arg: Any) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index d074f3042..d1f87d046 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,7 +5,7 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Union, cast +from typing import TYPE_CHECKING, Generic, Union, cast, get_type_hints from typing_extensions import TypeVar, assert_type, deprecated, overload @@ -60,38 +60,65 @@ class ToolOutput(Generic[OutputDataT]): """Marker class to use tools for structured outputs, and customize the tool.""" output_type: type[OutputDataT] - # TODO: Add `output_call` support, for calling a function to get the output - # output_call: Callable[..., OutputDataT] | None + call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None name: str description: str | None max_retries: int | None strict: bool | None + @overload def __init__( self, *, type_: type[OutputDataT], - # call: Callable[..., OutputDataT] | None = None, + name: str = 'final_result', + description: str | None = None, + max_retries: int | None = None, + strict: bool | None = None, + ) -> None: ... + + @overload + def __init__( + self, + *, + call: Callable[..., OutputDataT | Awaitable[OutputDataT]], + type_: type[OutputDataT] | _utils.Unset = _utils.UNSET, + name: str = 'final_result', + description: str | None = None, + max_retries: int | None = None, + strict: bool | None = None, + ) -> None: ... + + def __init__( + self, + *, + type_: type[OutputDataT] | _utils.Unset = _utils.UNSET, + call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None = None, name: str = 'final_result', description: str | None = None, max_retries: int | None = None, strict: bool | None = None, ): - self.output_type = type_ self.name = name self.description = description self.max_retries = max_retries self.strict = strict - # TODO: add support for call and make type_ optional, with the following logic: - # if type_ is None and call is None: - # raise ValueError('Either type_ or call must be provided') - # if call is not None: - # if type_ is None: - # type_ = get_type_hints(call).get('return') - # if type_ is None: - # raise ValueError('Unable to determine type_ from call signature; please provide it explicitly') - # self.output_call = call + if not _utils.is_set(type_): + if call is None: + raise ValueError('Either type_ or call must be provided') + else: + try: + type_ = get_type_hints(call).get('return', _utils.UNSET) + if not _utils.is_set(type_): + raise ValueError('Unable to determine type_ from call signature; please provide it explicitly') + except Exception as e: + raise ValueError( + 'Unable to determine type_ from call signature; please provide it explicitly' + ) from e + + self.output_type = type_ + self.call = call @dataclass @@ -152,7 +179,7 @@ async def _validate_response( ) call, output_tool = match - result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) + result_data = await output_tool.execute(call, allow_partial=allow_partial, wrap_validation_errors=False) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) @@ -466,7 +493,7 @@ async def validate_structured_output( ) call, output_tool = match - result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) + result_data = await output_tool.execute(call, allow_partial=allow_partial, wrap_validation_errors=False) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx)