Skip to content

Commit 2d3dec7

Browse files
authored
Merge branch 'main' into multimodal-tool-output
2 parents c1c7392 + 35d6ed4 commit 2d3dec7

29 files changed

+523
-113
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,4 @@ async def __anext__(self) -> T:
291291

292292

293293
def get_traceparent(x: AgentRun | AgentRunResult | GraphRun | GraphRunResult) -> str:
294-
import logfire
295-
import logfire_api
296-
from logfire.experimental.annotations import get_traceparent
297-
298-
span: AbstractSpan | None = x._span(required=False) # type: ignore[reportPrivateUsage]
299-
if not span: # pragma: no cover
300-
return ''
301-
if isinstance(span, logfire_api.LogfireSpan): # pragma: no cover
302-
assert isinstance(span, logfire.LogfireSpan)
303-
return get_traceparent(span)
294+
return x._traceparent(required=False) or '' # type: ignore[reportPrivateUsage]

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
result,
2828
usage as _usage,
2929
)
30-
from ._utils import AbstractSpan
3130
from .models.instrumented import InstrumentationSettings, InstrumentedModel
3231
from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput
3332
from .settings import ModelSettings, merge_model_settings
@@ -1683,14 +1682,14 @@ async def main():
16831682
]
16841683

16851684
@overload
1686-
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
1685+
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
16871686
@overload
1688-
def _span(self) -> AbstractSpan: ...
1689-
def _span(self, *, required: bool = True) -> AbstractSpan | None:
1690-
span = self._graph_run._span(required=False) # type: ignore[reportPrivateUsage]
1691-
if span is None and required: # pragma: no cover
1692-
raise AttributeError('Span is not available for this agent run')
1693-
return span
1687+
def _traceparent(self) -> str: ...
1688+
def _traceparent(self, *, required: bool = True) -> str | None:
1689+
traceparent = self._graph_run._traceparent(required=False) # type: ignore[reportPrivateUsage]
1690+
if traceparent is None and required: # pragma: no cover
1691+
raise AttributeError('No span was created for this agent run')
1692+
return traceparent
16941693

16951694
@property
16961695
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
@@ -1729,7 +1728,7 @@ def result(self) -> AgentRunResult[OutputDataT] | None:
17291728
graph_run_result.output.tool_name,
17301729
graph_run_result.state,
17311730
self._graph_run.deps.new_message_index,
1732-
self._graph_run._span(required=False), # type: ignore[reportPrivateUsage]
1731+
self._traceparent(required=False),
17331732
)
17341733

