Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 98 additions & 28 deletions src/huggingface_hub/inference/_mcp/mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, AsyncIterable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload

from typing_extensions import TypeAlias
from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack

from ...utils._runtime import get_hf_hub_version
from .._generated._async_client import AsyncInferenceClient
Expand All @@ -26,6 +27,30 @@
# Type alias for tool names
ToolName: TypeAlias = str

ServerType: TypeAlias = Literal["stdio", "sse", "http"]


class StdioServerParameters_T(TypedDict):
command: str
args: NotRequired[List[str]]
env: NotRequired[Dict[str, str]]
cwd: NotRequired[Union[str, Path, None]]


class SSEServerParameters_T(TypedDict):
url: str
headers: NotRequired[Dict[str, Any]]
timeout: NotRequired[float]
sse_read_timeout: NotRequired[float]


class StreamableHTTPParameters_T(TypedDict):
url: str
headers: NotRequired[dict[str, Any]]
timeout: NotRequired[timedelta]
sse_read_timeout: NotRequired[timedelta]
terminate_on_close: NotRequired[bool]


class MCPClient:
"""
Expand Down Expand Up @@ -64,39 +89,84 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.__aexit__(exc_type, exc_val, exc_tb)
await self.cleanup()

async def add_mcp_server(
self,
*,
command: str,
args: Optional[List[str]] = None,
env: Optional[Dict[str, str]] = None,
cwd: Union[str, Path, None] = None,
):
@overload
async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ...

@overload
async def add_mcp_server(self, type: Literal["sse"], **params: Unpack[SSEServerParameters_T]): ...

@overload
async def add_mcp_server(self, type: Literal["http"], **params: Unpack[StreamableHTTPParameters_T]): ...

async def add_mcp_server(self, type: ServerType, **params: Any):
"""Connect to an MCP server

Args:
command (str):
The command to run the MCP server.
args (List[str], optional):
Arguments for the command.
env (Dict[str, str], optional):
Environment variables for the command. Default is to inherit the parent environment.
cwd (Union[str, Path, None], optional):
Working directory for the command. Default to current directory.
type (`str`):
Type of the server to connect to. Can be one of:
- "stdio": Standard input/output server (local)
- "sse": Server-sent events (SSE) server
- "http": StreamableHTTP server
**params: Server parameters that can be either:
- For stdio servers:
- command (str): The command to run the MCP server
- args (List[str], optional): Arguments for the command
- env (Dict[str, str], optional): Environment variables for the command
- cwd (Union[str, Path, None], optional): Working directory for the command
- For SSE servers:
- url (str): The URL of the SSE server
- headers (Dict[str, Any], optional): Headers for the SSE connection
- timeout (float, optional): Connection timeout
- sse_read_timeout (float, optional): SSE read timeout
- For StreamableHTTP servers:
- url (str): The URL of the StreamableHTTP server
- headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
- timeout (timedelta, optional): Connection timeout
- sse_read_timeout (timedelta, optional): SSE read timeout
- terminate_on_close (bool, optional): Whether to terminate on close
"""
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types
from mcp.client.stdio import stdio_client

logger.info(f"Connecting to MCP server with command: {command} {args}")
server_params = StdioServerParameters(
command=command,
args=args if args is not None else [],
env=env,
cwd=cwd,
)

read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
# Determine server type and create appropriate parameters
if type == "stdio":
# Handle stdio server
from mcp.client.stdio import stdio_client

logger.info(f"Connecting to stdio MCP server with command: {params['command']} {params.get('args', [])}")

client_kwargs = {"command": params["command"]}
for key in ["args", "env", "cwd"]:
if params.get(key) is not None:
client_kwargs[key] = params[key]
server_params = StdioServerParameters(**client_kwargs)
read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
elif type == "sse":
# Handle SSE server
from mcp.client.sse import sse_client

logger.info(f"Connecting to SSE MCP server at: {params['url']}")

client_kwargs = {"url": params["url"]}
for key in ["headers", "timeout", "sse_read_timeout"]:
if params.get(key) is not None:
client_kwargs[key] = params[key]
read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs))
elif type == "http":
# Handle StreamableHTTP server
from mcp.client.streamable_http import streamablehttp_client

logger.info(f"Connecting to StreamableHTTP MCP server at: {params['url']}")

client_kwargs = {"url": params["url"]}
for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]:
if params.get(key) is not None:
client_kwargs[key] = params[key]
read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs))
# ^ TODO: should be handle `get_session_id_callback`? (function to retrieve the current session ID)
else:
raise ValueError(f"Unsupported server type: {type}")

session = await self.exit_stack.enter_async_context(
ClientSession(
read_stream=read,
Expand Down