31
31
from mcp .client .sse import sse_client
32
32
from mcp .client .stdio import StdioServerParameters , stdio_client
33
33
from mcp .client .streamable_http import GetSessionIdCallback , streamablehttp_client
34
+ from mcp .client .websocket import websocket_client
34
35
from mcp .shared .context import RequestContext
35
36
from mcp .shared .exceptions import McpError
36
37
from mcp .shared .message import SessionMessage
@@ -866,6 +867,75 @@ def __eq__(self, value: object, /) -> bool:
866
867
return self .url == value .url
867
868
868
869
870
+ class MCPServerWebSocket (_MCPServerHTTP ):
871
+ """An MCP server that connects over a WebSocket connection.
872
+
873
+ This class implements the WebSocket transport from the MCP specification.
874
+ See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#websocket> for more information.
875
+
876
+ !!! note
877
+ Using this class as an async context manager will create a WebSocket connection
878
+ to a server which should already be running.
879
+
880
+ Example:
881
+ ```python {py="3.10"}
882
+ from pydantic_ai import Agent
883
+ from pydantic_ai.mcp import MCPServerWebSocket
884
+
885
+ server = MCPServerWebSocket('ws://localhost:8000/mcp-ws')
886
+ agent = Agent('openai:gpt-4o', toolsets=[server])
887
+
888
+ async def main():
889
+ async with agent: # This will connect to the WebSocket server.
890
+ ...
891
+ ```
892
+ """
893
+
894
+ @classmethod
895
+ def __get_pydantic_core_schema__ (cls , _ : Any , __ : Any ) -> CoreSchema :
896
+ return core_schema .no_info_after_validator_function (
897
+ lambda dct : MCPServerWebSocket (** dct ),
898
+ core_schema .typed_dict_schema (
899
+ {
900
+ 'url' : core_schema .typed_dict_field (core_schema .str_schema ()),
901
+ 'headers' : core_schema .typed_dict_field (
902
+ core_schema .dict_schema (core_schema .str_schema (), core_schema .str_schema ()), required = False
903
+ ),
904
+ }
905
+ ),
906
+ )
907
+
908
+ @property
909
+ def _transport_client (self ):
910
+ return websocket_client # pragma: no cover
911
+
912
+ @asynccontextmanager
913
+ async def client_streams (
914
+ self ,
915
+ ) -> AsyncIterator [
916
+ tuple [
917
+ MemoryObjectReceiveStream [SessionMessage | Exception ],
918
+ MemoryObjectSendStream [SessionMessage ],
919
+ ]
920
+ ]:
921
+ if self .http_client :
922
+ raise ValueError ('`http_client` is not supported for WebSocket connections.' )
923
+ if self .headers :
924
+ warnings .warn (
925
+ 'The provided `websocket_client` does not support `headers`. They will be ignored.' ,
926
+ UserWarning ,
927
+ stacklevel = 2 ,
928
+ )
929
+
930
+ async with websocket_client (url = self .url ) as (read_stream , write_stream ):
931
+ yield read_stream , write_stream
932
+
933
+ def __eq__ (self , value : object , / ) -> bool :
934
+ if not isinstance (value , MCPServerWebSocket ):
935
+ return False # pragma: no cover
936
+ return self .url == value .url
937
+
938
+
869
939
ToolResult = (
870
940
str
871
941
| messages .BinaryContent
@@ -898,7 +968,10 @@ def __eq__(self, value: object, /) -> bool:
898
968
899
969
def _mcp_server_discriminator (value : dict [str , Any ]) -> str | None :
900
970
if 'url' in value :
901
- if value ['url' ].endswith ('/sse' ):
971
+ url : str = value ['url' ]
972
+ if url .startswith ('ws://' ) or url .startswith ('wss://' ):
973
+ return 'websocket'
974
+ if url .endswith ('/sse' ):
902
975
return 'sse'
903
976
return 'streamable-http'
904
977
return 'stdio'
@@ -913,15 +986,18 @@ class MCPServerConfig(BaseModel):
913
986
Annotated [
914
987
Annotated [MCPServerStdio , Tag ('stdio' )]
915
988
| Annotated [MCPServerStreamableHTTP , Tag ('streamable-http' )]
916
- | Annotated [MCPServerSSE , Tag ('sse' )],
989
+ | Annotated [MCPServerSSE , Tag ('sse' )]
990
+ | Annotated [MCPServerWebSocket , Tag ('websocket' )],
917
991
Discriminator (_mcp_server_discriminator ),
918
992
],
919
993
],
920
994
Field (alias = 'mcpServers' ),
921
995
]
922
996
923
997
924
- def load_mcp_servers (config_path : str | Path ) -> list [MCPServerStdio | MCPServerStreamableHTTP | MCPServerSSE ]:
998
+ def load_mcp_servers (
999
+ config_path : str | Path ,
1000
+ ) -> list [MCPServerStdio | MCPServerStreamableHTTP | MCPServerSSE | MCPServerWebSocket ]:
925
1001
"""Load MCP servers from a configuration file.
926
1002
927
1003
Args:
0 commit comments