Skip to content

Support returning multi-modal content from tools #1517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
67 changes: 64 additions & 3 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ There are a number of ways to register tools with an agent:
* via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext]
* via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool]

## Registering Function Tools via Decorator

`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent context.

Here's an example using both:
Expand Down Expand Up @@ -188,7 +190,7 @@ sequenceDiagram
Note over Agent: Game session complete
```

## Registering Function Tools via kwarg
## Registering Function Tools via Agent Argument

As well as using the decorators, we can register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to reuse tools, and can also give more fine-grained control over the tools.

Expand Down Expand Up @@ -244,6 +246,67 @@ print(dice_result['b'].output)

_(This example is complete, it can be run "as is")_

## Function Tool Output

Tools can return anything that Pydantic can serialize to JSON, as well as audio, video, image or document content depending on the types of [multi-modal input](input.md) the model supports:

```python {title="function_tool_output.py"}
from datetime import datetime

from pydantic import BaseModel

from pydantic_ai import Agent, DocumentUrl, ImageUrl
from pydantic_ai.models.openai import OpenAIResponsesModel


class User(BaseModel):
name: str
age: int


agent = Agent(model=OpenAIResponsesModel('gpt-4o'))


@agent.tool_plain
def get_current_time() -> datetime:
return datetime.now()


@agent.tool_plain
def get_user() -> User:
return User(name='John', age=30)


@agent.tool_plain
def get_company_logo() -> ImageUrl:
return ImageUrl(url='https://iili.io/3Hs4FMg.png')


@agent.tool_plain
def get_document() -> DocumentUrl:
return DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf')


result = agent.run_sync('What time is it?')
print(result.output)
#> The current time is 10:45 PM on April 17, 2025.

result = agent.run_sync('What is the user name?')
print(result.output)
#> The user's name is John.

result = agent.run_sync('What is the company name in the logo?')
print(result.output)
#> The company name in the logo is "Pydantic."

result = agent.run_sync('What is the main content of the document?')
print(result.output)
#> The document contains just the text "Dummy PDF file."
```
_(This example is complete, it can be run "as is")_

Some models (e.g. Gemini) natively support semi-structured return values, while some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.

## Function Tools vs. Structured Outputs

As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output.
Expand Down Expand Up @@ -307,8 +370,6 @@ agent.run_sync('hello', model=FunctionModel(print_schema))

_(This example is complete, it can be run "as is")_

The return type of tool can be anything which Pydantic can serialize to JSON as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.

If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object.

Here's an example where we use [`TestModel.last_model_request_parameters`][pydantic_ai.models.test.TestModel.last_model_request_parameters] to inspect the tool schema that would be passed to the model.
Expand Down
26 changes: 24 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
)


async def process_function_tools(
async def process_function_tools( # noqa C901
tool_calls: list[_messages.ToolCallPart],
output_tool_name: str | None,
output_tool_call_id: str | None,
Expand Down Expand Up @@ -662,6 +662,8 @@ async def process_function_tools(
if not calls_to_run:
return

user_parts: list[_messages.UserPromptPart] = []

# Run all tool tasks in parallel
results_by_index: dict[int, _messages.ModelRequestPart] = {}
with ctx.deps.tracer.start_as_current_span(
Expand All @@ -675,14 +677,32 @@ async def process_function_tools(
asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name)
for tool, call in calls_to_run
]

file_index = 1

pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
result = task.result()
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):

if isinstance(result, _messages.RetryPromptPart):
results_by_index[index] = result
elif isinstance(result, _messages.ToolReturnPart):
if isinstance(result.content, _messages.MultiModalContentTypes):
user_parts.append(
_messages.UserPromptPart(
content=[f'This is file {file_index}:', result.content],
timestamp=result.timestamp,
part_kind='user-prompt',
)
)

result.content = f'See file {file_index}.'
file_index += 1

results_by_index[index] = result
else:
assert_never(result)
Expand All @@ -692,6 +712,8 @@ async def process_function_tools(
for k in sorted(results_by_index):
output_parts.append(results_by_index[k])

output_parts.extend(user_parts)


async def _tool_from_mcp_server(
tool_name: str,
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def format(self) -> str:

UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent'

# Ideally this would be a Union of types, but Python 3.9 requires it to be a string, and strings don't work with `isinstance``.
MultiModalContentTypes = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)


def _document_format(media_type: str) -> DocumentFormat:
if media_type == 'application/pdf':
Expand Down
15 changes: 14 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,20 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]:
assert_never(message)
if instructions := self._get_instructions(messages):
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
return mistral_messages

# Post-process messages to insert fake assistant message after tool message if followed by user message
# to work around `Unexpected role 'user' after role 'tool'` error.
processed_messages: list[MistralMessages] = []
for i, current_message in enumerate(mistral_messages):
processed_messages.append(current_message)

if isinstance(current_message, MistralToolMessage) and i + 1 < len(mistral_messages):
next_message = mistral_messages[i + 1]
if isinstance(next_message, MistralUserMessage):
# Insert a dummy assistant message
processed_messages.append(MistralAssistantMessage(content=[MistralTextChunk(text='OK')]))

return processed_messages

def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
content: str | list[MistralContentChunk]
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Loading