diff --git a/README.md b/README.md index 3a50bd9..11f0809 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,8 @@ GREPTIMEDB_AUDIT_ENABLED=true # Enable audit logging GREPTIMEDB_TRANSPORT=stdio # stdio, sse, or streamable-http GREPTIMEDB_LISTEN_HOST=0.0.0.0 # HTTP server bind host GREPTIMEDB_LISTEN_PORT=8080 # HTTP server bind port +GREPTIMEDB_ALLOWED_HOSTS= # DNS rebinding protection (comma-separated) +GREPTIMEDB_ALLOWED_ORIGINS= # CORS allowed origins (comma-separated) ``` ### CLI Arguments @@ -113,6 +115,28 @@ greptimedb-mcp-server --transport streamable-http --listen-port 8080 greptimedb-mcp-server --transport sse --listen-port 3000 ``` +#### DNS Rebinding Protection + +By default, DNS rebinding protection is **disabled** for compatibility with proxies, gateways, and Kubernetes services. To enable it, use `--allowed-hosts`: + +```bash +# Enable DNS rebinding protection with allowed hosts +greptimedb-mcp-server --transport streamable-http \ + --allowed-hosts "localhost:*,127.0.0.1:*,my-service.namespace:*" + +# With custom allowed origins for CORS +greptimedb-mcp-server --transport streamable-http \ + --allowed-hosts "my-service.namespace:*" \ + --allowed-origins "http://localhost:*,https://my-app.example.com" + +# Or via environment variables +GREPTIMEDB_ALLOWED_HOSTS="localhost:*,my-service.namespace:*" \ +GREPTIMEDB_ALLOWED_ORIGINS="http://localhost:*" \ + greptimedb-mcp-server --transport streamable-http +``` + +If you encounter `421 Invalid Host Header` errors, either disable protection (default) or add your host to the allowed list. + ## Security ### Read-Only Database User (Recommended) diff --git a/pyproject.toml b/pyproject.toml index e12528d..851e994 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "greptimedb-mcp-server" -version = "0.4.1" +version = "0.4.2" description = "MCP server for GreptimeDB with SQL/TQL/PromQL support, sensitive data masking, and prompt templates for observability data analysis." readme = "README.md" license = {text = "MIT"} diff --git a/src/greptimedb_mcp_server/config.py b/src/greptimedb_mcp_server/config.py index 036ca70..e569823 100644 --- a/src/greptimedb_mcp_server/config.py +++ b/src/greptimedb_mcp_server/config.py @@ -84,6 +84,18 @@ class Config: Enable audit logging for all tool calls """ + allowed_hosts: list[str] + """ + Allowed hosts for DNS rebinding protection (for sse/streamable-http). + If empty, DNS rebinding protection is disabled. + """ + + allowed_origins: list[str] + """ + Allowed origins for CORS (for sse/streamable-http). + Only used when DNS rebinding protection is enabled. + """ + @staticmethod def from_env_arguments() -> "Config": """ @@ -198,7 +210,30 @@ def from_env_arguments() -> "Config": default=os.getenv("GREPTIMEDB_AUDIT_ENABLED", "true"), ) + parser.add_argument( + "--allowed-hosts", + type=str, + help=( + "Allowed hosts for DNS rebinding protection (comma-separated). " + "If not set, DNS rebinding protection is disabled. " + "Example: localhost:*,127.0.0.1:*,my-service.namespace:*" + ), + default=os.getenv("GREPTIMEDB_ALLOWED_HOSTS", ""), + ) + + parser.add_argument( + "--allowed-origins", + type=str, + help=( + "Allowed origins for CORS (comma-separated). " + "Only used when allowed-hosts is set. " + "Example: http://localhost:*,https://my-app.example.com" + ), + default=os.getenv("GREPTIMEDB_ALLOWED_ORIGINS", ""), + ) + args = parser.parse_args() + return Config( host=args.host, port=args.port, @@ -215,4 +250,14 @@ def from_env_arguments() -> "Config": listen_host=args.listen_host, listen_port=args.listen_port, audit_enabled=args.audit_enabled, + allowed_hosts=_parse_comma_separated(args.allowed_hosts), + allowed_origins=_parse_comma_separated(args.allowed_origins), ) + + +def _parse_comma_separated(value: str) -> list[str]: + """Parse a comma-separated string into a list of trimmed non-empty strings.""" + value = value.strip() + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] diff --git a/src/greptimedb_mcp_server/server.py b/src/greptimedb_mcp_server/server.py index e909ade..fe1d9c5 100644 --- a/src/greptimedb_mcp_server/server.py +++ b/src/greptimedb_mcp_server/server.py @@ -28,6 +28,7 @@ import aiohttp from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.server import TransportSecuritySettings from mysql.connector import connect, Error from mysql.connector.pooling import MySQLConnectionPool @@ -896,6 +897,31 @@ def main(): if _config.transport != "stdio": mcp.settings.host = _config.listen_host mcp.settings.port = _config.listen_port + + # Configure DNS rebinding protection + # If allowed_hosts is empty, disable protection for compatibility + # with proxies, gateways, and Kubernetes services + if _config.allowed_hosts: + security_kwargs = { + "enable_dns_rebinding_protection": True, + "allowed_hosts": _config.allowed_hosts, + } + if _config.allowed_origins: + security_kwargs["allowed_origins"] = _config.allowed_origins + mcp.settings.transport_security = TransportSecuritySettings( + **security_kwargs + ) + logger.info( + f"DNS rebinding protection: enabled " + f"(allowed_hosts: {_config.allowed_hosts}, " + f"allowed_origins: {_config.allowed_origins or 'default'})" + ) + else: + mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ) + logger.info("DNS rebinding protection: disabled") + logger.info( f"Starting MCP server with transport: {_config.transport} " f"on {_config.listen_host}:{_config.listen_port}" diff --git a/tests/test_config.py b/tests/test_config.py index cdf2f86..9eb7ffc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ import os from unittest.mock import patch -from greptimedb_mcp_server.config import Config +from greptimedb_mcp_server.config import Config, _parse_comma_separated def test_config_default_values(): @@ -23,6 +23,8 @@ def test_config_default_values(): assert config.transport == "stdio" assert config.listen_host == "0.0.0.0" assert config.listen_port == 8080 + assert config.allowed_hosts == [] + assert config.allowed_origins == [] def test_config_env_variables(): @@ -42,6 +44,8 @@ def test_config_env_variables(): "GREPTIMEDB_TRANSPORT": "streamable-http", "GREPTIMEDB_LISTEN_HOST": "127.0.0.1", "GREPTIMEDB_LISTEN_PORT": "3000", + "GREPTIMEDB_ALLOWED_HOSTS": "localhost:*,127.0.0.1:*", + "GREPTIMEDB_ALLOWED_ORIGINS": "http://localhost:*,https://example.com", } with patch.dict(os.environ, env_vars): @@ -60,6 +64,11 @@ def test_config_env_variables(): assert config.transport == "streamable-http" assert config.listen_host == "127.0.0.1" assert config.listen_port == 3000 + assert config.allowed_hosts == ["localhost:*", "127.0.0.1:*"] + assert config.allowed_origins == [ + "http://localhost:*", + "https://example.com", + ] def test_config_cli_arguments(): @@ -92,6 +101,10 @@ def test_config_cli_arguments(): "192.168.1.1", "--listen-port", "9090", + "--allowed-hosts", + "my-service.namespace:*", + "--allowed-origins", + "http://my-app.example.com", ] with patch.dict(os.environ, {}, clear=True): @@ -110,6 +123,8 @@ def test_config_cli_arguments(): assert config.transport == "sse" assert config.listen_host == "192.168.1.1" assert config.listen_port == 9090 + assert config.allowed_hosts == ["my-service.namespace:*"] + assert config.allowed_origins == ["http://my-app.example.com"] def test_config_precedence(): @@ -175,3 +190,27 @@ def test_config_precedence(): assert config.transport == "streamable-http" assert config.listen_host == "cli-listen-host" assert config.listen_port == 2222 + + +class TestParseCommaSeparated: + """Tests for _parse_comma_separated helper function.""" + + def test_empty_and_whitespace(self): + assert _parse_comma_separated("") == [] + assert _parse_comma_separated(" ") == [] + + def test_single_value(self): + assert _parse_comma_separated("localhost:*") == ["localhost:*"] + assert _parse_comma_separated(" localhost:* ") == ["localhost:*"] + + def test_multiple_values(self): + assert _parse_comma_separated(" localhost:* , 127.0.0.1:* ") == [ + "localhost:*", + "127.0.0.1:*", + ] + + def test_empty_items_filtered(self): + assert _parse_comma_separated("localhost:*,, ,127.0.0.1:*") == [ + "localhost:*", + "127.0.0.1:*", + ] diff --git a/tests/test_http_transport.py b/tests/test_http_transport.py index 42f015a..dbf58d8 100644 --- a/tests/test_http_transport.py +++ b/tests/test_http_transport.py @@ -248,3 +248,110 @@ def test_config_invalid_transport_rejected(self): from greptimedb_mcp_server.config import Config Config.from_env_arguments() + + +class TestDnsRebindingProtection: + """Tests for DNS rebinding protection configuration.""" + + @pytest.mark.asyncio + async def test_protection_disabled_by_default(self, free_port, mock_db_connection): + """Test that DNS rebinding protection is disabled when allowed_hosts is empty.""" + from mcp.server.fastmcp import FastMCP + from mcp.server.fastmcp.server import TransportSecuritySettings + + test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port) + + # Simulate our server.py logic: empty allowed_hosts = disabled + test_mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ) + + async def run_server(): + await test_mcp.run_streamable_http_async() + + server_task = asyncio.create_task(run_server()) + await asyncio.sleep(0.5) + + try: + async with httpx.AsyncClient(trust_env=False) as client: + # Request with arbitrary Host header should succeed + response = await client.post( + f"http://127.0.0.1:{free_port}/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"}, + }, + }, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Host": "arbitrary-host.example.com:8080", + }, + timeout=5.0, + ) + # Should succeed (not 421) + assert response.status_code == 200 + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_protection_enabled_rejects_invalid_host( + self, free_port, mock_db_connection + ): + """Test that enabled protection rejects requests with invalid Host header.""" + from mcp.server.fastmcp import FastMCP + from mcp.server.fastmcp.server import TransportSecuritySettings + + test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port) + + # Enable protection with specific allowed hosts + test_mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["localhost:*", "127.0.0.1:*"], + ) + + async def run_server(): + await test_mcp.run_streamable_http_async() + + server_task = asyncio.create_task(run_server()) + await asyncio.sleep(0.5) + + try: + async with httpx.AsyncClient(trust_env=False) as client: + # Request with disallowed Host header should be rejected + response = await client.post( + f"http://127.0.0.1:{free_port}/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"}, + }, + }, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "Host": "malicious-host.example.com:8080", + }, + timeout=5.0, + ) + # Should be rejected with 421 + assert response.status_code == 421 + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass diff --git a/tests/test_server.py b/tests/test_server.py index 1ec7684..5a6b71f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -40,6 +40,8 @@ def setup_state(): listen_host="0.0.0.0", listen_port=8080, audit_enabled=False, + allowed_hosts=[], + allowed_origins=[], ) # Set global config for get_config() calls server._config = config