Skip to content

Add support for callable output tools #1463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ async def _handle_tool_calls(
if output_schema is not None:
for call, output_tool in output_schema.find_tool(tool_calls):
try:
result_data = output_tool.validate(call)
result_data = await output_tool.execute(call)
result_data = await _validate_output(result_data, ctx, call)
except _output.ToolRetryError as e:
# TODO: Should only increment retry stuff once per node execution, not for each tool call
Expand Down
71 changes: 53 additions & 18 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""An invariant TypeVar."""


# TODO: Deprecate output validators in favor of ToolOutput with a call
@dataclass
class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv]
Expand Down Expand Up @@ -98,12 +99,14 @@ def build(
if output_type is str:
return None

call: Callable[..., T | Awaitable[T]] | None = None
if isinstance(output_type, ToolOutput):
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
name = output_type.name
description = output_type.description
output_type_ = output_type.output_type
strict = output_type.strict
call = output_type.call
else:
output_type_ = output_type

Expand All @@ -115,22 +118,33 @@ def build(

tools: dict[str, OutputSchemaTool[T]] = {}
if args := get_union_args(output_type_):
# Note: this will not be hit if output_type_ is a ToolOutput since get_union_args will return ()
for i, arg in enumerate(args, start=1):
tool_name = raw_tool_name = union_tool_name(name, arg)
while tool_name in tools:
tool_name = f'{raw_tool_name}_{i}'
tools[tool_name] = cast(
OutputSchemaTool[T],
OutputSchemaTool(
output_type=arg, name=tool_name, description=description, multiple=True, strict=strict
output_type=arg,
call=call,
name=tool_name,
description=description,
multiple=True,
strict=strict,
),
)
else:
name = name or DEFAULT_OUTPUT_TOOL_NAME
tools[name] = cast(
OutputSchemaTool[T],
OutputSchemaTool(
output_type=output_type_, name=name, description=description, multiple=False, strict=strict
output_type=output_type_,
call=call,
name=name,
description=description,
multiple=False,
strict=strict,
),
)

Expand Down Expand Up @@ -171,29 +185,39 @@ def tool_defs(self) -> list[ToolDefinition]:
class OutputSchemaTool(Generic[OutputDataT]):
tool_def: ToolDefinition
type_adapter: TypeAdapter[Any]
call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None

def __init__(
self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None
self,
*,
output_type: type[OutputDataT],
call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None,
name: str,
description: str | None,
multiple: bool,
strict: bool | None,
):
"""Build a OutputSchemaTool from a response type."""
if _utils.is_model_like(output_type):
self.call = call

outer_typed_dict_key: str | None = None
if call is not None:
self.type_adapter = TypeAdapter(call)
elif _utils.is_model_like(output_type):
self.type_adapter = TypeAdapter(output_type)
outer_typed_dict_key: str | None = None
# noinspection PyArgumentList
parameters_json_schema = _utils.check_object_json_schema(
self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
)
else:
response_data_typed_dict = TypedDict( # noqa: UP013
'response_data_typed_dict',
{'response': output_type}, # pyright: ignore[reportInvalidTypeForm]
)
self.type_adapter = TypeAdapter(response_data_typed_dict)
outer_typed_dict_key = 'response'
# noinspection PyArgumentList
parameters_json_schema = _utils.check_object_json_schema(
self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
)

# noinspection PyArgumentList
parameters_json_schema = _utils.check_object_json_schema(
self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
)
if outer_typed_dict_key is not None:
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
parameters_json_schema.pop('title')

Expand All @@ -215,14 +239,14 @@ def __init__(
strict=strict,
)

def validate(
async def execute(
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
) -> OutputDataT:
"""Validate an output message.
"""Execute the tool call. In the case that the `call` attribute is None, this just amounts to validation.

Args:
tool_call: The tool call from the LLM to validate.
allow_partial: If true, allow partial validation.
tool_call: The tool call from the LLM to execute.
allow_partial: If true, allow partial validation (prior to execution if there is a call).
wrap_validation_errors: If true, wrap the validation errors in a retry message.

Returns:
Expand All @@ -234,6 +258,13 @@ def validate(
output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
else:
output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
except ModelRetry as e:
m = _messages.RetryPromptPart(
tool_name=tool_call.tool_name,
content=e.message,
tool_call_id=tool_call.tool_call_id,
)
raise ToolRetryError(m) from e
except ValidationError as e:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
Expand All @@ -247,7 +278,11 @@ def validate(
else:
if k := self.tool_def.outer_typed_dict_key:
output = output[k]
return output

if self.call and inspect.isawaitable(output):
# The check for `self.call` is just there to skip the `isawaitable` check when no call is present
output = await output
return output


def union_tool_name(base_name: str | None, union_arg: Any) -> str:
Expand Down
59 changes: 43 additions & 16 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import copy
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Generic, Union, cast
from typing import TYPE_CHECKING, Generic, Union, cast, get_type_hints

from typing_extensions import TypeVar, assert_type, deprecated, overload

Expand Down Expand Up @@ -60,38 +60,65 @@ class ToolOutput(Generic[OutputDataT]):
"""Marker class to use tools for structured outputs, and customize the tool."""

output_type: type[OutputDataT]
# TODO: Add `output_call` support, for calling a function to get the output
# output_call: Callable[..., OutputDataT] | None
call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None
name: str
description: str | None
max_retries: int | None
strict: bool | None

@overload
def __init__(
self,
*,
type_: type[OutputDataT],
# call: Callable[..., OutputDataT] | None = None,
name: str = 'final_result',
description: str | None = None,
max_retries: int | None = None,
strict: bool | None = None,
) -> None: ...

@overload
def __init__(
self,
*,
call: Callable[..., OutputDataT | Awaitable[OutputDataT]],
type_: type[OutputDataT] | _utils.Unset = _utils.UNSET,
name: str = 'final_result',
description: str | None = None,
max_retries: int | None = None,
strict: bool | None = None,
) -> None: ...

def __init__(
self,
*,
type_: type[OutputDataT] | _utils.Unset = _utils.UNSET,
call: Callable[..., OutputDataT | Awaitable[OutputDataT]] | None = None,
name: str = 'final_result',
description: str | None = None,
max_retries: int | None = None,
strict: bool | None = None,
):
self.output_type = type_
self.name = name
self.description = description
self.max_retries = max_retries
self.strict = strict

# TODO: add support for call and make type_ optional, with the following logic:
# if type_ is None and call is None:
# raise ValueError('Either type_ or call must be provided')
# if call is not None:
# if type_ is None:
# type_ = get_type_hints(call).get('return')
# if type_ is None:
# raise ValueError('Unable to determine type_ from call signature; please provide it explicitly')
# self.output_call = call
if not _utils.is_set(type_):
if call is None:
raise ValueError('Either type_ or call must be provided')
else:
try:
type_ = get_type_hints(call).get('return', _utils.UNSET)
if not _utils.is_set(type_):
raise ValueError('Unable to determine type_ from call signature; please provide it explicitly')
except Exception as e:
raise ValueError(
'Unable to determine type_ from call signature; please provide it explicitly'
) from e

self.output_type = type_
self.call = call


@dataclass
Expand Down Expand Up @@ -152,7 +179,7 @@ async def _validate_response(
)

call, output_tool = match
result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
result_data = await output_tool.execute(call, allow_partial=allow_partial, wrap_validation_errors=False)

for validator in self._output_validators:
result_data = await validator.validate(result_data, call, self._run_ctx)
Expand Down Expand Up @@ -466,7 +493,7 @@ async def validate_structured_output(
)

call, output_tool = match
result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
result_data = await output_tool.execute(call, allow_partial=allow_partial, wrap_validation_errors=False)

for validator in self._output_validators:
result_data = await validator.validate(result_data, call, self._run_ctx)
Expand Down
Loading