Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions Server/src/transport/unity_instance_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,93 @@ def clear_active_instance(self, ctx) -> None:
with self._lock:
self._active_by_key.pop(key, None)

async def _maybe_autoselect_instance(self, ctx) -> str | None:
"""Auto-select sole Unity instance when no active instance is set."""
try:
from transport.unity_transport import _current_transport

transport = _current_transport()
if PluginHub.is_configured():
try:
sessions_data = await PluginHub.get_sessions()
sessions = sessions_data.sessions or {}
ids: list[str] = []
for session_info in sessions.values():
project = getattr(session_info, "project", None) or "Unknown"
hash_value = getattr(session_info, "hash", None)
if hash_value:
ids.append(f"{project}@{hash_value}")
if len(ids) == 1:
chosen = ids[0]
self.set_active_instance(ctx, chosen)
logger.info(
"Auto-selected sole Unity instance via PluginHub: %s",
chosen,
)
return chosen
except (ConnectionError, ValueError, KeyError, TimeoutError, AttributeError) as exc:
logger.debug(
"PluginHub auto-select probe failed (%s); falling back to stdio",
type(exc).__name__,
exc_info=True,
)
except Exception as exc:
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
raise
logger.debug(
"PluginHub auto-select probe failed with unexpected error (%s); falling back to stdio",
type(exc).__name__,
exc_info=True,
)

if transport != "http":
try:
from transport.legacy.unity_connection import get_unity_connection_pool

pool = get_unity_connection_pool()
instances = pool.discover_all_instances(force_refresh=True)
ids = [getattr(inst, "id", None) for inst in instances]
ids = [inst_id for inst_id in ids if inst_id]
if len(ids) == 1:
chosen = ids[0]
self.set_active_instance(ctx, chosen)
logger.info(
"Auto-selected sole Unity instance via stdio discovery: %s",
chosen,
)
return chosen
except (ConnectionError, ValueError, KeyError, TimeoutError, AttributeError) as exc:
logger.debug(
"Stdio auto-select probe failed (%s)",
type(exc).__name__,
exc_info=True,
)
except Exception as exc:
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
raise
logger.debug(
"Stdio auto-select probe failed with unexpected error (%s)",
type(exc).__name__,
exc_info=True,
)
except Exception as exc:
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
raise
logger.debug(
"Auto-select path encountered an unexpected error (%s)",
type(exc).__name__,
exc_info=True,
)

return None

async def _inject_unity_instance(self, context: MiddlewareContext) -> None:
"""Inject active Unity instance into context if available."""
ctx = context.fastmcp_context

active_instance = self.get_active_instance(ctx)
if not active_instance:
active_instance = await self._maybe_autoselect_instance(ctx)
if active_instance:
# If using HTTP transport (PluginHub configured), validate session
# But for stdio transport (no PluginHub needed or maybe partially configured),
Expand Down
39 changes: 39 additions & 0 deletions Server/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
SERVER_ROOT = Path(__file__).resolve().parents[2]
if str(SERVER_ROOT) not in sys.path:
sys.path.insert(0, str(SERVER_ROOT))
SERVER_SRC = SERVER_ROOT / "src"
if str(SERVER_SRC) not in sys.path:
sys.path.insert(0, str(SERVER_SRC))

# Ensure telemetry is disabled during test collection and execution to avoid
# any background network or thread startup that could slow or block pytest.
Expand Down Expand Up @@ -86,3 +89,39 @@ class _DummyMiddlewareContext:
fastmcp_server.middleware = fastmcp_server_middleware
sys.modules.setdefault("fastmcp.server", fastmcp_server)
sys.modules.setdefault("fastmcp.server.middleware", fastmcp_server_middleware)

# Stub minimal starlette modules to avoid optional dependency imports.
starlette = types.ModuleType("starlette")
starlette_endpoints = types.ModuleType("starlette.endpoints")
starlette_websockets = types.ModuleType("starlette.websockets")
starlette_requests = types.ModuleType("starlette.requests")
starlette_responses = types.ModuleType("starlette.responses")


class _DummyWebSocketEndpoint:
pass


class _DummyWebSocket:
pass


class _DummyRequest:
pass


class _DummyJSONResponse:
def __init__(self, *args, **kwargs):
pass


