Skip to content
This repository was archived by the owner on Oct 28, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"semgrep==1.131.0",
"opentelemetry-api>=1.25.0",
"opentelemetry-sdk>=1.25.0",
"pyjwt>=2.10.1",
]

[project.license]
Expand All @@ -65,6 +66,7 @@ dev-dependencies = [
"tomli-w>=1.0.0",
"pre-commit>=3.0.0",
"pyright>=1.1.0",
"cryptography>=45.0.6",
]

[tool.ruff]
Expand All @@ -74,6 +76,7 @@ extend-exclude = [
"build",
"dist",
]
exclude = ["src/semgrep_mcp/semgrep_interfaces"]

[tool.ruff.lint]
select = [
Expand Down
43 changes: 42 additions & 1 deletion src/semgrep_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import click
import httpx
from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.settings import AuthSettings
from mcp.server.fastmcp import Context, FastMCP
from mcp.shared.exceptions import McpError
from mcp.types import (
Expand All @@ -18,7 +20,7 @@
INVALID_REQUEST,
ErrorData,
)
from pydantic import Field, ValidationError
from pydantic import AnyHttpUrl, Field, ValidationError
from starlette.requests import Request
from starlette.responses import JSONResponse

Expand All @@ -31,6 +33,7 @@
set_semgrep_executable,
)
from semgrep_mcp.semgrep_interfaces.semgrep_output_v1 import CliOutput
from semgrep_mcp.token_verifier import JWKSTokenVerifier
from utilities.tracing import start_tracing, with_span

# ---------------------------------------------------------------------------------
Expand All @@ -43,6 +46,10 @@
SEMGREP_API_URL = f"{SEMGREP_URL}/api"
SEMGREP_API_VERSION = "v1"

AUTH_BASE_URL = os.getenv("SEMGREP_AUTH_URL", "https://login.semgrep.dev")
SERVER_URL = os.getenv("SEMGREP_MCP_URL", "http://localhost:8000") # mcp.semgrep.ai in prod
WORKOS_CLIENT_ID = os.getenv("WORKOS_CLIENT_ID", "client_01JWXZ4GZ3WP1BFWJ5YTE9JWK7") # not secret

# Field definitions for function parameters
CODE_FILES_FIELD = Field(description="List of dictionaries with 'filename' and 'content' keys")
LOCAL_CODE_FILES_FIELD = Field(
Expand Down Expand Up @@ -302,17 +309,51 @@ async def server_lifespan(_server: FastMCP) -> AsyncIterator[SemgrepContext | No
# Create a fast MCP server
mcp = FastMCP(
"Semgrep",
# Token verifier for authentication
token_verifier=JWKSTokenVerifier(
jwks_endpoint=f"{AUTH_BASE_URL}/oauth2/jwks",
issuer=AUTH_BASE_URL,
audience=WORKOS_CLIENT_ID,
),
# Auth settings for RFC 9728 Protected Resource Metadata
auth=AuthSettings(
# Authorization Server URL
issuer_url=AnyHttpUrl(AUTH_BASE_URL),
# This server's URL
resource_server_url=AnyHttpUrl(SERVER_URL),
required_scopes=["openid", "profile", "email"],
),
stateless_http=True,
json_response=True,
lifespan=server_lifespan,
)

http_client = httpx.AsyncClient()

@mcp.custom_route("/.well-known/oauth-authorization-server", methods=["GET"])
async def oauth_authorization_server(request: Request) -> JSONResponse:
"""
Get the OAuth authorization server configuration for legacy clients
"""
r = await http_client.get(f"{AUTH_BASE_URL}/.well-known/oauth-authorization-server")
return JSONResponse(content=r.json())

# ---------------------------------------------------------------------------------
# MCP Tools
# ---------------------------------------------------------------------------------

@mcp.tool()
async def semgrep_user_info() -> str:
"""
Get the user info for the current semgrep user.
"""

# TODO remove before production
# demo getting access token only

access_token = get_access_token()
print(f"User token: {access_token}")
return "has access" if access_token else "no access"

@mcp.tool()
async def semgrep_rule_schema() -> str:
Expand Down
86 changes: 86 additions & 0 deletions src/semgrep_mcp/token_verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""JWT token verifier implementation using JWKS validation."""

import logging
from typing import Any

import jwt
from jwt import PyJWKClient
from mcp.server.auth.provider import AccessToken, TokenVerifier

logger = logging.getLogger(__name__)

class JWKSTokenVerifier(TokenVerifier):
"""JWT token verifier that uses JWKS for signature validation.

This implementation validates JWT tokens by:
1. Using PyJWKClient to automatically fetch and cache public keys from JWKS
2. Validating the JWT signature
3. Checking standard JWT claims (exp, iat, iss, aud)
"""

def __init__(
self,
jwks_endpoint: str,
issuer: str,
audience: str,
):
self.jwks_endpoint = jwks_endpoint
self.issuer = issuer
self.audience = audience

# Initialize PyJWKClient for automatic JWKS handling
self._jwk_client = PyJWKClient(self.jwks_endpoint)

async def verify_token(self, token: str) -> AccessToken | None:
"""Verify JWT token using JWKS validation."""

try:
# Log payload for debugging
logger.debug(f"JWT token: {token}")

# Get the signing key from JWKS
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
if not signing_key:
logger.warning("Could not retrieve signing key from JWKS")
return None

# Decode and validate the JWT
payload = self._decode_and_validate_jwt(token, signing_key.key)
if not payload:
return None

# Log payload for debugging
logger.debug(f"JWT payload: {payload}")

return AccessToken(
token=token,
client_id=payload.get("aud", "unknown"),
scopes=["openid", "profile", "email"],
expires_at=payload.get("exp"),
)

except Exception as e:
logger.warning(f"JWT validation failed: {e}")
return None

def _decode_and_validate_jwt(self, token: str, public_key: Any) -> dict | None:
"""Decode and validate JWT using the public key."""
try:
# Decode JWT with validation
payload = jwt.decode(
token,
public_key,
algorithms=["RS256"], # NEVER CHANGE THIS or read from the token
issuer=self.issuer, # Validate issuer
audience=self.audience, # Validate audience
options={
"require": ["exp", "iat", "iss", "aud"],
},
leeway=10,
)

return payload

except Exception as e:
logger.warning(f"JWT validation error: {e}")
return None
Loading
Loading