diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index 1df1984e1..0e8220ff3 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -83,11 +83,93 @@ def clear_active_instance(self, ctx) -> None: with self._lock: self._active_by_key.pop(key, None) + async def _maybe_autoselect_instance(self, ctx) -> str | None: + """Auto-select sole Unity instance when no active instance is set.""" + try: + from transport.unity_transport import _current_transport + + transport = _current_transport() + if PluginHub.is_configured(): + try: + sessions_data = await PluginHub.get_sessions() + sessions = sessions_data.sessions or {} + ids: list[str] = [] + for session_info in sessions.values(): + project = getattr(session_info, "project", None) or "Unknown" + hash_value = getattr(session_info, "hash", None) + if hash_value: + ids.append(f"{project}@{hash_value}") + if len(ids) == 1: + chosen = ids[0] + self.set_active_instance(ctx, chosen) + logger.info( + "Auto-selected sole Unity instance via PluginHub: %s", + chosen, + ) + return chosen + except (ConnectionError, ValueError, KeyError, TimeoutError, AttributeError) as exc: + logger.debug( + "PluginHub auto-select probe failed (%s); falling back to stdio", + type(exc).__name__, + exc_info=True, + ) + except Exception as exc: + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise + logger.debug( + "PluginHub auto-select probe failed with unexpected error (%s); falling back to stdio", + type(exc).__name__, + exc_info=True, + ) + + if transport != "http": + try: + from transport.legacy.unity_connection import get_unity_connection_pool + + pool = get_unity_connection_pool() + instances = pool.discover_all_instances(force_refresh=True) + ids = [getattr(inst, "id", None) for inst in instances] + ids = [inst_id for inst_id in ids if inst_id] + if len(ids) == 1: + chosen = ids[0] + self.set_active_instance(ctx, chosen) + logger.info( + "Auto-selected sole Unity instance via stdio discovery: %s", + chosen, + ) + return chosen + except (ConnectionError, ValueError, KeyError, TimeoutError, AttributeError) as exc: + logger.debug( + "Stdio auto-select probe failed (%s)", + type(exc).__name__, + exc_info=True, + ) + except Exception as exc: + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise + logger.debug( + "Stdio auto-select probe failed with unexpected error (%s)", + type(exc).__name__, + exc_info=True, + ) + except Exception as exc: + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise + logger.debug( + "Auto-select path encountered an unexpected error (%s)", + type(exc).__name__, + exc_info=True, + ) + + return None + async def _inject_unity_instance(self, context: MiddlewareContext) -> None: """Inject active Unity instance into context if available.""" ctx = context.fastmcp_context active_instance = self.get_active_instance(ctx) + if not active_instance: + active_instance = await self._maybe_autoselect_instance(ctx) if active_instance: # If using HTTP transport (PluginHub configured), validate session # But for stdio transport (no PluginHub needed or maybe partially configured), diff --git a/Server/tests/integration/conftest.py b/Server/tests/integration/conftest.py index 798c6a605..1f409bbcb 100644 --- a/Server/tests/integration/conftest.py +++ b/Server/tests/integration/conftest.py @@ -6,6 +6,9 @@ SERVER_ROOT = Path(__file__).resolve().parents[2] if str(SERVER_ROOT) not in sys.path: sys.path.insert(0, str(SERVER_ROOT)) +SERVER_SRC = SERVER_ROOT / "src" +if str(SERVER_SRC) not in sys.path: + sys.path.insert(0, str(SERVER_SRC)) # Ensure telemetry is disabled during test collection and execution to avoid # any background network or thread startup that could slow or block pytest. @@ -86,3 +89,39 @@ class _DummyMiddlewareContext: fastmcp_server.middleware = fastmcp_server_middleware sys.modules.setdefault("fastmcp.server", fastmcp_server) sys.modules.setdefault("fastmcp.server.middleware", fastmcp_server_middleware) + +# Stub minimal starlette modules to avoid optional dependency imports. +starlette = types.ModuleType("starlette") +starlette_endpoints = types.ModuleType("starlette.endpoints") +starlette_websockets = types.ModuleType("starlette.websockets") +starlette_requests = types.ModuleType("starlette.requests") +starlette_responses = types.ModuleType("starlette.responses") + + +class _DummyWebSocketEndpoint: + pass + + +class _DummyWebSocket: + pass + + +class _DummyRequest: + pass + + +class _DummyJSONResponse: + def __init__(self, *args, **kwargs): + pass + + +starlette_endpoints.WebSocketEndpoint = _DummyWebSocketEndpoint +starlette_websockets.WebSocket = _DummyWebSocket +starlette_requests.Request = _DummyRequest +starlette_responses.JSONResponse = _DummyJSONResponse + +sys.modules.setdefault("starlette", starlette) +sys.modules.setdefault("starlette.endpoints", starlette_endpoints) +sys.modules.setdefault("starlette.websockets", starlette_websockets) +sys.modules.setdefault("starlette.requests", starlette_requests) +sys.modules.setdefault("starlette.responses", starlette_responses) diff --git a/Server/tests/integration/test_instance_autoselect.py b/Server/tests/integration/test_instance_autoselect.py new file mode 100644 index 000000000..0838ed207 --- /dev/null +++ b/Server/tests/integration/test_instance_autoselect.py @@ -0,0 +1,108 @@ +import asyncio +import sys +import types +from types import SimpleNamespace + +from .test_helpers import DummyContext + + +class DummyMiddlewareContext: + def __init__(self, ctx): + self.fastmcp_context = ctx + + +def test_auto_selects_single_instance_via_pluginhub(monkeypatch): + plugin_hub = types.ModuleType("transport.plugin_hub") + + class PluginHub: + @classmethod + def is_configured(cls) -> bool: + return True + + @classmethod + async def get_sessions(cls): + raise AssertionError("get_sessions should be stubbed in test") + + plugin_hub.PluginHub = PluginHub + monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub) + unity_transport = types.ModuleType("transport.unity_transport") + unity_transport._current_transport = lambda: "http" + monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport) + monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False) + + from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub + assert ImportedPluginHub is plugin_hub.PluginHub + + monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http") + + middleware = UnityInstanceMiddleware() + ctx = DummyContext() + ctx.client_id = "client-1" + middleware_context = DummyMiddlewareContext(ctx) + + call_count = {"sessions": 0} + + async def fake_get_sessions(): + call_count["sessions"] += 1 + return SimpleNamespace( + sessions={ + "session-1": SimpleNamespace(project="Ramble", hash="deadbeef"), + } + ) + + monkeypatch.setattr(plugin_hub.PluginHub, "get_sessions", fake_get_sessions) + + selected = asyncio.run(middleware._maybe_autoselect_instance(ctx)) + + assert selected == "Ramble@deadbeef" + assert middleware.get_active_instance(ctx) == "Ramble@deadbeef" + assert call_count["sessions"] == 1 + + asyncio.run(middleware._inject_unity_instance(middleware_context)) + + assert ctx.get_state("unity_instance") == "Ramble@deadbeef" + assert call_count["sessions"] == 1 + + +def test_auto_selects_single_instance_via_stdio(monkeypatch): + plugin_hub = types.ModuleType("transport.plugin_hub") + + class PluginHub: + @classmethod + def is_configured(cls) -> bool: + return False + + plugin_hub.PluginHub = PluginHub + monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub) + unity_transport = types.ModuleType("transport.unity_transport") + unity_transport._current_transport = lambda: "stdio" + monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport) + monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False) + + from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub + assert ImportedPluginHub is plugin_hub.PluginHub + + monkeypatch.setenv("UNITY_MCP_TRANSPORT", "stdio") + + middleware = UnityInstanceMiddleware() + ctx = DummyContext() + ctx.client_id = "client-1" + middleware_context = DummyMiddlewareContext(ctx) + + class PoolStub: + def discover_all_instances(self, force_refresh=False): + assert force_refresh is True + return [SimpleNamespace(id="UnityMCPTests@cc8756d4")] + + unity_connection = types.ModuleType("transport.legacy.unity_connection") + unity_connection.get_unity_connection_pool = lambda: PoolStub() + monkeypatch.setitem(sys.modules, "transport.legacy.unity_connection", unity_connection) + + selected = asyncio.run(middleware._maybe_autoselect_instance(ctx)) + + assert selected == "UnityMCPTests@cc8756d4" + assert middleware.get_active_instance(ctx) == "UnityMCPTests@cc8756d4" + + asyncio.run(middleware._inject_unity_instance(middleware_context)) + + assert ctx.get_state("unity_instance") == "UnityMCPTests@cc8756d4"