diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py new file mode 100644 index 0000000000..0a21eb5a6f --- /dev/null +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -0,0 +1,51 @@ +import logging + +import anyio + +import mcp.client.mqtt as mcp_mqtt +from mcp.shared.mqtt import configure_logging + +configure_logging(level="DEBUG") +logger = logging.getLogger(__name__) + +async def on_mcp_server_discovered(client, server_name): + logger.info(f"Discovered {server_name}, connecting ...") + await client.initialize_mcp_server(server_name) + +async def on_mcp_connect(client, server_name, connect_result): + capabilities = client.get_session(server_name).server_info.capabilities + logger.info(f"Capabilities of {server_name}: {capabilities}") + if capabilities.prompts: + prompts = await client.list_prompts(server_name) + logger.info(f"Prompts of {server_name}: {prompts}") + if capabilities.resources: + resources = await client.list_resources(server_name) + logger.info(f"Resources of {server_name}: {resources}") + resource_templates = await client.list_resource_templates(server_name) + logger.info(f"Resources templates of {server_name}: {resource_templates}") + if capabilities.tools: + toolsResult = await client.list_tools(server_name) + tools = toolsResult.tools + logger.info(f"Tools of {server_name}: {tools}") + +async def on_mcp_disconnect(client, server_name): + logger.info(f"Disconnected from {server_name}") + +async def main(): + async with mcp_mqtt.MqttTransportClient( + "test_client", + auto_connect_to_mcp_server = True, + on_mcp_server_discovered = on_mcp_server_discovered, + on_mcp_connect = on_mcp_connect, + on_mcp_disconnect = on_mcp_disconnect, + mqtt_options = mcp_mqtt.MqttOptions( + host="broker.emqx.io", + ) + ) as client: + client.start() + while True: + logger.info("Other works while the MQTT transport client is running in the background...") + await anyio.sleep(10) + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/fastmcp/mqtt_simple_echo.py b/examples/fastmcp/mqtt_simple_echo.py new file mode 100644 index 0000000000..1f50a9d611 --- /dev/null +++ b/examples/fastmcp/mqtt_simple_echo.py @@ -0,0 +1,27 @@ +""" +FastMCP Echo Server +""" + +from mcp.server.fastmcp import FastMCP +from mcp.shared.mqtt import MqttOptions + +# Create server +mcp = FastMCP( + "demo_server/echo", + log_level="DEBUG", +) + +mcp.settings.mqtt_options = MqttOptions( + host="broker.emqx.io", + verify_connack_properties=True, # Change to False if broker is Mosquitto +) + + +@mcp.tool() +def echo(text: str) -> str: + """Echo the input text""" + return text + + +if __name__ == "__main__": + mcp.run(transport="mqtt") diff --git a/pyproject.toml b/pyproject.toml index c6119867ef..7457e135e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "uvicorn>=0.31.1; sys_platform != 'emscripten'", "jsonschema>=4.20.0", "pywin32>=310; sys_platform == 'win32'", + "paho-mqtt>=2.1.0", ] [project.optional-dependencies] @@ -59,6 +60,9 @@ dev = [ "pytest-pretty>=1.2.0", "inline-snapshot>=0.23.0", "dirty-equals>=0.9.0", + "typer>=0.17.4", + "python-dotenv>=1.1.1", + "websockets>=15.0.1", ] docs = [ "mkdocs>=1.6.1", @@ -134,7 +138,7 @@ max-returns = 13 # Default is 6 max-statements = 102 # Default is 50 [tool.uv.workspace] -members = ["examples/servers/*", "examples/snippets"] +members = ["examples/servers/*", "examples/snippets", "examples/clients/mqtt-clients/smart-home"] [tool.uv.sources] mcp = { workspace = true } diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py new file mode 100644 index 0000000000..c8401348a7 --- /dev/null +++ b/src/mcp/client/mqtt.py @@ -0,0 +1,549 @@ +""" +This module implements the MQTT transport for the MCP server. +""" + +import asyncio +import json +import logging +import random +import traceback +from collections.abc import Awaitable, Callable +from contextlib import AsyncExitStack +from datetime import timedelta +from typing import Any, Literal, TypeAlias +from uuid import uuid4 + +import anyio +import anyio.from_thread as anyio_from_thread +import anyio.to_thread as anyio_to_thread +import paho.mqtt.client as mqtt +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from paho.mqtt.properties import Properties +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.subscribeoptions import SubscribeOptions +from pydantic import AnyUrl, BaseModel + +import mcp.shared.mqtt_topic as mqtt_topic +import mcp.types as types +from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage +from mcp.shared.mqtt import MCP_SERVER_NAME_FILTERS, QOS, MqttOptions, MqttTransportBase + +RcvStream: TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream: TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx: TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX: TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerRun: TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + +ServerName: TypeAlias = str +ServerId: TypeAlias = str +InitializeResult: TypeAlias = Literal["ok"] | Literal["already_connected"] | tuple[Literal["error"], str] +ConnectResult: TypeAlias = tuple[Literal["ok"], types.InitializeResult] | tuple[Literal["error"], Any] + +logger = logging.getLogger(__name__) + + +class ServerDefinition(BaseModel): + description: str + meta: dict[str, Any] = {} + + +class ServerOnlineNotification(BaseModel): + jsonrpc: Literal["2.0"] + method: str = "notifications/server/online" + params: ServerDefinition + + +class MqttClientSession(ClientSession): + def __init__( + self, + server_id: ServerId, + server_name: ServerName, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + ) -> None: + super().__init__( + read_stream, + write_stream, + read_timeout_seconds, + sampling_callback, + elicitation_callback, + list_roots_callback, + logging_callback, + message_handler, + ) + self.server_id = server_id + self.server_name = server_name + self.server_info: types.InitializeResult | None = None + + +class MqttTransportClient(MqttTransportBase): + def __init__( + self, + mcp_client_name: str, + client_id: str | None = None, + server_name_filter: str | list[str] = "#", + auto_connect_to_mcp_server: bool = False, + on_mcp_connect: Callable[["MqttTransportClient", ServerName, ConnectResult], Awaitable[Any]] | None = None, + on_mcp_disconnect: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, + on_mcp_server_discovered: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, + mqtt_options: MqttOptions = MqttOptions(), + ): + uuid = uuid4().hex + mqtt_clientid = client_id if client_id else uuid + self._current_server_id: dict[ServerName, ServerId] = {} + self.server_list: dict[ServerName, dict[ServerId, ServerDefinition]] = {} + self.client_sessions: dict[ServerName, MqttClientSession] = {} + self.mcp_client_id = mqtt_clientid + self.mcp_client_name = mcp_client_name + if isinstance(server_name_filter, str): + self.server_name_filter = [server_name_filter] + else: + self.server_name_filter = server_name_filter + self.auto_connect_to_mcp_server = auto_connect_to_mcp_server + self.on_mcp_connect = on_mcp_connect + self.on_mcp_disconnect = on_mcp_disconnect + self.on_mcp_server_discovered = on_mcp_server_discovered + self.client_capability_change_topic = mqtt_topic.get_client_capability_change_topic(self.mcp_client_id) + ## Send disconnected notification when disconnects + self._disconnected_msg = types.JSONRPCMessage( + types.JSONRPCNotification(jsonrpc="2.0", method="notifications/disconnected") + ) + super().__init__( + "mcp-client", + mqtt_clientid=mqtt_clientid, + mqtt_options=mqtt_options, + disconnected_msg=self._disconnected_msg, + disconnected_msg_retain=False, + ) + + def get_presence_topic(self) -> str: + return mqtt_topic.get_client_presence_topic(self.mcp_client_id) + + async def start(self, timeout: timedelta | None = None) -> bool | str: + try: + connect_result = self.connect() + asyncio.create_task(anyio_to_thread.run_sync(self.client.loop_forever)) + if connect_result != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to connect to MQTT broker, error code: {connect_result}") + return mqtt.error_string(connect_result) + # test if the client is connected and wait until it is connected + if timeout: + while not self.is_connected(): + await asyncio.sleep(0.1) + if timeout.total_seconds() <= 0: + last_fail_reason = self.get_last_connect_fail_reason() + if last_fail_reason: + return last_fail_reason.getName() + return "timeout" + timeout -= timedelta(seconds=0.1) + return True + except asyncio.CancelledError: + logger.debug("MQTT transport (MCP client) got cancelled") + return "cancelled" + except ConnectionRefusedError as exc: + logger.error(f"MQTT transport (MCP client) failed to connect: {exc}") + return "connection_refused" + except TimeoutError as exc: + logger.error(f"MQTT transport (MCP client) timed out: {exc}") + return "timeout" + except Exception as exc: + logger.error(f"MQTT transport (MCP client) failed: {exc}") + return f"connect mqtt error: {str(exc)}" + + def get_session(self, server_name: ServerName) -> MqttClientSession | None: + return self.client_sessions.get(server_name, None) + + async def initialize_mcp_server( + self, + server_name: str, + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + ) -> InitializeResult: + if server_name in self.client_sessions: + return "already_connected" + if server_name not in self.server_list: + logger.error(f"MCP server not found, server name: {server_name}") + return ("error", "MCP server not found") + server_id = self.pick_server_id(server_name) + + async def after_subscribed(subscribe_result: Literal["success", "error"]): + if subscribe_result == "error": + if self.on_mcp_connect: + self._task_group.start_soon( + self.on_mcp_connect, self, server_name, ("error", "subscribe_mcp_server_topics_failed") + ) + client_session = self._create_session( + server_id, + server_name, + read_timeout_seconds, + sampling_callback, + None, # elicitation_callback + list_roots_callback, + logging_callback, + message_handler, + ) + self.client_sessions[server_name] = client_session + try: + logger.debug(f"before initialize: {server_name}") + + async def after_initialize(): + exit_stack = AsyncExitStack() + try: + session = await exit_stack.enter_async_context(client_session) + init_result = await session.initialize() + session.server_info = init_result + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("ok", init_result)) + except Exception as e: + self.client_sessions.pop(server_name) + logging.error(f"Failed to initialize server {server_name}: {e}") + await exit_stack.aclose() + + self._task_group.start_soon(after_initialize) + logger.debug(f"after initialize: {server_name}") + except McpError as exc: + self.client_sessions.pop(server_name) + logger.error(f"Failed to connect to MCP server: {exc}") + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("error", McpError)) + + if self._subscribe_mcp_server_topics(server_id, server_name, after_subscribed): + return "ok" + else: + return ("error", "send_subscribe_request_failed") + + async def deinitialize_mcp_server(self, server_name: ServerName) -> None: + server_id = self._current_server_id[server_name] + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + self.publish_json_rpc_message(topic, message=self._disconnected_msg, retain=False) + self._remove_server(server_id, server_name) + + async def send_ping(self, server_name: ServerName) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.send_ping()) + + async def send_progress_notification( + self, server_name: ServerName, progress_token: str | int, progress: float, total: float | None = None + ) -> bool | None: + return await self._with_session( + server_name, lambda s: s.send_progress_notification(progress_token, progress, total) + ) + + async def set_logging_level(self, server_name: ServerName, level: types.LoggingLevel) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.set_logging_level(level)) + + async def list_resources(self, server_name: ServerName) -> bool | types.ListResourcesResult: + return await self._with_session(server_name, lambda s: s.list_resources()) + + async def list_resource_templates(self, server_name: ServerName) -> bool | types.ListResourceTemplatesResult: + return await self._with_session(server_name, lambda s: s.list_resource_templates()) + + async def read_resource(self, server_name: ServerName, uri: AnyUrl) -> bool | types.ReadResourceResult: + return await self._with_session(server_name, lambda s: s.read_resource(uri)) + + async def subscribe_resource(self, server_name: ServerName, uri: AnyUrl) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.subscribe_resource(uri)) + + async def unsubscribe_resource(self, server_name: ServerName, uri: AnyUrl) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.unsubscribe_resource(uri)) + + async def call_tool( + self, server_name: ServerName, name: str, arguments: dict[str, Any] | None = None + ) -> bool | types.CallToolResult: + return await self._with_session(server_name, lambda s: s.call_tool(name, arguments)) + + async def list_prompts(self, server_name: ServerName) -> bool | types.ListPromptsResult: + return await self._with_session(server_name, lambda s: s.list_prompts()) + + async def get_prompt( + self, server_name: ServerName, name: str, arguments: dict[str, str] | None = None + ) -> bool | types.GetPromptResult: + return await self._with_session(server_name, lambda s: s.get_prompt(name, arguments)) + + async def complete( + self, + server_name: ServerName, + ref: types.ResourceTemplateReference | types.PromptReference, + argument: dict[str, str], + ) -> bool | types.CompleteResult: + return await self._with_session(server_name, lambda s: s.complete(ref, argument)) + + async def list_tools(self, server_name: ServerName) -> bool | types.ListToolsResult: + return await self._with_session(server_name, lambda s: s.list_tools()) + + async def send_roots_list_changed(self, server_name: ServerName) -> bool | None: + return await self._with_session(server_name, lambda s: s.send_roots_list_changed()) + + async def _with_session( + self, server_name: ServerName, async_callback: Callable[[MqttClientSession], Awaitable[bool | Any]] + ) -> bool | Any: + if not (client_session := self.client_sessions.get(server_name, None)): + logger.error(f"No session for server_name: {server_name}") + return False + return await async_callback(client_session) + + def _create_session( + self, + server_id: ServerId, + server_name: ServerName, + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + ): + ## Streams are used to communicate between the MqttTransportClient and the MCPSession: + ## 1. MQTT --> Client -->[raw_read]-- conversion -->[session_read]--> MCPSession + ## 2. MQTT <-- Client <--[raw_write]<-- conversion <--[session_write]<-- MCPSession + # Create raw streams for JSONRPCMessage + raw_read_stream_writer, raw_read_stream = anyio.create_memory_object_stream[types.JSONRPCMessage | Exception](0) + raw_write_stream, raw_write_stream_reader = anyio.create_memory_object_stream[types.JSONRPCMessage](0) + + # Create SessionMessage streams for the session + session_read_stream_writer, session_read_stream = anyio.create_memory_object_stream[SessionMessage | Exception]( + 0 + ) + session_write_stream, session_write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + + # Start conversion tasks + self._task_group.start_soon(self._convert_jsonrpc_to_session, raw_read_stream, session_read_stream_writer) + self._task_group.start_soon(self._convert_session_to_jsonrpc, session_write_stream_reader, raw_write_stream) + self._read_stream_writers[server_id] = raw_read_stream_writer + self._task_group.start_soon(self._receieved_from_session, server_id, server_name, raw_write_stream_reader) + logger.debug(f"Created new session for server_id: {server_id}") + return MqttClientSession( + server_id, + server_name, + session_read_stream, + session_write_stream, + read_timeout_seconds, + sampling_callback, + elicitation_callback, + list_roots_callback, + logging_callback, + message_handler, + ) + + def _on_connect( + self, + client: mqtt.Client, + userdata: Any, + connect_flags: mqtt.ConnectFlags, + reason_code: ReasonCode, + properties: Properties | None, + ): + super()._on_connect(client, userdata, connect_flags, reason_code, properties) + if properties and hasattr(properties, "UserProperty"): + user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore + if MCP_SERVER_NAME_FILTERS in user_properties: + self.server_name_filter = json.loads(user_properties[MCP_SERVER_NAME_FILTERS]) + logger.debug(f"Use broker suggested server name filters: {self.server_name_filter}") + if reason_code == 0: + ## Subscribe to the MCP server's presence topic + for snf in self.server_name_filter: + logger.debug(f"Subscribing to server presence topic for server_name_filter: {snf}") + client.subscribe(mqtt_topic.get_server_presence_topic("+", snf), qos=QOS) + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") + match msg.topic: + case str() as t if t.startswith(mqtt_topic.SERVER_PRESENCE_BASE): + self._handle_server_presence_message(msg) + case str() as t if t.startswith(mqtt_topic.RPC_BASE): + self._handle_rpc_message(msg) + case str() as t if t.startswith(mqtt_topic.SERVER_CAPABILITY_CHANGE_BASE): + self._handle_server_capability_message(msg) + case _: + logger.error(f"Received message on unexpected topic: {msg.topic}") + + def _on_subscribe( + self, + client: mqtt.Client, + userdata: Any, + mid: int, + reason_code_list: list[ReasonCode], + properties: Properties | None, + ): + if mid in userdata.get("pending_subs", {}): + server_name, server_id, after_subscribed = userdata["pending_subs"].pop(mid) + ## only create session if all topic subscribed successfully + if all(rc.value == QOS for rc in reason_code_list): + logger.debug(f"Subscribed to topics for server_name: {server_name}, server_id: {server_id}") + anyio_from_thread.run(after_subscribed, "success") + else: + anyio_from_thread.run(after_subscribed, "error") + logger.error( + f"Failed to subscribe to topics for server_name: {server_name}, " + f"server_id: {server_id}, reason_codes: {reason_code_list}" + ) + + def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: + topic_words = msg.topic.split("/") + server_id = topic_words[2] + server_name = "/".join(topic_words[3:]) + if msg.payload: + newly_added_server = False if server_name in self.server_list else True + server_notif = ServerOnlineNotification.model_validate_json(msg.payload.decode()) + self.server_list.setdefault(server_name, {})[server_id] = server_notif.params + logger.debug(f"Server {server_name} with id {server_id} is online") + if newly_added_server: + if self.auto_connect_to_mcp_server: + logger.debug(f"Auto connecting to MCP server {server_name}") + anyio_from_thread.run(self.initialize_mcp_server, server_name) + if self.on_mcp_server_discovered: + anyio_from_thread.run(self.on_mcp_server_discovered, self, server_name) + else: + # server is offline if the payload is empty + logger.debug(f"Server {server_name} with id {server_id} is offline") + self._remove_server(server_id, server_name) + + def _remove_server(self, server_id: ServerId, server_name: ServerName) -> None: + if server_id in self.server_list.get(server_name, {}): + if server_id in self._read_stream_writers: + logger.debug(f"Closing stream writer for server_id: {server_id}") + self._read_stream_writers[server_id].close() + + def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: + server_name = "/".join(msg.topic.split("/")[3:]) + anyio_from_thread.run(self._send_message_to_session, server_name, msg) + + def _handle_server_capability_message(self, msg: mqtt.MQTTMessage) -> None: + server_name = "/".join(msg.topic.split("/")[4:]) + anyio_from_thread.run(self._send_message_to_session, server_name, msg) + + def _subscribe_mcp_server_topics( + self, server_id: ServerId, server_name: ServerName, after_subscribed: Callable[[Any], Awaitable[None]] + ): + topic_filters = [ + (mqtt_topic.get_server_capability_change_topic(server_id, server_name), SubscribeOptions(qos=QOS)), + ( + mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name), + SubscribeOptions(qos=QOS, noLocal=True), + ), + ] + ret, mid = self.client.subscribe(topic=topic_filters) + if ret != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to subscribe to topics for server_name: {server_name}") + return False + userdata = self.client.user_data_get() + pending_subs = userdata.get("pending_subs", {}) + pending_subs[mid] = (server_name, server_id, after_subscribed) + userdata["pending_subs"] = pending_subs + return True + + async def _send_message_to_session(self, server_name: ServerName, msg: mqtt.MQTTMessage): + if not (client_session := self.client_sessions.get(server_name, None)): + logger.error(f"_send_message_to_session: No session for server_name: {server_name}") + return + payload = msg.payload.decode() + server_id = client_session.server_id + if server_id not in self._read_stream_writers: + logger.error(f"No session for server_id: {server_id}") + return + read_stream_writer = self._read_stream_writers[server_id] + try: + message = types.JSONRPCMessage.model_validate_json(payload) + logger.debug(f"Sending msg to session for server_id: {server_id}, msg: {message}") + with anyio.fail_after(3): + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Failed to send msg to session for server_id: {server_id}, exception: {exc}") + traceback.print_exc() + ## TODO: the session does not handle exceptions for now + # await read_stream_writer.send(exc) + + async def _receieved_from_session( + self, server_id: ServerId, server_name: ServerName, write_stream_reader: RcvStream + ): + async with write_stream_reader: + async for msg in write_stream_reader: + logger.debug(f"Got msg from session for server_id: {server_id}, msg: {msg}") + match msg.model_dump(): + case {"method": method} if method == "notifications/initialized": + logger.debug(f"Session initialized for server_id: {server_id}") + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + case {"method": method} if method.endswith("/list_changed"): + topic = None + logger.warning("Resource updates should not be sent from the session. Ignoring.") + case {"method": method} if method == "initialize": + topic = mqtt_topic.get_server_control_topic(server_id, server_name) + case _: + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + if topic: + self.publish_json_rpc_message(topic, message=msg) + # cleanup + if server_id in self._read_stream_writers: + logger.debug(f"Removing session for server_id: {server_id}") + stream = self._read_stream_writers.pop(server_id) + await stream.aclose() + + # unsubscribe from the topics + logger.debug(f"Unsubscribing from topics for server_id: {server_id}, server_name: {server_name}") + topic_filters = [ + mqtt_topic.get_server_capability_change_topic(server_id, server_name), + mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name), + ] + self.client.unsubscribe(topic=topic_filters) + + if server_id in self.server_list.get(server_name, {}): + _ = self.server_list[server_name].pop(server_id) + if not self.server_list[server_name]: + _ = self.server_list.pop(server_name) + if self.on_mcp_disconnect: + self._task_group.start_soon(self.on_mcp_disconnect, self, server_name) + + if server_name in self.client_sessions: + _ = self.client_sessions.pop(server_name) + + if server_name in self._current_server_id: + _ = self._current_server_id.pop(server_name) + logger.debug(f"Session stream closed for server_id: {server_id}") + + def pick_server_id(self, server_name: str) -> ServerId: + server_id = random.choice(list(self.server_list[server_name].keys())) + self._current_server_id[server_name] = server_id + return server_id + + async def _convert_jsonrpc_to_session( + self, + jsonrpc_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + session_writer: MemoryObjectSendStream[SessionMessage | Exception], + ) -> None: + """Convert JSONRPCMessage stream to SessionMessage stream.""" + async with jsonrpc_stream, session_writer: + async for message in jsonrpc_stream: + if isinstance(message, Exception): + await session_writer.send(message) + else: + session_message = SessionMessage(message=message) + await session_writer.send(session_message) + + async def _convert_session_to_jsonrpc( + self, + session_stream: MemoryObjectReceiveStream[SessionMessage], + jsonrpc_writer: MemoryObjectSendStream[types.JSONRPCMessage], + ) -> None: + """Convert SessionMessage stream to JSONRPCMessage stream.""" + async with session_stream, jsonrpc_writer: + async for session_message in session_stream: + await jsonrpc_writer.send(session_message.message) + + +def validate_server_name(name: str): + if "/" not in name: + raise ValueError(f"Invalid server name: {name}, must contain a '/'") + elif ("+" in name) or ("#" in name): + raise ValueError(f"Invalid server name: {name}, must not contain '+' or '#'") + elif name[0] == "/": + raise ValueError(f"Invalid server name: {name}, must not start with '/'") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index d86fa85e32..8e0778ed2e 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -36,6 +36,7 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan +from mcp.server.mqtt import MqttOptions, start_mqtt, validate_server_name from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server @@ -85,6 +86,12 @@ class Settings(BaseSettings, Generic[LifespanResultT]): stateless_http: bool """Define if the server should create a new transport per request.""" + # MQTT settings + mqtt_server_description: str = "" + mqtt_server_meta: dict[str, Any] = {} + mqtt_client_id: str | None = None + mqtt_options: MqttOptions = MqttOptions() + # resource settings warn_on_duplicate_resources: bool @@ -232,16 +239,16 @@ def session_manager(self) -> StreamableHTTPSessionManager: def run( self, - transport: Literal["stdio", "sse", "streamable-http"] = "stdio", + transport: Literal["stdio", "sse", "streamable-http", "mqtt"] = "stdio", mount_path: str | None = None, ) -> None: """Run the FastMCP server. Note this is a synchronous function. Args: - transport: Transport protocol to use ("stdio", "sse", or "streamable-http") + transport: Transport protocol to use ("stdio", "sse", "streamable-http", or "mqtt") mount_path: Optional mount path for SSE transport """ - TRANSPORTS = Literal["stdio", "sse", "streamable-http"] + TRANSPORTS = Literal["stdio", "sse", "streamable-http", "mqtt"] if transport not in TRANSPORTS.__args__: # type: ignore raise ValueError(f"Unknown transport: {transport}") @@ -252,6 +259,9 @@ def run( anyio.run(lambda: self.run_sse_async(mount_path)) case "streamable-http": anyio.run(self.run_streamable_http_async) + case "mqtt": + validate_server_name(self._mcp_server.name) + anyio.run(self.run_mqtt_async) def _setup_handlers(self) -> None: """Set up core MCP protocol handlers.""" @@ -724,6 +734,25 @@ def _normalize_path(self, mount_path: str, endpoint: str) -> str: # Combine paths return mount_path + endpoint + async def run_mqtt_async(self) -> None: + """Run the server using MQTT transport.""" + + def server_session_run(read_stream: Any, write_stream: Any): + return self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options(), + ) + + await start_mqtt( + server_session_run, + server_name=self._mcp_server.name, + server_description=self.settings.mqtt_server_description, + server_meta=self.settings.mqtt_server_meta, + client_id=self.settings.mqtt_client_id, + mqtt_options=self.settings.mqtt_options, + ) + def sse_app(self, mount_path: str | None = None) -> Starlette: """Return an instance of the SSE server app.""" from starlette.middleware import Middleware diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3076e283e3..b424b00aff 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -592,6 +592,7 @@ async def run( lifespan_context, raise_exceptions, ) + logger.debug("Server closed") async def _handle_message( self, diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py new file mode 100644 index 0000000000..20ec867e27 --- /dev/null +++ b/src/mcp/server/mqtt.py @@ -0,0 +1,371 @@ +""" +This module implements the MQTT transport for the MCP server. +""" + +import asyncio +import json +import logging +import traceback +from collections.abc import Awaitable, Callable +from typing import Any, TypeAlias +from uuid import uuid4 + +import anyio +import anyio.from_thread as anyio_from_thread +import anyio.to_thread as anyio_to_thread +import paho.mqtt.client as mqtt +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from paho.mqtt.properties import Properties +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.subscribeoptions import SubscribeOptions + +import mcp.shared.mqtt_topic as mqtt_topic +import mcp.types as types +from mcp.shared.message import SessionMessage +from mcp.shared.mqtt import MCP_SERVER_NAME, PROPERTY_K_MQTT_CLIENT_ID, QOS, MqttOptions, MqttTransportBase + +# Raw MQTT streams (JSONRPCMessage) +RcvStream: TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream: TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx: TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX: TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] + +# Session streams (SessionMessage) +SessionRcvStream: TypeAlias = MemoryObjectReceiveStream[SessionMessage] +SessionSndStream: TypeAlias = MemoryObjectSendStream[SessionMessage] +SessionRcvStreamEx: TypeAlias = MemoryObjectReceiveStream[SessionMessage | Exception] +SessionSndStreamEx: TypeAlias = MemoryObjectSendStream[SessionMessage | Exception] + +ServerSessionRun: TypeAlias = Callable[[SessionRcvStreamEx, SessionSndStream], Awaitable[Any]] + +logger = logging.getLogger(__name__) + + +class MqttTransportServer(MqttTransportBase): + def __init__( + self, + server_session_run: ServerSessionRun, + server_name: str, + server_description: str, + server_meta: dict[str, Any], + client_id: str | None = None, + mqtt_options: MqttOptions = MqttOptions(), + ): + uuid = uuid4().hex + mqtt_clientid = client_id if client_id else uuid + self.server_id = mqtt_clientid + self.server_name = server_name + self.server_description = server_description + self.server_meta = server_meta + self.server_session_run = server_session_run + super().__init__( + "mcp-server", + mqtt_clientid=mqtt_clientid, + mqtt_options=mqtt_options, + disconnected_msg=None, + disconnected_msg_retain=True, + ) + + def get_presence_topic(self) -> str: + return mqtt_topic.get_server_presence_topic(self.server_id, self.server_name) + + def _on_connect( + self, + client: mqtt.Client, + userdata: Any, + connect_flags: mqtt.ConnectFlags, + reason_code: ReasonCode, + properties: Properties | None, + ): + super()._on_connect(client, userdata, connect_flags, reason_code, properties) + if reason_code == 0: + if properties and hasattr(properties, "UserProperty"): + user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore + if MCP_SERVER_NAME in user_properties: + broker_suggested_server_name = user_properties[MCP_SERVER_NAME] + self.server_name = broker_suggested_server_name + logger.debug(f"Used broker suggested server name: {broker_suggested_server_name}") + else: + logger.error(f"No {PROPERTY_K_MQTT_CLIENT_ID} in UserProperties") + self.server_control_topic = mqtt_topic.get_server_control_topic(self.server_id, self.server_name) + ## Subscribe to the server control topic + client.subscribe(self.server_control_topic, QOS) + ## Reister the server on the presence topic + online_msg = types.JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + method="notifications/server/online", + params={"description": self.server_description, "meta": self.server_meta}, + ) + ) + self.publish_json_rpc_message(self.get_presence_topic(), message=online_msg, retain=True) + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") + match msg.topic: + case str() as t if t == self.server_control_topic: + self.handle_server_contorl_message(msg) + case str() as t if t.startswith(mqtt_topic.CLIENT_CAPABILITY_CHANGE_BASE): + self.handle_client_capability_change_message(msg) + case str() as t if t.startswith(mqtt_topic.RPC_BASE): + self.handle_rpc_message(msg) + case str() as t if t.startswith(mqtt_topic.CLIENT_PRESENCE_BASE): + self.handle_client_presence_message(msg) + case _: + logger.error(f"Received message on unexpected topic: {msg.topic}") + + def _on_subscribe( + self, + client: mqtt.Client, + userdata: Any, + mid: int, + reason_code_list: list[ReasonCode], + properties: Properties | None, + ): + if mid in userdata.get("pending_subs", {}): + mcp_client_id, msg, rpc_msg_id = userdata["pending_subs"].pop(mid) + ## only create session if all topic subscribed successfully + if all(rc.value == QOS for rc in reason_code_list): + logger.debug(f"Subscribed to topics for mcp_client_id: {mcp_client_id}") + anyio_from_thread.run(self.create_session, mcp_client_id, msg) + else: + logger.error( + f"Failed to subscribe to topics for mcp_client_id: {mcp_client_id}, " + f"reason_codes: {reason_code_list}" + ) + err = types.JSONRPCError( + jsonrpc="2.0", + id=rpc_msg_id, + error=types.ErrorData(code=types.INTERNAL_ERROR, message="Failed to subscribe to client topics"), + ) + self.publish_json_rpc_message( + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name), + message=types.JSONRPCMessage(err), + ) + + def handle_server_contorl_message(self, msg: mqtt.MQTTMessage): + if msg.properties and hasattr(msg.properties, "UserProperty"): + user_properties: dict[str, Any] = dict(msg.properties.UserProperty) # type: ignore + if PROPERTY_K_MQTT_CLIENT_ID in user_properties: + mcp_client_id = user_properties[PROPERTY_K_MQTT_CLIENT_ID] + if mcp_client_id in self._read_stream_writers: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + else: + self.maybe_subscribe_to_client(mcp_client_id, msg) + else: + logger.error(f"No {PROPERTY_K_MQTT_CLIENT_ID} in UserProperties") + else: + logger.error("No UserProperties in control message") + + def handle_client_capability_change_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[-1] + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + + def handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[1] + try: + json_msg = json.loads(msg.payload.decode()) + if "method" in json_msg: + if json_msg["method"] == "notifications/disconnected": + stream = self._read_stream_writers[mcp_client_id] + anyio_from_thread.run(stream.aclose) + logger.debug(f"Closed read_stream for mcp_client_id: {mcp_client_id}") + return + else: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + else: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + except json.JSONDecodeError: + logger.error(f"Invalid JSON in RPC message for mcp_client_id: {mcp_client_id}") + + def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[-1] + if mcp_client_id not in self._read_stream_writers: + logger.error(f"No session for mcp_client_id: {mcp_client_id}") + return + try: + json_msg = json.loads(msg.payload.decode()) + if "method" in json_msg: + if json_msg["method"] == "notifications/disconnected": + stream = self._read_stream_writers[mcp_client_id] + anyio_from_thread.run(stream.aclose) + logger.debug(f"Closed read_stream for mcp_client_id: {mcp_client_id}") + else: + logger.error(f"Unknown method in presence message for mcp_client_id: {mcp_client_id}") + else: + logger.error(f"No method in presence message for mcp_client_id: {mcp_client_id}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON in presence message for mcp_client_id: {mcp_client_id}") + + async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + ## Streams are used to communicate between the MqttTransportServer and the MCPSession: + ## 1. (msg) --> MqttBroker --> MqttTransportServer --[raw_read_stream]--> conversion + ## --[session_read_stream]--> MCPSession + ## 2. MqttBroker <-- MqttTransportServer <--[raw_write_stream_reader]-- conversion + ## <--[session_write_stream]-- MCPSession <-- (msg) + + # Create raw MQTT streams (JSONRPCMessage) + raw_read_stream: RcvStreamEx + raw_read_stream_writer: SndStreamEX + raw_write_stream: SndStream + raw_write_stream_reader: RcvStream + raw_read_stream_writer, raw_read_stream = anyio.create_memory_object_stream(0) # type: ignore + raw_write_stream, raw_write_stream_reader = anyio.create_memory_object_stream(0) # type: ignore + + # Create session streams (SessionMessage) + session_read_stream_writer: SessionSndStreamEx + session_read_stream: SessionRcvStreamEx + session_write_stream: SessionSndStream + session_write_stream_reader: SessionRcvStream + session_read_stream_writer, session_read_stream = anyio.create_memory_object_stream(0) # type: ignore + session_write_stream, session_write_stream_reader = anyio.create_memory_object_stream(0) # type: ignore + + self._read_stream_writers[mcp_client_id] = raw_read_stream_writer + + # Start conversion tasks + self._task_group.start_soon(self._convert_jsonrpc_to_session, raw_read_stream, session_read_stream_writer) + self._task_group.start_soon(self._convert_session_to_jsonrpc, session_write_stream_reader, raw_write_stream) + + # Start session with SessionMessage streams + self._task_group.start_soon(self.server_session_run, session_read_stream, session_write_stream) + self._task_group.start_soon(self._receieved_from_session, mcp_client_id, raw_write_stream_reader) + logger.debug(f"Created new session for mcp_client_id: {mcp_client_id}") + await self._send_message_to_session(mcp_client_id, msg) + + def maybe_subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + try: + json_msg = json.loads(msg.payload.decode()) + if "id" in json_msg: + rpc_msg_id = json_msg["id"] + self.subscribe_to_client(mcp_client_id, msg, rpc_msg_id) + else: + logger.error(f"No id in control message for mcp_client_id: {mcp_client_id}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON in control message for mcp_client_id: {mcp_client_id}") + return + + def subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage, rcp_msg_id: Any): + topic_filters = [ + (mqtt_topic.get_client_presence_topic(mcp_client_id), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_client_capability_change_topic(mcp_client_id), SubscribeOptions(qos=QOS)), + ( + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name), + SubscribeOptions(qos=QOS, noLocal=True), + ), + ] + ret, mid = self.client.subscribe(topic=topic_filters) + if ret != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to subscribe to topics for mcp_client_id: {mcp_client_id}") + return + userdata = self.client.user_data_get() + pending_subs = userdata.get("pending_subs", {}) + pending_subs[mid] = (mcp_client_id, msg, rcp_msg_id) + userdata["pending_subs"] = pending_subs + + async def _send_message_to_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + payload = msg.payload.decode() + if mcp_client_id not in self._read_stream_writers: + logger.error(f"No session for mcp_client_id: {mcp_client_id}") + return + read_stream_writer = self._read_stream_writers[mcp_client_id] + try: + message = types.JSONRPCMessage.model_validate_json(payload) + logger.debug(f"Sending msg to session for mcp_client_id: {mcp_client_id}, msg: {message}") + with anyio.fail_after(3): + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Failed to send msg to session for mcp_client_id: {mcp_client_id}, exception: {exc}") + traceback.print_exc() + ## TODO: the session does not handle exceptions for now + # await read_stream_writer.send(exc) + + async def _receieved_from_session(self, mcp_client_id: str, write_stream_reader: RcvStream): + async with write_stream_reader: + async for msg in write_stream_reader: + logger.debug(f"Got msg from session for mcp_client_id: {mcp_client_id}, msg: {msg}") + match msg.model_dump(): + case {"method": "notifications/resources/updated"}: + logger.warning("Resource updates should not be sent from the session. Ignoring.") + case {"method": method} if method.endswith("/list_changed"): + logger.warning("Resource updates should not be sent from the session. Ignoring.") + case _: + topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name) + self.publish_json_rpc_message(topic, message=msg) + # cleanup + if mcp_client_id in self._read_stream_writers: + logger.debug(f"Removing session for mcp_client_id: {mcp_client_id}") + stream = self._read_stream_writers.pop(mcp_client_id) + await stream.aclose() + + # unsubscribe from the client topics + logger.debug(f"Unsubscribing from topics for mcp_client_id: {mcp_client_id}") + topic_filters = [ + mqtt_topic.get_client_presence_topic(mcp_client_id), + mqtt_topic.get_client_capability_change_topic(mcp_client_id), + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name), + ] + self.client.unsubscribe(topic=topic_filters) + + logger.debug(f"Session stream closed for mcp_client_id: {mcp_client_id}") + + async def _convert_jsonrpc_to_session( + self, + jsonrpc_stream: RcvStreamEx, + session_writer: SessionSndStreamEx, + ) -> None: + """Convert JSONRPCMessage stream to SessionMessage stream.""" + async with jsonrpc_stream, session_writer: + async for message in jsonrpc_stream: + if isinstance(message, Exception): + await session_writer.send(message) + else: + session_message = SessionMessage(message=message) + await session_writer.send(session_message) + + async def _convert_session_to_jsonrpc( + self, + session_stream: SessionRcvStream, + jsonrpc_writer: SndStream, + ) -> None: + """Convert SessionMessage stream to JSONRPCMessage stream.""" + async with session_stream, jsonrpc_writer: + async for session_message in session_stream: + await jsonrpc_writer.send(session_message.message) + + +async def start_mqtt( + server_session_run: ServerSessionRun, + server_name: str, + server_description: str, + server_meta: dict[str, Any], + client_id: str | None = None, + mqtt_options: MqttOptions = MqttOptions(), +): + async with MqttTransportServer( + server_session_run, + server_name=server_name, + server_description=server_description, + server_meta=server_meta, + client_id=client_id, + mqtt_options=mqtt_options, + ) as mqtt_trans: + + def start(): + mqtt_trans.connect() + mqtt_trans.client.loop_forever() + + try: + await anyio_to_thread.run_sync(start) + except asyncio.CancelledError: + logger.debug("MQTT transport (MCP server) got cancelled") + except Exception as exc: + logger.error(f"MQTT transport (MCP server) failed with exception: {exc}") + + +def validate_server_name(name: str): + if "/" not in name: + raise ValueError(f"Invalid server name: {name}, must contain a '/'") + elif ("+" in name) or ("#" in name): + raise ValueError(f"Invalid server name: {name}, must not contain '+' or '#'") + elif name[0] == "/": + raise ValueError(f"Invalid server name: {name}, must not start with '/'") diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py new file mode 100644 index 0000000000..f674f5eeb6 --- /dev/null +++ b/src/mcp/shared/mqtt.py @@ -0,0 +1,256 @@ +""" +MQTT Transport Base Module + +""" + +import logging +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from types import TracebackType +from typing import Any, Literal, TypeAlias + +import anyio +import anyio.from_thread as anyio_from_thread +import paho.mqtt.client as mqtt +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from paho.mqtt.enums import CallbackAPIVersion +from paho.mqtt.packettypes import PacketTypes +from paho.mqtt.properties import Properties +from paho.mqtt.reasoncodes import ReasonCode +from pydantic import BaseModel, SecretStr +from typing_extensions import Self + +import mcp.types as types + +DEFAULT_LOG_FORMAT = "%(asctime)s - %(message)s" +QOS = 0 +MCP_SERVER_NAME = "MCP-SERVER-NAME" +MCP_SERVER_NAME_FILTERS = "MCP-SERVER-NAME-FILTERS" +MCP_AUTH_ROLE = "MCP-AUTH-ROLE" +PROPERTY_K_MCP_COMPONENT = "MCP-COMPONENT-TYPE" +PROPERTY_K_MQTT_CLIENT_ID = "MCP-MQTT-CLIENT-ID" +logger = logging.getLogger(__name__) + +RcvStream: TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream: TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx: TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX: TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerRun: TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + + +class MqttOptions(BaseModel): + host: str = "localhost" + port: int = 1883 + transport: Literal["tcp", "websockets", "unix"] = "tcp" + keepalive: int = 60 + bind_address: str = "" + bind_port: int = 0 + username: str | None = None + password: SecretStr | None = None + tls_enabled: bool = False + tls_version: int | None = None + tls_insecure: bool = False + ca_certs: str | None = None + certfile: str | None = None + keyfile: str | None = None + ciphers: str | None = None + keyfile_password: str | None = None + alpn_protocols: list[str] | None = None + websocket_path: str = "/mqtt" + websocket_headers: dict[str, str] | None = None + verify_connack_properties: bool = True + + +class MqttTransportBase(ABC): + _read_stream_writers: dict[str, SndStreamEX] + + def __init__( + self, + mcp_component_type: Literal["mcp-client", "mcp-server"], + mqtt_clientid: str | None = None, + mqtt_options: MqttOptions = MqttOptions(), + disconnected_msg: types.JSONRPCMessage | None = None, + disconnected_msg_retain: bool = True, + ): + self._read_stream_writers = {} + self._last_connect_fail_reason = None + self.mqtt_clientid = mqtt_clientid + self.mcp_component_type = mcp_component_type + self.mqtt_options = mqtt_options + self.disconnected_msg = disconnected_msg + self.disconnected_msg_retain = disconnected_msg_retain + client = mqtt.Client( + callback_api_version=CallbackAPIVersion.VERSION2, + client_id=mqtt_clientid, + protocol=mqtt.MQTTv5, + userdata={}, + transport=mqtt_options.transport, + reconnect_on_failure=True, + ) + client.reconnect_delay_set(min_delay=1, max_delay=120) + client.username_pw_set( + mqtt_options.username, mqtt_options.password.get_secret_value() if mqtt_options.password else None + ) + if mqtt_options.tls_enabled: + client.tls_set( # type: ignore + ca_certs=mqtt_options.ca_certs, + certfile=mqtt_options.certfile, + keyfile=mqtt_options.keyfile, + tls_version=mqtt_options.tls_version, + ciphers=mqtt_options.ciphers, + keyfile_password=mqtt_options.keyfile_password, + alpn_protocols=mqtt_options.alpn_protocols, + ) + client.tls_insecure_set(mqtt_options.tls_insecure) + if mqtt_options.transport == "websockets": + client.ws_set_options(path=mqtt_options.websocket_path, headers=mqtt_options.websocket_headers) + client.on_connect = self._on_connect + client.on_message = self._on_message + client.on_subscribe = self._on_subscribe + ## We need to set an empty will message to clean the retained presence + ## message when the MCP server goes offline. + ## Note that if the broker suggested a new server name, it's the broker's + ## responsibility to clean the retained presence message and send the + ## last will message on the changed presence topic. + client.will_set( + topic=self.get_presence_topic(), + payload=disconnected_msg.model_dump_json() if disconnected_msg else None, + qos=QOS, + retain=disconnected_msg_retain, + properties=self.get_publish_properties(), + ) + logger.info( + f"MCP component type: {mcp_component_type}, MQTT clientid: {mqtt_clientid}, MQTT settings: {mqtt_options}" + ) + self.client = client + + async def __aenter__(self) -> Self: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.stop_mqtt() + self._task_group.cancel_scope.cancel() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + def _on_connect( + self, + client: mqtt.Client, + userdata: Any, + connect_flags: mqtt.ConnectFlags, + reason_code: ReasonCode, + properties: Properties | None, + ): + if reason_code == 0: + logger.debug(f"Connected to MQTT broker_host at {self.mqtt_options.host}:{self.mqtt_options.port}") + if self.mqtt_options.verify_connack_properties: + self.assert_property(properties, "RetainAvailable", 1) + self.assert_property(properties, "WildcardSubscriptionAvailable", 1) + else: + self._last_connect_fail_reason = reason_code + logger.error(f"Failed to connect, return code {reason_code}") + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + pass + + def _on_subscribe( + self, + client: mqtt.Client, + userdata: Any, + mid: int, + reason_code_list: list[ReasonCode], + properties: Properties | None, + ): + pass + + def is_connected(self) -> bool: + return self.client.is_connected() + + def get_last_connect_fail_reason(self) -> ReasonCode | None: + return self._last_connect_fail_reason + + def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage | None, retain: bool = False): + props = self.get_publish_properties() + payload = message.model_dump_json(by_alias=True, exclude_none=True) if message else None + result = self.client.publish(topic=topic, payload=payload, qos=QOS, retain=retain, properties=props) + return result + + def get_publish_properties(self): + props = Properties(PacketTypes.PUBLISH) + props.UserProperty = [ + (PROPERTY_K_MCP_COMPONENT, self.mcp_component_type), + (PROPERTY_K_MQTT_CLIENT_ID, self.mqtt_clientid), + ] + return props + + def connect(self): + logger.debug("Setting up MQTT connection") + props = Properties(PacketTypes.CONNECT) + props.UserProperty = [(PROPERTY_K_MCP_COMPONENT, self.mcp_component_type)] + return self.client.connect( + host=self.mqtt_options.host, + port=self.mqtt_options.port, + keepalive=self.mqtt_options.keepalive, + bind_address=self.mqtt_options.bind_address, + bind_port=self.mqtt_options.bind_port, + clean_start=True, + properties=props, + ) + + def assert_property(self, properties: Properties | None, property_name: str, expected_value: Any): + if get_property(properties, property_name) == expected_value: + pass + else: + anyio_from_thread.run(self.stop_mqtt) + raise ValueError(f"{property_name} not available") + + @abstractmethod + def get_presence_topic(self) -> str: + pass + + async def stop_mqtt(self): + self.publish_json_rpc_message( + self.get_presence_topic(), message=self.disconnected_msg, retain=self.disconnected_msg_retain + ) + self.client.disconnect() + self.client.loop_stop() + for stream in self._read_stream_writers.values(): + await stream.aclose() + self._read_stream_writers = {} + logger.debug("Disconnected from MQTT broker_host") + + +def get_property(properties: Properties | None, property_name: str): + if properties and hasattr(properties, property_name): + return getattr(properties, property_name) + else: + return False + + +def configure_logging( + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", + format: str = DEFAULT_LOG_FORMAT, +) -> None: + handlers: list[logging.Handler] = [] + try: + from rich.console import Console + from rich.logging import RichHandler + + handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) + except ImportError: + pass + + if not handlers: + handlers.append(logging.StreamHandler()) + + logging.basicConfig( + level=level, + format=format, + handlers=handlers, + ) diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py new file mode 100644 index 0000000000..8b2d9c5719 --- /dev/null +++ b/src/mcp/shared/mqtt_topic.py @@ -0,0 +1,30 @@ +SERVER_CONTROL_BASE: str = "$mcp-server" +SERVER_CAPABILITY_CHANGE_BASE: str = "$mcp-server/capability" +SERVER_PRESENCE_BASE: str = "$mcp-server/presence" +CLIENT_PRESENCE_BASE: str = "$mcp-client/presence" +CLIENT_CAPABILITY_CHANGE_BASE: str = "$mcp-client/capability" +RPC_BASE: str = "$mcp-rpc" + + +def get_server_control_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_CONTROL_BASE}/{server_id}/{server_name}" + + +def get_server_capability_change_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_CAPABILITY_CHANGE_BASE}/{server_id}/{server_name}" + + +def get_server_presence_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_PRESENCE_BASE}/{server_id}/{server_name}" + + +def get_client_presence_topic(mcp_clientid: str) -> str: + return f"{CLIENT_PRESENCE_BASE}/{mcp_clientid}" + + +def get_client_capability_change_topic(mcp_clientid: str) -> str: + return f"{CLIENT_CAPABILITY_CHANGE_BASE}/{mcp_clientid}" + + +def get_rpc_topic(mcp_clientid: str, server_id: str, server_name: str) -> str: + return f"{RPC_BASE}/{mcp_clientid}/{server_id}/{server_name}" diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8bc..fd79d7b86a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from datetime import timedelta from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -196,6 +196,8 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._receive_loop_alive = None + self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -215,7 +217,9 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + if self._receive_loop_alive: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + return False async def send_request( self, @@ -269,7 +273,8 @@ async def send_request( try: with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() + async with response_stream_reader: + response_or_error = await response_stream_reader.receive() except TimeoutError: raise McpError( ErrorData( @@ -329,10 +334,15 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er await self._write_stream.send(session_message) async def _receive_loop(self) -> None: - async with ( - self._read_stream, - self._write_stream, - ): + @asynccontextmanager + async def receive_loop_status(): + try: + self._receive_loop_alive = True + yield + finally: + self._receive_loop_alive = False + + async with self._read_stream, self._write_stream, self._exit_stack, receive_loop_status(): try: async for message in self._read_stream: if isinstance(message, Exception): diff --git a/uv.lock b/uv.lock index 7979f9aab4..fad5c16145 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" [manifest] @@ -608,6 +608,7 @@ dependencies = [ { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, + { name = "paho-mqtt" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "python-multipart" }, @@ -639,8 +640,11 @@ dev = [ { name = "pytest-flakefinder" }, { name = "pytest-pretty" }, { name = "pytest-xdist" }, + { name = "python-dotenv" }, { name = "ruff" }, { name = "trio" }, + { name = "typer" }, + { name = "websockets" }, ] docs = [ { name = "mkdocs" }, @@ -655,6 +659,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.1" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, + { name = "paho-mqtt", specifier = ">=2.1.0" }, { name = "pydantic", specifier = ">=2.11.0,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, @@ -679,8 +684,11 @@ dev = [ { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, + { name = "python-dotenv", specifier = ">=1.1.1" }, { name = "ruff", specifier = ">=0.8.5" }, { name = "trio", specifier = ">=0.26.2" }, + { name = "typer", specifier = ">=0.17.4" }, + { name = "websockets", specifier = ">=15.0.1" }, ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, @@ -1114,6 +1122,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746, upload-time = "2024-08-25T14:17:22.55Z" }, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/39/15/0a6214e76d4d32e7f663b109cf71fb22561c2be0f701d67f93950cd40542/paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834", size = 148848, upload-time = "2024-04-29T19:52:55.591Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/cb/00451c3cf31790287768bb12c6bec834f5d292eaf3022afc88e14b8afc94/paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee", size = 67219, upload-time = "2024-04-29T19:52:48.345Z" }, +] + [[package]] name = "pathspec" version = "0.12.1"