Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ GREPTIMEDB_AUDIT_ENABLED=true # Enable audit logging
GREPTIMEDB_TRANSPORT=stdio # stdio, sse, or streamable-http
GREPTIMEDB_LISTEN_HOST=0.0.0.0 # HTTP server bind host
GREPTIMEDB_LISTEN_PORT=8080 # HTTP server bind port
GREPTIMEDB_ALLOWED_HOSTS= # DNS rebinding protection (comma-separated)
GREPTIMEDB_ALLOWED_ORIGINS= # CORS allowed origins (comma-separated)
```

### CLI Arguments
Expand Down Expand Up @@ -113,6 +115,28 @@ greptimedb-mcp-server --transport streamable-http --listen-port 8080
greptimedb-mcp-server --transport sse --listen-port 3000
```

#### DNS Rebinding Protection

By default, DNS rebinding protection is **disabled** for compatibility with proxies, gateways, and Kubernetes services. To enable it, use `--allowed-hosts`:

```bash
# Enable DNS rebinding protection with allowed hosts
greptimedb-mcp-server --transport streamable-http \
--allowed-hosts "localhost:*,127.0.0.1:*,my-service.namespace:*"

# With custom allowed origins for CORS
greptimedb-mcp-server --transport streamable-http \
--allowed-hosts "my-service.namespace:*" \
--allowed-origins "http://localhost:*,https://my-app.example.com"

# Or via environment variables
GREPTIMEDB_ALLOWED_HOSTS="localhost:*,my-service.namespace:*" \
GREPTIMEDB_ALLOWED_ORIGINS="http://localhost:*" \
greptimedb-mcp-server --transport streamable-http
```

If you encounter `421 Invalid Host Header` errors, either disable protection (default) or add your host to the allowed list.

## Security

### Read-Only Database User (Recommended)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "greptimedb-mcp-server"
version = "0.4.1"
version = "0.4.2"
description = "MCP server for GreptimeDB with SQL/TQL/PromQL support, sensitive data masking, and prompt templates for observability data analysis."
readme = "README.md"
license = {text = "MIT"}
Expand Down
45 changes: 45 additions & 0 deletions src/greptimedb_mcp_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ class Config:
Enable audit logging for all tool calls
"""

allowed_hosts: list[str]
"""
Allowed hosts for DNS rebinding protection (for sse/streamable-http).
If empty, DNS rebinding protection is disabled.
"""

allowed_origins: list[str]
"""
Allowed origins for CORS (for sse/streamable-http).
Only used when DNS rebinding protection is enabled.
"""

@staticmethod
def from_env_arguments() -> "Config":
"""
Expand Down Expand Up @@ -198,7 +210,30 @@ def from_env_arguments() -> "Config":
default=os.getenv("GREPTIMEDB_AUDIT_ENABLED", "true"),
)

parser.add_argument(
"--allowed-hosts",
type=str,
help=(
"Allowed hosts for DNS rebinding protection (comma-separated). "
"If not set, DNS rebinding protection is disabled. "
"Example: localhost:*,127.0.0.1:*,my-service.namespace:*"
),
default=os.getenv("GREPTIMEDB_ALLOWED_HOSTS", ""),
)

parser.add_argument(
"--allowed-origins",
type=str,
help=(
"Allowed origins for CORS (comma-separated). "
"Only used when allowed-hosts is set. "
"Example: http://localhost:*,https://my-app.example.com"
),
default=os.getenv("GREPTIMEDB_ALLOWED_ORIGINS", ""),
)

args = parser.parse_args()

return Config(
host=args.host,
port=args.port,
Expand All @@ -215,4 +250,14 @@ def from_env_arguments() -> "Config":
listen_host=args.listen_host,
listen_port=args.listen_port,
audit_enabled=args.audit_enabled,
allowed_hosts=_parse_comma_separated(args.allowed_hosts),
allowed_origins=_parse_comma_separated(args.allowed_origins),
)


def _parse_comma_separated(value: str) -> list[str]:
"""Parse a comma-separated string into a list of trimmed non-empty strings."""
value = value.strip()
if not value:
return []
return [item.strip() for item in value.split(",") if item.strip()]
26 changes: 26 additions & 0 deletions src/greptimedb_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import aiohttp
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.server import TransportSecuritySettings
from mysql.connector import connect, Error
from mysql.connector.pooling import MySQLConnectionPool

Expand Down Expand Up @@ -896,6 +897,31 @@ def main():
if _config.transport != "stdio":
mcp.settings.host = _config.listen_host
mcp.settings.port = _config.listen_port

# Configure DNS rebinding protection
# If allowed_hosts is empty, disable protection for compatibility
# with proxies, gateways, and Kubernetes services
if _config.allowed_hosts:
security_kwargs = {
"enable_dns_rebinding_protection": True,
"allowed_hosts": _config.allowed_hosts,
}
if _config.allowed_origins:
security_kwargs["allowed_origins"] = _config.allowed_origins
mcp.settings.transport_security = TransportSecuritySettings(
**security_kwargs
)
logger.info(
f"DNS rebinding protection: enabled "
f"(allowed_hosts: {_config.allowed_hosts}, "
f"allowed_origins: {_config.allowed_origins or 'default'})"
)
else:
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=False,
)
logger.info("DNS rebinding protection: disabled")

