diff --git a/src/huggingface_hub/inference/_mcp/mcp_client.py b/src/huggingface_hub/inference/_mcp/mcp_client.py index 834da0d39a..0374c285bb 100644 --- a/src/huggingface_hub/inference/_mcp/mcp_client.py +++ b/src/huggingface_hub/inference/_mcp/mcp_client.py @@ -1,10 +1,11 @@ import json import logging from contextlib import AsyncExitStack +from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, AsyncIterable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload -from typing_extensions import TypeAlias +from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack from ...utils._runtime import get_hf_hub_version from .._generated._async_client import AsyncInferenceClient @@ -26,6 +27,30 @@ # Type alias for tool names ToolName: TypeAlias = str +ServerType: TypeAlias = Literal["stdio", "sse", "http"] + + +class StdioServerParameters_T(TypedDict): + command: str + args: NotRequired[List[str]] + env: NotRequired[Dict[str, str]] + cwd: NotRequired[Union[str, Path, None]] + + +class SSEServerParameters_T(TypedDict): + url: str + headers: NotRequired[Dict[str, Any]] + timeout: NotRequired[float] + sse_read_timeout: NotRequired[float] + + +class StreamableHTTPParameters_T(TypedDict): + url: str + headers: NotRequired[dict[str, Any]] + timeout: NotRequired[timedelta] + sse_read_timeout: NotRequired[timedelta] + terminate_on_close: NotRequired[bool] + class MCPClient: """ @@ -64,39 +89,84 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.client.__aexit__(exc_type, exc_val, exc_tb) await self.cleanup() - async def add_mcp_server( - self, - *, - command: str, - args: Optional[List[str]] = None, - env: Optional[Dict[str, str]] = None, - cwd: Union[str, Path, None] = None, - ): + @overload + async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ... + + @overload + async def add_mcp_server(self, type: Literal["sse"], **params: Unpack[SSEServerParameters_T]): ... + + @overload + async def add_mcp_server(self, type: Literal["http"], **params: Unpack[StreamableHTTPParameters_T]): ... + + async def add_mcp_server(self, type: ServerType, **params: Any): """Connect to an MCP server Args: - command (str): - The command to run the MCP server. - args (List[str], optional): - Arguments for the command. - env (Dict[str, str], optional): - Environment variables for the command. Default is to inherit the parent environment. - cwd (Union[str, Path, None], optional): - Working directory for the command. Default to current directory. + type (`str`): + Type of the server to connect to. Can be one of: + - "stdio": Standard input/output server (local) + - "sse": Server-sent events (SSE) server + - "http": StreamableHTTP server + **params: Server parameters that can be either: + - For stdio servers: + - command (str): The command to run the MCP server + - args (List[str], optional): Arguments for the command + - env (Dict[str, str], optional): Environment variables for the command + - cwd (Union[str, Path, None], optional): Working directory for the command + - For SSE servers: + - url (str): The URL of the SSE server + - headers (Dict[str, Any], optional): Headers for the SSE connection + - timeout (float, optional): Connection timeout + - sse_read_timeout (float, optional): SSE read timeout + - For StreamableHTTP servers: + - url (str): The URL of the StreamableHTTP server + - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection + - timeout (timedelta, optional): Connection timeout + - sse_read_timeout (timedelta, optional): SSE read timeout + - terminate_on_close (bool, optional): Whether to terminate on close """ from mcp import ClientSession, StdioServerParameters from mcp import types as mcp_types - from mcp.client.stdio import stdio_client - - logger.info(f"Connecting to MCP server with command: {command} {args}") - server_params = StdioServerParameters( - command=command, - args=args if args is not None else [], - env=env, - cwd=cwd, - ) - read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + # Determine server type and create appropriate parameters + if type == "stdio": + # Handle stdio server + from mcp.client.stdio import stdio_client + + logger.info(f"Connecting to stdio MCP server with command: {params['command']} {params.get('args', [])}") + + client_kwargs = {"command": params["command"]} + for key in ["args", "env", "cwd"]: + if params.get(key) is not None: + client_kwargs[key] = params[key] + server_params = StdioServerParameters(**client_kwargs) + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + elif type == "sse": + # Handle SSE server + from mcp.client.sse import sse_client + + logger.info(f"Connecting to SSE MCP server at: {params['url']}") + + client_kwargs = {"url": params["url"]} + for key in ["headers", "timeout", "sse_read_timeout"]: + if params.get(key) is not None: + client_kwargs[key] = params[key] + read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs)) + elif type == "http": + # Handle StreamableHTTP server + from mcp.client.streamable_http import streamablehttp_client + + logger.info(f"Connecting to StreamableHTTP MCP server at: {params['url']}") + + client_kwargs = {"url": params["url"]} + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: + if params.get(key) is not None: + client_kwargs[key] = params[key] + read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs)) + # ^ TODO: should be handle `get_session_id_callback`? (function to retrieve the current session ID) + else: + raise ValueError(f"Unsupported server type: {type}") + session = await self.exit_stack.enter_async_context( ClientSession( read_stream=read,