starlette_endpoints.WebSocketEndpoint = _DummyWebSocketEndpoint
starlette_websockets.WebSocket = _DummyWebSocket
starlette_requests.Request = _DummyRequest
starlette_responses.JSONResponse = _DummyJSONResponse

sys.modules.setdefault("starlette", starlette)
sys.modules.setdefault("starlette.endpoints", starlette_endpoints)
sys.modules.setdefault("starlette.websockets", starlette_websockets)
sys.modules.setdefault("starlette.requests", starlette_requests)
sys.modules.setdefault("starlette.responses", starlette_responses)
108 changes: 108 additions & 0 deletions Server/tests/integration/test_instance_autoselect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import asyncio
import sys
import types
from types import SimpleNamespace

from .test_helpers import DummyContext


class DummyMiddlewareContext:
def __init__(self, ctx):
self.fastmcp_context = ctx


def test_auto_selects_single_instance_via_pluginhub(monkeypatch):
plugin_hub = types.ModuleType("transport.plugin_hub")

class PluginHub:
@classmethod
def is_configured(cls) -> bool:
return True

@classmethod
async def get_sessions(cls):
raise AssertionError("get_sessions should be stubbed in test")

plugin_hub.PluginHub = PluginHub
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
unity_transport = types.ModuleType("transport.unity_transport")
unity_transport._current_transport = lambda: "http"
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)

from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
assert ImportedPluginHub is plugin_hub.PluginHub

monkeypatch.setenv("UNITY_MCP_TRANSPORT", "http")

middleware = UnityInstanceMiddleware()
ctx = DummyContext()
ctx.client_id = "client-1"
middleware_context = DummyMiddlewareContext(ctx)

call_count = {"sessions": 0}

async def fake_get_sessions():
call_count["sessions"] += 1
return SimpleNamespace(
sessions={
"session-1": SimpleNamespace(project="Ramble", hash="deadbeef"),
}
)

monkeypatch.setattr(plugin_hub.PluginHub, "get_sessions", fake_get_sessions)

selected = asyncio.run(middleware._maybe_autoselect_instance(ctx))

assert selected == "Ramble@deadbeef"
assert middleware.get_active_instance(ctx) == "Ramble@deadbeef"
assert call_count["sessions"] == 1

asyncio.run(middleware._inject_unity_instance(middleware_context))

assert ctx.get_state("unity_instance") == "Ramble@deadbeef"
assert call_count["sessions"] == 1


def test_auto_selects_single_instance_via_stdio(monkeypatch):
plugin_hub = types.ModuleType("transport.plugin_hub")

class PluginHub:
@classmethod
def is_configured(cls) -> bool:
return False

plugin_hub.PluginHub = PluginHub
monkeypatch.setitem(sys.modules, "transport.plugin_hub", plugin_hub)
unity_transport = types.ModuleType("transport.unity_transport")
unity_transport._current_transport = lambda: "stdio"
monkeypatch.setitem(sys.modules, "transport.unity_transport", unity_transport)
monkeypatch.delitem(sys.modules, "transport.unity_instance_middleware", raising=False)

from transport.unity_instance_middleware import UnityInstanceMiddleware, PluginHub as ImportedPluginHub
assert ImportedPluginHub is plugin_hub.PluginHub

monkeypatch.setenv("UNITY_MCP_TRANSPORT", "stdio")

middleware = UnityInstanceMiddleware()
ctx = DummyContext()
ctx.client_id = "client-1"
middleware_context = DummyMiddlewareContext(ctx)

class PoolStub:
def discover_all_instances(self, force_refresh=False):
assert force_refresh is True
return [SimpleNamespace(id="UnityMCPTests@cc8756d4")]

unity_connection = types.ModuleType("transport.legacy.unity_connection")
unity_connection.get_unity_connection_pool = lambda: PoolStub()
monkeypatch.setitem(sys.modules, "transport.legacy.unity_connection", unity_connection)

selected = asyncio.run(middleware._maybe_autoselect_instance(ctx))

assert selected == "UnityMCPTests@cc8756d4"
assert middleware.get_active_instance(ctx) == "UnityMCPTests@cc8756d4"

asyncio.run(middleware._inject_unity_instance(middleware_context))

assert ctx.get_state("unity_instance") == "UnityMCPTests@cc8756d4"