17351734
def __aiter__(
@@ -1847,16 +1846,16 @@ class AgentRunResult(Generic[OutputDataT]):
18471846
_output_tool_name: str | None = dataclasses.field(repr=False)
18481847
_state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
18491848
_new_message_index: int = dataclasses.field(repr=False)
1850-
_span_value: AbstractSpan | None = dataclasses.field(repr=False)
1849+
_traceparent_value: str | None = dataclasses.field(repr=False)
18511850

18521851
@overload
1853-
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
1852+
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
18541853
@overload
1855-
def _span(self) -> AbstractSpan: ...
1856-
def _span(self, *, required: bool = True) -> AbstractSpan | None:
1857-
if self._span_value is None and required: # pragma: no cover
1858-
raise AttributeError('Span is not available for this agent run')
1859-
return self._span_value
1854+
def _traceparent(self) -> str: ...
1855+
def _traceparent(self, *, required: bool = True) -> str | None:
1856+
if self._traceparent_value is None and required: # pragma: no cover
1857+
raise AttributeError('No span was created for this agent run')
1858+
return self._traceparent_value
18601859

18611860
@property
18621861
@deprecated('`result.data` is deprecated, use `result.output` instead.')

pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ async def __call__(self, query: str) -> list[DuckDuckGoResult]:
5454
"""
5555
search = functools.partial(self.client.text, max_results=self.max_results)
5656
results = await anyio.to_thread.run_sync(search, query)
57-
if len(results) == 0:
58-
raise RuntimeError('No search results found.')
5957
return duckduckgo_ta.validate_python(results)
6058

6159

pydantic_ai_slim/pydantic_ai/common_tools/tavily.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ async def __call__(
6363
The search results.
6464
"""
6565
results = await self.client.search(query, search_depth=search_deep, topic=topic, time_range=time_range) # type: ignore[reportUnknownMemberType]
66-
if not results['results']:
67-
raise RuntimeError('No search results found.')
6866
return tavily_search_ta.validate_python(results['results'])
6967

7068

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Any
1010

1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
12-
from mcp.types import JSONRPCMessage
12+
from mcp.types import JSONRPCMessage, LoggingLevel
1313
from typing_extensions import Self
1414

1515
from pydantic_ai.tools import ToolDefinition
@@ -52,6 +52,11 @@ async def client_streams(
5252
raise NotImplementedError('MCP Server subclasses must implement this method.')
5353
yield
5454

55+
@abstractmethod
56+
def _get_log_level(self) -> LoggingLevel | None:
57+
"""Get the log level for the MCP server."""
58+
raise NotImplementedError('MCP Server subclasses must implement this method.')
59+
5560
async def list_tools(self) -> list[ToolDefinition]:
5661
"""Retrieve tools that are currently active on the server.
5762
@@ -89,6 +94,8 @@ async def __aenter__(self) -> Self:
8994
self._client = await self._exit_stack.enter_async_context(client)
9095

9196
await self._client.initialize()
97+
if log_level := self._get_log_level():
98+
await self._client.set_logging_level(log_level)
9299
self.is_running = True
93100
return self
94101

@@ -150,6 +157,13 @@ async def main():
150157
By default the subprocess will not inherit any environment variables from the parent process.
151158
If you want to inherit the environment variables from the parent process, use `env=os.environ`.
152159
"""
160+
log_level: LoggingLevel | None = None
161+
"""The log level to set when connecting to the server, if any.
162+
163+
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
164+
165+
If `None`, no log level will be set.
166+
"""
153167

154168
cwd: str | Path | None = None
155169
"""The working directory to use when spawning the process."""
@@ -164,6 +178,9 @@ async def client_streams(
164178
async with stdio_client(server=server) as (read_stream, write_stream):
165179
yield read_stream, write_stream
166180

181+
def _get_log_level(self) -> LoggingLevel | None:
182+
return self.log_level
183+
167184

168185
@dataclass
169186
class MCPServerHTTP(MCPServer):
@@ -223,6 +240,13 @@ async def main():
223240
If no new messages are received within this time, the connection will be considered stale
224241
and may be closed. Defaults to 5 minutes (300 seconds).
225242
"""
243+
log_level: LoggingLevel | None = None
244+
"""The log level to set when connecting to the server, if any.
245+
246+
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
247+
248+
If `None`, no log level will be set.
249+
"""
226250

227251
@asynccontextmanager
228252
async def client_streams(
@@ -234,3 +258,6 @@ async def client_streams(
234258
url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout
235259
) as (read_stream, write_stream):
236260
yield read_stream, write_stream
261+
262+
def _get_log_level(self) -> LoggingLevel | None:
263+
return self.log_level

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,8 @@ def args_as_dict(self) -> dict[str, Any]:
516516
"""
517517
if isinstance(self.args, dict):
518518
return self.args
519+
if isinstance(self.args, str) and not self.args:
520+
return {}
519521
args = pydantic_core.from_json(self.args)
520522
assert isinstance(args, dict), 'args should be a dict'
521523
return cast(dict[str, Any], args)

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
'google-gla:gemini-2.0-flash',
107107
'google-gla:gemini-2.0-flash-lite-preview-02-05',
108108
'google-gla:gemini-2.0-pro-exp-02-05',
109+
'google-gla:gemini-2.5-flash-preview-04-17',
109110
'google-gla:gemini-2.5-pro-exp-03-25',
110111
'google-gla:gemini-2.5-pro-preview-03-25',
111112
'google-vertex:gemini-1.0-pro',
@@ -118,6 +119,7 @@
118119
'google-vertex:gemini-2.0-flash',
119120
'google-vertex:gemini-2.0-flash-lite-preview-02-05',
120121
'google-vertex:gemini-2.0-pro-exp-02-05',
122+
'google-vertex:gemini-2.5-flash-preview-04-17',
121123
'google-vertex:gemini-2.5-pro-exp-03-25',
122124
'google-vertex:gemini-2.5-pro-preview-03-25',
123125
'gpt-3.5-turbo',
@@ -192,6 +194,8 @@
192194
'o1-mini-2024-09-12',
193195
'o1-preview',
194196
'o1-preview-2024-09-12',
197+
'o3',
198+
'o3-2025-04-16',
195199
'o3-mini',
196200
'o3-mini-2025-01-31',
197201
'openai:chatgpt-4o-latest',
@@ -241,8 +245,12 @@
241245
'openai:o1-mini-2024-09-12',
242246
'openai:o1-preview',
243247
'openai:o1-preview-2024-09-12',
248+
'openai:o3',
249+
'openai:o3-2025-04-16',
244250
'openai:o3-mini',
245251
'openai:o3-mini-2025-01-31',
252+
'openai:o4-mini',
253+
'openai:o4-mini-2025-04-16',
246254
'test',
247255
],
248256
)

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ async def _messages_create(
239239
timeout=model_settings.get('timeout', NOT_GIVEN),
240240
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
241241
extra_headers={'User-Agent': get_user_agent()},
242+
extra_body=model_settings.get('extra_body'),
242243
)
243244
except APIStatusError as e:
244245
if (status_code := e.status_code) >= 400:

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import functools
44
import typing
5-
from collections.abc import AsyncIterator, Iterable, Mapping
5+
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
66
from contextlib import asynccontextmanager
77
from dataclasses import dataclass, field
88
from datetime import datetime
9+
from itertools import count
910
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast, overload
1011

1112
import anyio
@@ -369,13 +370,14 @@ async def _map_messages(
369370
"""Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
370371
system_prompt: list[SystemContentBlockTypeDef] = []
371372
bedrock_messages: list[MessageUnionTypeDef] = []
373+
document_count: Iterator[int] = count(1)
372374
for m in messages:
373375
if isinstance(m, ModelRequest):
374376
for part in m.parts:
375377
if isinstance(part, SystemPromptPart):
376378
system_prompt.append({'text': part.content})
377379
elif isinstance(part, UserPromptPart):
378-
bedrock_messages.extend(await self._map_user_prompt(part))
380+
bedrock_messages.extend(await self._map_user_prompt(part, document_count))
379381
elif isinstance(part, ToolReturnPart):
380382
assert part.tool_call_id is not None
381383
bedrock_messages.append(
@@ -430,20 +432,18 @@ async def _map_messages(
430432
return system_prompt, bedrock_messages
431433

432434
@staticmethod
433-
async def _map_user_prompt(part: UserPromptPart) -> list[MessageUnionTypeDef]:
435+
async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]:
434436
content: list[ContentBlockUnionTypeDef] = []
435437
if isinstance(part.content, str):
436438
content.append({'text': part.content})
437439
else:
438-
document_count = 0
439440
for item in part.content:
440441
if isinstance(item, str):
441442
content.append({'text': item})
442443
elif isinstance(item, BinaryContent):
443444
format = item.format
444445
if item.is_document:
445-
document_count += 1
446-
name = f'Document {document_count}'
446+
name = f'Document {next(document_count)}'
447447
assert format in ('pdf', 'txt', 'csv', 'doc', 'docx', 'xls', 'xlsx', 'html', 'md')
448448
content.append({'document': {'name': name, 'format': format, 'source': {'bytes': item.data}}})
449449
elif item.is_image:
@@ -464,8 +464,7 @@ async def _map_user_prompt(part: UserPromptPart) -> list[MessageUnionTypeDef]:
464464
content.append({'image': image})
465465

466466
elif item.kind == 'document-url':
467-
document_count += 1
468-
name = f'Document {document_count}'
467+
name = f'Document {next(document_count)}'
469468
data = response.content
470469
content.append({'document': {'name': name, 'format': item.format, 'source': {'bytes': data}}})
471470

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
'gemini-2.0-flash',
5959
'gemini-2.0-flash-lite-preview-02-05',
6060
'gemini-2.0-pro-exp-02-05',
61+
'gemini-2.5-flash-preview-04-17',
6162
'gemini-2.5-pro-exp-03-25',
6263
'gemini-2.5-pro-preview-03-25',
6364
]

0 commit comments

Comments
 (0)