diff --git a/docs/guides/mcp.md b/docs/guides/mcp.md index ba05b68..f52d047 100644 --- a/docs/guides/mcp.md +++ b/docs/guides/mcp.md @@ -7,7 +7,7 @@ orchestrators can start/stop keep-alive jobs and inspect GPU state. - You run KeepGPU from an agent (LangChain, custom orchestrator, etc.) instead of a shell. - You want to keep GPUs alive on a remote box over TCP rather than stdio. -- You need a quick way to list GPU utilization/memory via the same interface. +- You need a quick way to list GPU utilization/memory by way of the same interface. ## Quick start @@ -33,7 +33,7 @@ Supported methods: - `start_keep(gpu_ids?, vram?, interval?, busy_threshold?, job_id?)` - `stop_keep(job_id?)` (omit `job_id` to stop all) - `status(job_id?)` (omit `job_id` to list active jobs) -- `list_gpus()` (detailed info via NVML/ROCm SMI/torch) +- `list_gpus()` (detailed info by way of NVML/ROCm SMI/torch) ## Client configs (MCP-style) diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py index 81ed71c..e4d255a 100644 --- a/src/keep_gpu/mcp/server.py +++ b/src/keep_gpu/mcp/server.py @@ -52,6 +52,22 @@ def start_keep( busy_threshold: int = -1, job_id: Optional[str] = None, ) -> Dict[str, Any]: + """ + Start a KeepGPU session that reserves VRAM on one or more GPUs. + + Args: + gpu_ids: GPU indices to target; None uses all available GPUs. + vram: Human-readable VRAM size to keep (for example, "1GiB"). + interval: Seconds between controller checks/actions. + busy_threshold: Utilization above which the controller backs off. + job_id: Optional session identifier; a UUID is generated if omitted. + + Returns: + Dict with the started session's job_id, e.g. ``{"job_id": ""}``. + + Raises: + ValueError: If the provided job_id already exists. + """ job_id = job_id or str(uuid.uuid4()) if job_id in self._sessions: raise ValueError(f"job_id {job_id} already exists") @@ -78,6 +94,20 @@ def start_keep( def stop_keep( self, job_id: Optional[str] = None, quiet: bool = False ) -> Dict[str, Any]: + """ + Stop one or all active keep sessions. + + If job_id is supplied, only that session is stopped; otherwise all active + sessions are released. When quiet=True, informational logging is skipped. + + Args: + job_id: Session identifier to stop; None stops every session. + quiet: Suppress informational logs about stopped sessions. + + Returns: + Dict with a "stopped" list of job ids. If a specific job_id was not + found, a "message" field explains the miss. + """ if job_id: session = self._sessions.pop(job_id, None) if session: @@ -118,6 +148,7 @@ def list_gpus(self) -> Dict[str, Any]: return {"gpus": infos} def shutdown(self) -> None: + """Stop all sessions quietly; ignore errors during interpreter teardown.""" try: self.stop_keep(None, quiet=True) except Exception: # pragma: no cover - defensive @@ -126,6 +157,16 @@ def shutdown(self) -> None: def _handle_request(server: KeepGPUServer, payload: Dict[str, Any]) -> Dict[str, Any]: + """ + Dispatch a JSON-RPC payload to the server and return a response dict. + + Args: + server: Target KeepGPUServer. + payload: Dict with "method", optional "params", and optional "id". + + Returns: + JSON-RPC-style dict containing either "result" or "error" plus "id". + """ method = payload.get("method") params = payload.get("params", {}) or {} req_id = payload.get("id") @@ -150,11 +191,18 @@ class _JSONRPCHandler(BaseHTTPRequestHandler): server_version = "KeepGPU-MCP/0.1" def do_POST(self): # noqa: N802 + """ + Handle an HTTP JSON-RPC request and write a JSON response. + + Expects application/json bodies containing {"method", "params", "id"}. + Returns 400 with an error object if parsing fails. + """ try: length = int(self.headers.get("content-length", "0")) body = self.rfile.read(length).decode("utf-8") payload = json.loads(body) - response = _handle_request(self.server.keepgpu_server, payload) # type: ignore[attr-defined] + server_ref = self.server.keepgpu_server # type: ignore[attr-defined] + response = _handle_request(server_ref, payload) status = 200 except (json.JSONDecodeError, ValueError, UnicodeDecodeError) as exc: response = {"error": {"message": f"Bad request: {exc}"}} @@ -167,10 +215,12 @@ def do_POST(self): # noqa: N802 self.wfile.write(data) def log_message(self, format, *args): # noqa: A003 + """Suppress default request logging.""" return def run_stdio(server: KeepGPUServer) -> None: + """Serve JSON-RPC requests over stdin/stdout (one JSON object per line).""" for line in sys.stdin: line = line.strip() if not line: @@ -185,6 +235,8 @@ def run_stdio(server: KeepGPUServer) -> None: def run_http(server: KeepGPUServer, host: str = "127.0.0.1", port: int = 8765) -> None: + """Run a lightweight HTTP JSON-RPC server on the given host/port.""" + class _Server(TCPServer): allow_reuse_address = True @@ -192,6 +244,7 @@ class _Server(TCPServer): httpd.keepgpu_server = server # type: ignore[attr-defined] def _serve(): + """Run the HTTP server loop until shutdown.""" httpd.serve_forever() thread = threading.Thread(target=_serve) @@ -210,6 +263,7 @@ def _serve(): def main() -> None: + """CLI entry point for the KeepGPU MCP server.""" parser = argparse.ArgumentParser(description="KeepGPU MCP server") parser.add_argument( "--mode",