Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion src/keep_gpu/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ def start_keep(
busy_threshold: int = -1,
job_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Start a KeepGPU session that periodically reserves the specified amount of VRAM on one or more GPUs.

Parameters:
gpu_ids (Optional[List[int]]): List of GPU indices to target; if `None`, all available GPUs may be considered.
vram (str): Amount of VRAM to reserve (human-readable, e.g. "1GiB").
interval (int): Time in seconds between controller checks/actions.
busy_threshold (int): Numeric threshold controlling what the controller treats as "busy" (semantics provided by the controller).
job_id (Optional[str]): Identifier for the session; when `None`, a new UUID is generated.

Returns:
dict: A dictionary with the started session's `job_id`, e.g. `{"job_id": "<id>"}`.

Raises:
ValueError: If `job_id` is provided and already exists.
"""
job_id = job_id or str(uuid.uuid4())
if job_id in self._sessions:
raise ValueError(f"job_id {job_id} already exists")
Expand All @@ -78,6 +94,19 @@ def start_keep(
def stop_keep(
self, job_id: Optional[str] = None, quiet: bool = False
) -> Dict[str, Any]:
"""
Stop one or all active keep sessions.

If `job_id` is provided, stops and removes that session if it exists; otherwise stops and removes all sessions. When `quiet` is True, informational logging about stopped sessions is suppressed.

Parameters:
job_id (Optional[str]): Identifier of the session to stop. If omitted, all sessions are stopped.
quiet (bool): If True, do not emit informational logs about stopped sessions.

Returns:
result (Dict[str, Any]): A dictionary with a "stopped" key listing stopped job IDs. If a specific
`job_id` was requested but not found, the dictionary also includes a "message" explaining that.
"""
if job_id:
session = self._sessions.pop(job_id, None)
if session:
Expand Down Expand Up @@ -118,6 +147,11 @@ def list_gpus(self) -> Dict[str, Any]:
return {"gpus": infos}

def shutdown(self) -> None:
"""
Stop all active sessions and release resources, suppressing any errors that occur during interpreter teardown.

This attempts to stop every session (quietly) and ignores exceptions to avoid noisy errors when the interpreter is shutting down.
"""
try:
self.stop_keep(None, quiet=True)
except Exception: # pragma: no cover - defensive
Expand All @@ -126,6 +160,22 @@ def shutdown(self) -> None:


def _handle_request(server: KeepGPUServer, payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Dispatches a JSON-RPC-like request payload to the corresponding KeepGPUServer method and returns a JSON-RPC response object.

Parameters:
server (KeepGPUServer): The server instance whose methods will be invoked.
payload (dict): The incoming request object; expected keys:
- "method" (str): RPC method name ("start_keep", "stop_keep", "status", "list_gpus").
- "params" (dict, optional): Keyword arguments for the method.
- "id" (any, optional): Caller-provided request identifier preserved in the response.

Returns:
dict: A JSON-RPC-style response containing:
- "id": the original request id (or None if not provided).
- "result": the method's return value on success.
- OR "error": an object with a "message" string describing the failure.
"""
method = payload.get("method")
params = payload.get("params", {}) or {}
req_id = payload.get("id")
Expand All @@ -150,6 +200,11 @@ class _JSONRPCHandler(BaseHTTPRequestHandler):
server_version = "KeepGPU-MCP/0.1"

def do_POST(self): # noqa: N802
"""
Handle HTTP POST requests containing a JSON-RPC payload and send a JSON response.

Reads the request body using the Content-Length header, parses it as JSON, dispatches the payload to the internal JSON-RPC dispatcher, and writes the dispatcher result as an application/json response. If the request body cannot be decoded or parsed, responds with HTTP 400 and a JSON error object describing the parsing error.
"""
try:
length = int(self.headers.get("content-length", "0"))
body = self.rfile.read(length).decode("utf-8")
Expand All @@ -167,10 +222,28 @@ def do_POST(self): # noqa: N802
self.wfile.write(data)

def log_message(self, format, *args): # noqa: A003
"""
Suppress the BaseHTTPRequestHandler's default request logging by overriding log_message to do nothing.

Parameters:
format (str): The format string provided by BaseHTTPRequestHandler.
*args: Values to interpolate into `format`.
"""
return


def run_stdio(server: KeepGPUServer) -> None:
"""
Read line-delimited JSON-RPC requests from stdin, dispatch each request to the server, and write the JSON response to stdout.

Parameters:
server (KeepGPUServer): Server instance used to handle JSON-RPC requests.

Description:
- Processes each non-empty line from stdin as a JSON payload.
- On successful handling, writes the JSON-RPC response followed by a newline to stdout and flushes.
- If parsing or handling raises an exception, writes an error object with the exception message as the response.
"""
for line in sys.stdin:
line = line.strip()
if not line:
Expand All @@ -185,13 +258,29 @@ def run_stdio(server: KeepGPUServer) -> None:


def run_http(server: KeepGPUServer, host: str = "127.0.0.1", port: int = 8765) -> None:
"""
Start a lightweight HTTP JSON-RPC server that exposes the given KeepGPUServer on the specified host and port.

Starts a TCP HTTP server serving _JSONRPCHandler in a background thread, logs the listening address, waits for the thread to finish, and on interruption or shutdown performs a clean shutdown of the HTTP server and calls server.shutdown() to release resources.

Parameters:
server (KeepGPUServer): The KeepGPUServer instance whose RPC methods will be exposed over HTTP.
host (str): Host address to bind the HTTP server to.
port (int): TCP port to bind the HTTP server to.
"""
class _Server(TCPServer):
allow_reuse_address = True

httpd = _Server((host, port), _JSONRPCHandler)
httpd.keepgpu_server = server # type: ignore[attr-defined]

def _serve():
"""
Run the HTTP server's request loop until the server is shut down.

Blocks the current thread and processes incoming HTTP requests for the
server instance until the server is stopped.
"""
httpd.serve_forever()

thread = threading.Thread(target=_serve)
Expand All @@ -210,6 +299,11 @@ def _serve():


def main() -> None:
"""
Entry point for the KeepGPU MCP server that parses command-line arguments and starts the chosen transport.

Parses --mode (stdio or http), --host and --port (for http mode), instantiates a KeepGPUServer, and runs either the stdio loop or the HTTP server based on the selected mode.
"""
parser = argparse.ArgumentParser(description="KeepGPU MCP server")
parser.add_argument(
"--mode",
Expand All @@ -229,4 +323,4 @@ def main() -> None:


if __name__ == "__main__":
main()
main()
Loading