logger.info(
f"Starting MCP server with transport: {_config.transport} "
f"on {_config.listen_host}:{_config.listen_port}"
Expand Down
41 changes: 40 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from unittest.mock import patch
from greptimedb_mcp_server.config import Config
from greptimedb_mcp_server.config import Config, _parse_comma_separated


def test_config_default_values():
Expand All @@ -23,6 +23,8 @@ def test_config_default_values():
assert config.transport == "stdio"
assert config.listen_host == "0.0.0.0"
assert config.listen_port == 8080
assert config.allowed_hosts == []
assert config.allowed_origins == []


def test_config_env_variables():
Expand All @@ -42,6 +44,8 @@ def test_config_env_variables():
"GREPTIMEDB_TRANSPORT": "streamable-http",
"GREPTIMEDB_LISTEN_HOST": "127.0.0.1",
"GREPTIMEDB_LISTEN_PORT": "3000",
"GREPTIMEDB_ALLOWED_HOSTS": "localhost:*,127.0.0.1:*",
"GREPTIMEDB_ALLOWED_ORIGINS": "http://localhost:*,https://example.com",
}

with patch.dict(os.environ, env_vars):
Expand All @@ -60,6 +64,11 @@ def test_config_env_variables():
assert config.transport == "streamable-http"
assert config.listen_host == "127.0.0.1"
assert config.listen_port == 3000
assert config.allowed_hosts == ["localhost:*", "127.0.0.1:*"]
assert config.allowed_origins == [
"http://localhost:*",
"https://example.com",
]


def test_config_cli_arguments():
Expand Down Expand Up @@ -92,6 +101,10 @@ def test_config_cli_arguments():
"192.168.1.1",
"--listen-port",
"9090",
"--allowed-hosts",
"my-service.namespace:*",
"--allowed-origins",
"http://my-app.example.com",
]

with patch.dict(os.environ, {}, clear=True):
Expand All @@ -110,6 +123,8 @@ def test_config_cli_arguments():
assert config.transport == "sse"
assert config.listen_host == "192.168.1.1"
assert config.listen_port == 9090
assert config.allowed_hosts == ["my-service.namespace:*"]
assert config.allowed_origins == ["http://my-app.example.com"]


def test_config_precedence():
Expand Down Expand Up @@ -175,3 +190,27 @@ def test_config_precedence():
assert config.transport == "streamable-http"
assert config.listen_host == "cli-listen-host"
assert config.listen_port == 2222


class TestParseCommaSeparated:
"""Tests for _parse_comma_separated helper function."""

def test_empty_and_whitespace(self):
assert _parse_comma_separated("") == []
assert _parse_comma_separated(" ") == []

def test_single_value(self):
assert _parse_comma_separated("localhost:*") == ["localhost:*"]
assert _parse_comma_separated(" localhost:* ") == ["localhost:*"]

def test_multiple_values(self):
assert _parse_comma_separated(" localhost:* , 127.0.0.1:* ") == [
"localhost:*",
"127.0.0.1:*",
]

def test_empty_items_filtered(self):
assert _parse_comma_separated("localhost:*,, ,127.0.0.1:*") == [
"localhost:*",
"127.0.0.1:*",
]
107 changes: 107 additions & 0 deletions tests/test_http_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,110 @@ def test_config_invalid_transport_rejected(self):
from greptimedb_mcp_server.config import Config

Config.from_env_arguments()


class TestDnsRebindingProtection:
"""Tests for DNS rebinding protection configuration."""

@pytest.mark.asyncio
async def test_protection_disabled_by_default(self, free_port, mock_db_connection):
"""Test that DNS rebinding protection is disabled when allowed_hosts is empty."""
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.server import TransportSecuritySettings

test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port)

# Simulate our server.py logic: empty allowed_hosts = disabled
test_mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=False,
)

async def run_server():
await test_mcp.run_streamable_http_async()

server_task = asyncio.create_task(run_server())
await asyncio.sleep(0.5)

try:
async with httpx.AsyncClient(trust_env=False) as client:
# Request with arbitrary Host header should succeed
response = await client.post(
f"http://127.0.0.1:{free_port}/mcp",
json={
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test", "version": "1.0"},
},
},
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Host": "arbitrary-host.example.com:8080",
},
timeout=5.0,
)
# Should succeed (not 421)
assert response.status_code == 200
finally:
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
pass

@pytest.mark.asyncio
async def test_protection_enabled_rejects_invalid_host(
self, free_port, mock_db_connection
):
"""Test that enabled protection rejects requests with invalid Host header."""
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.server import TransportSecuritySettings

test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port)

# Enable protection with specific allowed hosts
test_mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=["localhost:*", "127.0.0.1:*"],
)

async def run_server():
await test_mcp.run_streamable_http_async()

server_task = asyncio.create_task(run_server())
await asyncio.sleep(0.5)

try:
async with httpx.AsyncClient(trust_env=False) as client:
# Request with disallowed Host header should be rejected
response = await client.post(
f"http://127.0.0.1:{free_port}/mcp",
json={
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test", "version": "1.0"},
},
},
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"Host": "malicious-host.example.com:8080",
},
timeout=5.0,
)
# Should be rejected with 421
assert response.status_code == 421
finally:
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
pass
2 changes: 2 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def setup_state():
listen_host="0.0.0.0",
listen_port=8080,
audit_enabled=False,
allowed_hosts=[],
allowed_origins=[],
)
# Set global config for get_config() calls
server._config = config
Expand Down
Loading