Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ with GlobalGPUController(gpu_ids=[0, 1], vram_to_keep="750MB", interval=90, busy

### MCP endpoint (experimental)

- Start a simple JSON-RPC server on stdin/stdout:
- Start a simple JSON-RPC server on stdin/stdout (default):
```bash
keep-gpu-mcp-server
```
- Or expose it over HTTP (JSON-RPC 2.0 by way of POST):
```bash
keep-gpu-mcp-server --mode http --host 0.0.0.0 --port 8765
```
- Example request (one per line):
```json
{"id": 1, "method": "start_keep", "params": {"gpu_ids": [0], "vram": "512MB", "interval": 60, "busy_threshold": 20}}
Expand All @@ -108,6 +112,26 @@ with GlobalGPUController(gpu_ids=[0, 1], vram_to_keep="750MB", interval=90, busy
command: ["keep-gpu-mcp-server"]
adapter: stdio
```
- Minimal client config (HTTP MCP):
```yaml
servers:
keepgpu:
url: http://127.0.0.1:8765/
adapter: http
```
- Remote/SSH tunnel example (HTTP):
```bash
keep-gpu-mcp-server --mode http --host 0.0.0.0 --port 8765
```
Client config (replace hostname/tunnel as needed):
```yaml
servers:
keepgpu:
url: http://gpu-box.example.com:8765/
adapter: http
```
For untrusted networks, put the server behind your own auth/reverse-proxy or
tunnel by way of SSH (for example, `ssh -L 8765:localhost:8765 gpu-box`).

## Contributing

Expand Down
4 changes: 4 additions & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ expectations so you can get productive quickly and avoid surprises in CI.
## MCP server (experimental)

- Start: `keep-gpu-mcp-server` (stdin/stdout JSON-RPC)
- HTTP option: `keep-gpu-mcp-server --mode http --host 0.0.0.0 --port 8765`
- Methods: `start_keep`, `stop_keep`, `status`, `list_gpus`
- Example request:
```json
{"id":1,"method":"start_keep","params":{"gpu_ids":[0],"vram":"512MB","interval":60,"busy_threshold":20}}
```
- Remote tip: for shared clusters, prefer HTTP behind your own auth/reverse-proxy
or tunnel with SSH (`ssh -L 8765:localhost:8765 gpu-box`), then point your MCP
client at `http://127.0.0.1:8765/`.

## Pull requests

Expand Down
30 changes: 30 additions & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,36 @@ servers:
Tools exposed: `start_keep`, `stop_keep`, `status`, `list_gpus`. Each request is
a single JSON line; see above for an example payload.

### HTTP transport

Prefer TCP instead of stdio? Run:

```bash
keep-gpu-mcp-server --mode http --host 0.0.0.0 --port 8765
```

And point your MCP client at `http://127.0.0.1:8765/` (JSON-RPC 2.0 over POST).

### Remote/cluster usage

- Start on the GPU host:
```bash
keep-gpu-mcp-server --mode http --host 0.0.0.0 --port 8765
```
- Point your agent at the host:
```yaml
servers:
keepgpu:
url: http://gpu-box.example.com:8765/
adapter: http
```
- If the host is not on a trusted network, tunnel instead of exposing the port:
```bash
ssh -L 8765:localhost:8765 gpu-box.example.com
```
Then use `http://127.0.0.1:8765/` in your MCP config. For multi-user clusters,
consider fronting the service with your own auth/reverse-proxy.

=== "Editable dev install"
```bash
git clone https://github.com/Wangmerlyn/KeepGPU.git
Expand Down
89 changes: 82 additions & 7 deletions src/keep_gpu/mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Minimal MCP-style JSON-RPC server for KeepGPU.

The server reads JSON lines from stdin and writes JSON responses to stdout.
Run over stdin/stdout (default) or a lightweight HTTP server.
Supported methods:
- start_keep(gpu_ids, vram, interval, busy_threshold, job_id)
- stop_keep(job_id=None) # None stops all
- status(job_id=None) # None lists all
- list_gpus()
"""

from __future__ import annotations
Expand All @@ -14,6 +15,10 @@
import json
import sys
import uuid
import argparse
import threading
from http.server import BaseHTTPRequestHandler
from socketserver import TCPServer
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -70,20 +75,23 @@ def start_keep(
logger.info("Started keep session %s on GPUs %s", job_id, gpu_ids)
return {"job_id": job_id}

def stop_keep(self, job_id: Optional[str] = None) -> Dict[str, Any]:
def stop_keep(
self, job_id: Optional[str] = None, quiet: bool = False
) -> Dict[str, Any]:
if job_id:
session = self._sessions.pop(job_id, None)
if session:
session.controller.release()
logger.info("Stopped keep session %s", job_id)
if not quiet:
logger.info("Stopped keep session %s", job_id)
return {"stopped": [job_id]}
return {"stopped": [], "message": "job_id not found"}

stopped_ids = list(self._sessions.keys())
for job_id in stopped_ids:
session = self._sessions.pop(job_id)
session.controller.release()
if stopped_ids:
if stopped_ids and not quiet:
logger.info("Stopped sessions: %s", stopped_ids)
return {"stopped": stopped_ids}

Expand Down Expand Up @@ -111,7 +119,7 @@ def list_gpus(self) -> Dict[str, Any]:

def shutdown(self) -> None:
try:
self.stop_keep(None)
self.stop_keep(None, quiet=True)
except Exception: # pragma: no cover - defensive
# Avoid noisy errors during interpreter teardown
return
Expand All @@ -138,8 +146,31 @@ def _handle_request(server: KeepGPUServer, payload: Dict[str, Any]) -> Dict[str,
return {"id": req_id, "error": {"message": str(exc)}}


def main() -> None:
server = KeepGPUServer()
class _JSONRPCHandler(BaseHTTPRequestHandler):
server_version = "KeepGPU-MCP/0.1"

def do_POST(self): # noqa: N802
try:
length = int(self.headers.get("content-length", "0"))
body = self.rfile.read(length).decode("utf-8")
payload = json.loads(body)
response = _handle_request(self.server.keepgpu_server, payload) # type: ignore[attr-defined]
status = 200
except (json.JSONDecodeError, ValueError, UnicodeDecodeError) as exc:
response = {"error": {"message": f"Bad request: {exc}"}}
status = 400
data = json.dumps(response).encode()
self.send_response(status)
self.send_header("content-type", "application/json")
self.send_header("content-length", str(len(data)))
self.end_headers()
self.wfile.write(data)

def log_message(self, format, *args): # noqa: A003
return


def run_stdio(server: KeepGPUServer) -> None:
for line in sys.stdin:
line = line.strip()
if not line:
Expand All @@ -153,5 +184,49 @@ def main() -> None:
sys.stdout.flush()


def run_http(server: KeepGPUServer, host: str = "127.0.0.1", port: int = 8765) -> None:
class _Server(TCPServer):
allow_reuse_address = True

httpd = _Server((host, port), _JSONRPCHandler)
httpd.keepgpu_server = server # type: ignore[attr-defined]

def _serve():
httpd.serve_forever()

thread = threading.Thread(target=_serve)
thread.start()
logger.info(
"MCP HTTP server listening on http://%s:%s", host, httpd.server_address[1]
)
try:
thread.join()
except KeyboardInterrupt:
pass
finally:
httpd.shutdown()
httpd.server_close()
server.shutdown()


def main() -> None:
parser = argparse.ArgumentParser(description="KeepGPU MCP server")
parser.add_argument(
"--mode",
choices=["stdio", "http"],
default="stdio",
help="Transport mode (default: stdio)",
)
parser.add_argument("--host", default="127.0.0.1", help="HTTP host (http mode)")
parser.add_argument("--port", type=int, default=8765, help="HTTP port (http mode)")
args = parser.parse_args()

server = KeepGPUServer()
if args.mode == "stdio":
run_stdio(server)
else:
run_http(server, host=args.host, port=args.port)


if __name__ == "__main__":
main()