Skip to content

Commit ff95d0e

Browse files
committed
Introduces , a new client for connecting to MCP
servers using the WebSocket transport
1 parent 3430187 commit ff95d0e

File tree

1 file changed

+79
-3
lines changed
  • pydantic_ai_slim/pydantic_ai

1 file changed

+79
-3
lines changed

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from mcp.client.sse import sse_client
3232
from mcp.client.stdio import StdioServerParameters, stdio_client
3333
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
34+
from mcp.client.websocket import websocket_client
3435
from mcp.shared.context import RequestContext
3536
from mcp.shared.exceptions import McpError
3637
from mcp.shared.message import SessionMessage
@@ -866,6 +867,75 @@ def __eq__(self, value: object, /) -> bool:
866867
return self.url == value.url
867868

868869

870+
class MCPServerWebSocket(_MCPServerHTTP):
871+
"""An MCP server that connects over a WebSocket connection.
872+
873+
This class implements the WebSocket transport from the MCP specification.
874+
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#websocket> for more information.
875+
876+
!!! note
877+
Using this class as an async context manager will create a WebSocket connection
878+
to a server which should already be running.
879+
880+
Example:
881+
```python {py="3.10"}
882+
from pydantic_ai import Agent
883+
from pydantic_ai.mcp import MCPServerWebSocket
884+
885+
server = MCPServerWebSocket('ws://localhost:8000/mcp-ws')
886+
agent = Agent('openai:gpt-4o', toolsets=[server])
887+
888+
async def main():
889+
async with agent: # This will connect to the WebSocket server.
890+
...
891+
```
892+
"""
893+
894+
@classmethod
895+
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> CoreSchema:
896+
return core_schema.no_info_after_validator_function(
897+
lambda dct: MCPServerWebSocket(**dct),
898+
core_schema.typed_dict_schema(
899+
{
900+
'url': core_schema.typed_dict_field(core_schema.str_schema()),
901+
'headers': core_schema.typed_dict_field(
902+
core_schema.dict_schema(core_schema.str_schema(), core_schema.str_schema()), required=False
903+
),
904+
}
905+
),
906+
)
907+
908+
@property
909+
def _transport_client(self):
910+
return websocket_client # pragma: no cover
911+
912+
@asynccontextmanager
913+
async def client_streams(
914+
self,
915+
) -> AsyncIterator[
916+
tuple[
917+
MemoryObjectReceiveStream[SessionMessage | Exception],
918+
MemoryObjectSendStream[SessionMessage],
919+
]
920+
]:
921+
if self.http_client:
922+
raise ValueError('`http_client` is not supported for WebSocket connections.')
923+
if self.headers:
924+
warnings.warn(
925+
'The provided `websocket_client` does not support `headers`. They will be ignored.',
926+
UserWarning,
927+
stacklevel=2,
928+
)
929+
930+
async with websocket_client(url=self.url) as (read_stream, write_stream):
931+
yield read_stream, write_stream
932+
933+
def __eq__(self, value: object, /) -> bool:
934+
if not isinstance(value, MCPServerWebSocket):
935+
return False # pragma: no cover
936+
return self.url == value.url
937+
938+
869939
ToolResult = (
870940
str
871941
| messages.BinaryContent
@@ -898,7 +968,10 @@ def __eq__(self, value: object, /) -> bool:
898968

899969
def _mcp_server_discriminator(value: dict[str, Any]) -> str | None:
900970
if 'url' in value:
901-
if value['url'].endswith('/sse'):
971+
url: str = value['url']
972+
if url.startswith('ws://') or url.startswith('wss://'):
973+
return 'websocket'
974+
if url.endswith('/sse'):
902975
return 'sse'
903976
return 'streamable-http'
904977
return 'stdio'
@@ -913,15 +986,18 @@ class MCPServerConfig(BaseModel):
913986
Annotated[
914987
Annotated[MCPServerStdio, Tag('stdio')]
915988
| Annotated[MCPServerStreamableHTTP, Tag('streamable-http')]
916-
| Annotated[MCPServerSSE, Tag('sse')],
989+
| Annotated[MCPServerSSE, Tag('sse')]
990+
| Annotated[MCPServerWebSocket, Tag('websocket')],
917991
Discriminator(_mcp_server_discriminator),
918992
],
919993
],
920994
Field(alias='mcpServers'),
921995
]
922996

923997

924-
def load_mcp_servers(config_path: str | Path) -> list[MCPServerStdio | MCPServerStreamableHTTP | MCPServerSSE]:
998+
def load_mcp_servers(
999+
config_path: str | Path,
1000+
) -> list[MCPServerStdio | MCPServerStreamableHTTP | MCPServerSSE | MCPServerWebSocket]:
9251001
"""Load MCP servers from a configuration file.
9261002
9271003
Args:

0 commit comments

Comments
 (0)