Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/guides/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ orchestrators can start/stop keep-alive jobs and inspect GPU state.

- You run KeepGPU from an agent (LangChain, custom orchestrator, etc.) instead of a shell.
- You want to keep GPUs alive on a remote box over TCP rather than stdio.
- You need a quick way to list GPU utilization/memory via the same interface.
- You need a quick way to list GPU utilization/memory by way of the same interface.

## Quick start

Expand All @@ -33,7 +33,7 @@ Supported methods:
- `start_keep(gpu_ids?, vram?, interval?, busy_threshold?, job_id?)`
- `stop_keep(job_id?)` (omit `job_id` to stop all)
- `status(job_id?)` (omit `job_id` to list active jobs)
- `list_gpus()` (detailed info via NVML/ROCm SMI/torch)
- `list_gpus()` (detailed info by way of NVML/ROCm SMI/torch)

## Client configs (MCP-style)

Expand Down
56 changes: 55 additions & 1 deletion src/keep_gpu/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ def start_keep(
busy_threshold: int = -1,
job_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Start a KeepGPU session that reserves VRAM on one or more GPUs.

Args:
gpu_ids: GPU indices to target; None uses all available GPUs.
vram: Human-readable VRAM size to keep (for example, "1GiB").
interval: Seconds between controller checks/actions.
busy_threshold: Utilization above which the controller backs off.
job_id: Optional session identifier; a UUID is generated if omitted.

Returns:
Dict with the started session's job_id, e.g. ``{"job_id": "<id>"}``.

Raises:
ValueError: If the provided job_id already exists.
"""
job_id = job_id or str(uuid.uuid4())
if job_id in self._sessions:
raise ValueError(f"job_id {job_id} already exists")
Expand All @@ -78,6 +94,20 @@ def start_keep(
def stop_keep(
self, job_id: Optional[str] = None, quiet: bool = False
) -> Dict[str, Any]:
"""
Stop one or all active keep sessions.

If job_id is supplied, only that session is stopped; otherwise all active
sessions are released. When quiet=True, informational logging is skipped.

Args:
job_id: Session identifier to stop; None stops every session.
quiet: Suppress informational logs about stopped sessions.

Returns:
Dict with a "stopped" list of job ids. If a specific job_id was not
found, a "message" field explains the miss.
"""
if job_id:
session = self._sessions.pop(job_id, None)
if session:
Expand Down Expand Up @@ -118,6 +148,7 @@ def list_gpus(self) -> Dict[str, Any]:
return {"gpus": infos}

def shutdown(self) -> None:
"""Stop all sessions quietly; ignore errors during interpreter teardown."""
try:
self.stop_keep(None, quiet=True)
except Exception: # pragma: no cover - defensive
Expand All @@ -126,6 +157,16 @@ def shutdown(self) -> None:


def _handle_request(server: KeepGPUServer, payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Dispatch a JSON-RPC payload to the server and return a response dict.

Args:
server: Target KeepGPUServer.
payload: Dict with "method", optional "params", and optional "id".

Returns:
JSON-RPC-style dict containing either "result" or "error" plus "id".
"""
method = payload.get("method")
params = payload.get("params", {}) or {}
req_id = payload.get("id")
Expand All @@ -150,11 +191,18 @@ class _JSONRPCHandler(BaseHTTPRequestHandler):
server_version = "KeepGPU-MCP/0.1"

def do_POST(self): # noqa: N802
"""
Handle an HTTP JSON-RPC request and write a JSON response.

Expects application/json bodies containing {"method", "params", "id"}.
Returns 400 with an error object if parsing fails.
"""
try:
length = int(self.headers.get("content-length", "0"))
body = self.rfile.read(length).decode("utf-8")
payload = json.loads(body)
response = _handle_request(self.server.keepgpu_server, payload) # type: ignore[attr-defined]
server_ref = self.server.keepgpu_server # type: ignore[attr-defined]
response = _handle_request(server_ref, payload)
status = 200
except (json.JSONDecodeError, ValueError, UnicodeDecodeError) as exc:
response = {"error": {"message": f"Bad request: {exc}"}}
Expand All @@ -167,10 +215,12 @@ def do_POST(self): # noqa: N802
self.wfile.write(data)

def log_message(self, format, *args): # noqa: A003
"""Suppress default request logging."""
return


def run_stdio(server: KeepGPUServer) -> None:
"""Serve JSON-RPC requests over stdin/stdout (one JSON object per line)."""
for line in sys.stdin:
line = line.strip()
if not line:
Expand All @@ -185,13 +235,16 @@ def run_stdio(server: KeepGPUServer) -> None:


def run_http(server: KeepGPUServer, host: str = "127.0.0.1", port: int = 8765) -> None:
"""Run a lightweight HTTP JSON-RPC server on the given host/port."""

class _Server(TCPServer):
allow_reuse_address = True

httpd = _Server((host, port), _JSONRPCHandler)
httpd.keepgpu_server = server # type: ignore[attr-defined]

def _serve():
"""Run the HTTP server loop until shutdown."""
httpd.serve_forever()

thread = threading.Thread(target=_serve)
Expand All @@ -210,6 +263,7 @@ def _serve():


def main() -> None:
"""CLI entry point for the KeepGPU MCP server."""
parser = argparse.ArgumentParser(description="KeepGPU MCP server")
parser.add_argument(
"--mode",
Expand Down