Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add optional SSE middleware to obtain Grafana info from headers #21

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ mcp-grafana = "mcp_grafana.cli:app"

[dependency-groups]
dev = [
"httpx-sse>=0.4.0",
"pytest>=8.3.4",
"pytest-asyncio>=0.25.2",
"pytest-httpserver>=1.1.1",
]
lint = [
"ruff>=0.8.5",
Expand Down
12 changes: 11 additions & 1 deletion src/mcp_grafana/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
from types import MethodType

import typer

Expand All @@ -13,5 +14,14 @@ class Transport(enum.StrEnum):


@app.command()
def run(transport: Transport = Transport.stdio):
def run(transport: Transport = Transport.stdio, header_auth: bool = False):
if transport == Transport.sse and header_auth:
from .middleware import run_sse_async_with_middleware

# Monkeypatch the run_sse_async method to inject a Grafana middleware.
# This is a bit of a hack, but fastmcp doesn't have a way of adding
# middleware. It's not unreasonable to do this really, since fastmcp
# is just a thin wrapper around the low level mcp server.
mcp.run_sse_async = MethodType(run_sse_async_with_middleware, mcp)

mcp.run(transport.value)
27 changes: 25 additions & 2 deletions src/mcp_grafana/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
We should separate HTTP types from tool types.
"""

import contextvars
import math
from datetime import datetime
from typing import Any

import httpx
from pydantic import UUID4

from .settings import grafana_settings
from .settings import grafana_settings, GrafanaSettings
from .grafana_types import (
AddActivityToIncidentArguments,
CreateIncidentArguments,
Expand Down Expand Up @@ -52,6 +53,27 @@ def __init__(self, url: str, api_key: str | None = None) -> None:
base_url=url, auth=auth, timeout=httpx.Timeout(timeout=30.0)
)

@classmethod
def from_settings(cls, settings: GrafanaSettings) -> "GrafanaClient":
"""
Create a Grafana client from the given settings.
"""
return cls(settings.url, settings.api_key)

@classmethod
def for_current_request(cls) -> "GrafanaClient":
"""
Create a Grafana client for the current request.

This will use the Grafana settings from the current contextvar.
If running with the stdio transport then these settings will be
the ones in the MCP server's settings. If running with the SSE
transport then the settings will be extracted from the request
headers if possible, falling back to the defaults in the MCP
server's settings.
"""
return cls.from_settings(grafana_settings.get())

async def get(self, path: str, params: dict[str, str] | None = None) -> bytes:
r = await self.c.get(path, params=params)
if not r.is_success:
Expand Down Expand Up @@ -217,4 +239,5 @@ async def list_prometheus_label_values(
)


grafana_client = GrafanaClient(grafana_settings.url, api_key=grafana_settings.api_key)
grafana_client = contextvars.ContextVar("grafana_client")
grafana_client.set(GrafanaClient.for_current_request())
2 changes: 1 addition & 1 deletion src/mcp_grafana/grafana_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class SearchDashboardsArguments(BaseModel):
)
starred: bool = Field(default=False, description="Only include starred dashboards")
limit: int = Field(
default=grafana_settings.tools.search.limit,
default=grafana_settings.get().tools.search.limit,
description="Limit the number of returned results",
)
page: int = Field(default=1)
Expand Down
106 changes: 106 additions & 0 deletions src/mcp_grafana/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from dataclasses import dataclass

from mcp.server import FastMCP
from starlette.datastructures import Headers

from .client import GrafanaClient, grafana_client
from .settings import GrafanaSettings, grafana_settings


@dataclass
class GrafanaInfo:
"""
Simple container for the Grafana URL and API key.
"""

api_key: str
url: str

@classmethod
def from_headers(cls, headers: Headers) -> "GrafanaInfo | None":
if (url := headers.get("X-Grafana-URL")) is not None and (
key := headers.get("X-Grafana-API-Key")
) is not None:
return cls(api_key=key, url=url)
return None


class GrafanaMiddleware:
"""
Middleware that sets up Grafana info for the current request.

Grafana info will be stored in the `grafana_info` contextvar, which can be
used by tools/resources etc to access the Grafana configuration for the
current request, if it was provided.

This should be used as a context manager before handling the /sse request.
"""

def __init__(self, request):
self.request = request
self.settings_token = None
self.client_token = None

async def __aenter__(self):
if (info := GrafanaInfo.from_headers(self.request.headers)) is not None:
current_settings = grafana_settings.get()
new_settings = GrafanaSettings(
url=info.url,
api_key=info.api_key,
tools=current_settings.tools,
)
self.settings_token = grafana_settings.set(new_settings)
self.client_token = grafana_client.set(
GrafanaClient.from_settings(new_settings)
)

async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.settings_token is not None:
grafana_settings.reset(self.settings_token)
if self.client_token is not None:
grafana_client.reset(self.client_token)


async def run_sse_async_with_middleware(self: FastMCP) -> None:
"""
Run the server using SSE transport, with a middleware that extracts
Grafana authentication information from the request headers.

The vast majority of this code is the same as the original run_sse_async
method (see https://github.com/modelcontextprotocol/python-sdk/blob/44c0004e6c69e336811bb6793b7176e1eda50015/src/mcp/server/fastmcp/server.py#L436-L468).
"""

from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
import uvicorn

sse = SseServerTransport("/messages/")

async def handle_sse(request):
async with GrafanaMiddleware(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await self._mcp_server.run(
streams[0],
streams[1],
self._mcp_server.create_initialization_options(),
)

starlette_app = Starlette(
debug=self.settings.debug,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)

config = uvicorn.Config(
starlette_app,
host=self.settings.host,
port=self.settings.port,
log_level=self.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
11 changes: 10 additions & 1 deletion src/mcp_grafana/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextvars

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand Down Expand Up @@ -69,4 +71,11 @@ class GrafanaSettings(BaseSettings):
tools: ToolSettings = Field(default_factory=ToolSettings)


grafana_settings = GrafanaSettings()
# This contextvar can be updated by middleware to reflect the Grafana settings
# for the current request.

# If the middleware is not used, the default settings will be used.
grafana_settings: contextvars.ContextVar[GrafanaSettings] = contextvars.ContextVar(
"grafana_settings"
)
grafana_settings.set(GrafanaSettings())
9 changes: 5 additions & 4 deletions src/mcp_grafana/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ def add_tools(mcp: FastMCP):
"""
Add all enabled tools to the MCP server.
"""
if grafana_settings.tools.search.enabled:
settings = grafana_settings.get()
if settings.tools.search.enabled:
search.add_tools(mcp)
if grafana_settings.tools.datasources.enabled:
if settings.tools.datasources.enabled:
datasources.add_tools(mcp)
if grafana_settings.tools.incident.enabled:
if settings.tools.incident.enabled:
incident.add_tools(mcp)
if grafana_settings.tools.prometheus.enabled:
if settings.tools.prometheus.enabled:
prometheus.add_tools(mcp)
6 changes: 3 additions & 3 deletions src/mcp_grafana/tools/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ async def list_datasources() -> bytes:
"""
List datasources in the Grafana instance.
"""
return await grafana_client.list_datasources()
return await grafana_client.get().list_datasources()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to do away with all of these annoying .get() calls, tools definitely shouldn't be able to modify this contextvar...



async def get_datasource_by_uid(uid: str) -> bytes:
"""
Get a datasource by uid.
"""
return await grafana_client.get_datasource(uid=uid)
return await grafana_client.get().get_datasource(uid=uid)


async def get_datasource_by_name(name: str) -> bytes:
"""
Get a datasource by name.
"""
return await grafana_client.get_datasource(name=name)
return await grafana_client.get().get_datasource(name=name)


def add_tools(mcp: FastMCP):
Expand Down
10 changes: 5 additions & 5 deletions src/mcp_grafana/tools/incident.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pydantic import BaseModel, Field

from ..client import (
grafana_client,
AddActivityToIncidentArguments,
CreateIncidentArguments,
grafana_client,
)
from ..grafana_types import QueryIncidentPreviewsRequest, IncidentPreviewsQuery

Expand Down Expand Up @@ -49,14 +49,14 @@ async def list_incidents(arguments: ListIncidentsArguments) -> bytes:
),
includeCustomFieldValues=True,
)
return await grafana_client.list_incidents(body)
return await grafana_client.get().list_incidents(body)


async def create_incident(arguments: CreateIncidentArguments) -> bytes:
"""
Create an incident in the Grafana Incident incident management tool.
"""
return await grafana_client.create_incident(arguments)
return await grafana_client.get().create_incident(arguments)


async def add_activity_to_incident(
Expand All @@ -70,7 +70,7 @@ async def add_activity_to_incident(
:param event_time: The time that the activity occurred. If not provided, the current time will be used.
If provided, it must be in RFC3339 format.
"""
return await grafana_client.add_activity_to_incident(
return await grafana_client.get().add_activity_to_incident(
AddActivityToIncidentArguments(
incidentId=incident_id,
body=body,
Expand All @@ -89,7 +89,7 @@ async def resolve_incident(incident_id: str, summary: str) -> bytes:
This should be succint but thorough and informative,
enough to serve as a mini post-incident-report.
"""
return await grafana_client.close_incident(incident_id, summary)
return await grafana_client.get().close_incident(incident_id, summary)


def add_tools(mcp: FastMCP):
Expand Down
8 changes: 4 additions & 4 deletions src/mcp_grafana/tools/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def query_prometheus(
expr=expr, # type: ignore
intervalMs=interval_ms,
)
response = await grafana_client.query(start, end, [query])
response = await grafana_client.get().query(start, end, [query])
return DSQueryResponse.model_validate_json(response)


Expand All @@ -81,7 +81,7 @@ async def list_prometheus_metric_metadata(

A mapping from metric name to all available metadata for that metric.
"""
response = await grafana_client.list_prometheus_metric_metadata(
response = await grafana_client.get().list_prometheus_metric_metadata(
datasource_uid,
limit=limit,
limit_per_metric=limit_per_metric,
Expand Down Expand Up @@ -144,7 +144,7 @@ async def list_prometheus_label_names(
end: Optionally, the end time of the time range to filter the results by.
limit: Optionally, the maximum number of results to return. Defaults to 100.
"""
response = await grafana_client.list_prometheus_label_names(
response = await grafana_client.get().list_prometheus_label_names(
datasource_uid,
matches=matches,
start=start,
Expand Down Expand Up @@ -174,7 +174,7 @@ async def list_prometheus_label_values(
end: Optionally, the end time of the query.
limit: Optionally, the maximum number of results to return. Defaults to 100.
"""
response = await grafana_client.list_prometheus_label_values(
response = await grafana_client.get().list_prometheus_label_values(
datasource_uid,
label_name,
matches=matches,
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_grafana/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ async def search_dashboards(arguments: SearchDashboardsArguments) -> bytes:
"""
Search dashboards in the Grafana instance.
"""
return await grafana_client.search_dashboards(arguments)
return await grafana_client.get().search_dashboards(arguments)


def add_tools(mcp: FastMCP):
Expand Down
Loading
Loading