diff --git a/README.md b/README.md index 20c3b11..60cd6e9 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,32 @@ Claude will: 4. Test it with `dryrun_pipeline` tool ## Security -All queries pass through a security gate that: + +### Database User Configuration (Recommended) + +For production deployments, create a **read-only database user** for the MCP server. This provides defense-in-depth security at the database level. + +Configure a read-only user in GreptimeDB using [static user provider](https://docs.greptime.com/user-guide/deployments-administration/authentication/static/#permission-modes): + +``` +# User format: username:permission_mode=password +mcp_readonly:readonly=your_secure_password +``` + +Permission modes: +- `readonly` (or `ro`) - Can only query data (recommended for MCP server) +- `writeonly` (or `wo`) - Can only write data +- `readwrite` (or `rw`) - Full access (default) + +Then configure the MCP server to use this user: +```bash +GREPTIMEDB_USER=mcp_readonly +GREPTIMEDB_PASSWORD=your_secure_password +``` + +### Application-Level Security Gate + +All queries also pass through a security gate that: - Blocks DDL/DML operations: DROP, DELETE, TRUNCATE, UPDATE, INSERT, ALTER, CREATE, GRANT, REVOKE - Blocks dynamic SQL execution: EXEC, EXECUTE, CALL - Blocks data modification: REPLACE INTO @@ -178,6 +203,11 @@ GREPTIMEDB_TIMEZONE=UTC GREPTIMEDB_POOL_SIZE=5 # Optional: Connection pool size (defaults to 5) GREPTIMEDB_MASK_ENABLED=true # Optional: Enable data masking (defaults to true) GREPTIMEDB_MASK_PATTERNS= # Optional: Additional sensitive column patterns (comma-separated) + +# MCP Server Transport Options +GREPTIMEDB_TRANSPORT=stdio # Optional: Transport mode (stdio, sse, streamable-http, defaults to stdio) +GREPTIMEDB_LISTEN_HOST=0.0.0.0 # Optional: HTTP server bind host (defaults to 0.0.0.0) +GREPTIMEDB_LISTEN_PORT=8080 # Optional: HTTP server bind port (defaults to 8080) ``` Or via command-line args: @@ -192,7 +222,33 @@ Or via command-line args: * `--timezone` the session time zone, empty by default (using server default time zone), * `--pool-size` the connection pool size, `5` by default, * `--mask-enabled` enable data masking for sensitive columns, `true` by default, -* `--mask-patterns` additional sensitive column patterns (comma-separated), empty by default. +* `--mask-patterns` additional sensitive column patterns (comma-separated), empty by default, +* `--transport` MCP transport mode (`stdio`, `sse`, `streamable-http`), `stdio` by default, +* `--listen-host` HTTP server bind host (for sse/streamable-http), `0.0.0.0` by default, +* `--listen-port` HTTP server bind port (for sse/streamable-http), `8080` by default. + +## HTTP Server Mode + +For containerized or Kubernetes deployments, you can run the MCP server in HTTP mode instead of stdio: + +```bash +# Streamable HTTP mode (recommended for production) +greptimedb-mcp-server --transport streamable-http --listen-port 8080 + +# SSE mode (legacy, for older clients) +greptimedb-mcp-server --transport sse --listen-host 0.0.0.0 --listen-port 3000 + +# Via environment variables (for Docker/K8s) +GREPTIMEDB_TRANSPORT=streamable-http \ +GREPTIMEDB_LISTEN_HOST=0.0.0.0 \ +GREPTIMEDB_LISTEN_PORT=8080 \ +greptimedb-mcp-server +``` + +**Transport modes:** +- `stdio` (default): Standard input/output, for local CLI integration (e.g., Claude Desktop) +- `streamable-http`: HTTP-based transport with SSE streaming, recommended for remote/production deployments +- `sse`: Server-Sent Events transport (legacy, being deprecated in MCP spec) # Usage diff --git a/src/greptimedb_mcp_server/config.py b/src/greptimedb_mcp_server/config.py index b98c2cc..7935b87 100644 --- a/src/greptimedb_mcp_server/config.py +++ b/src/greptimedb_mcp_server/config.py @@ -64,6 +64,21 @@ class Config: Additional sensitive column patterns (comma-separated) """ + transport: str + """ + MCP transport mode: stdio, sse, or streamable-http + """ + + listen_host: str + """ + MCP HTTP server bind host (for sse/streamable-http transports) + """ + + listen_port: int + """ + MCP HTTP server bind port (for sse/streamable-http transports) + """ + @staticmethod def from_env_arguments() -> "Config": """ @@ -149,6 +164,28 @@ def from_env_arguments() -> "Config": default=os.getenv("GREPTIMEDB_MASK_PATTERNS", ""), ) + parser.add_argument( + "--transport", + type=str, + choices=["stdio", "sse", "streamable-http"], + help="MCP transport mode (default: stdio)", + default=os.getenv("GREPTIMEDB_TRANSPORT", "stdio"), + ) + + parser.add_argument( + "--listen-host", + type=str, + help="MCP HTTP server bind host (default: 0.0.0.0)", + default=os.getenv("GREPTIMEDB_LISTEN_HOST", "0.0.0.0"), + ) + + parser.add_argument( + "--listen-port", + type=int, + help="MCP HTTP server bind port (default: 8080)", + default=int(os.getenv("GREPTIMEDB_LISTEN_PORT", "8080")), + ) + args = parser.parse_args() return Config( host=args.host, @@ -162,4 +199,7 @@ def from_env_arguments() -> "Config": http_protocol=args.http_protocol, mask_enabled=args.mask_enabled, mask_patterns=args.mask_patterns, + transport=args.transport, + listen_host=args.listen_host, + listen_port=args.listen_port, ) diff --git a/src/greptimedb_mcp_server/server.py b/src/greptimedb_mcp_server/server.py index 3042e31..c57247f 100644 --- a/src/greptimedb_mcp_server/server.py +++ b/src/greptimedb_mcp_server/server.py @@ -10,6 +10,8 @@ validate_query_component, validate_duration, validate_fill, + validate_time_expression, + format_tql_time_param, ) import asyncio @@ -77,10 +79,26 @@ def get_http_auth(self) -> aiohttp.BasicAuth | None: return None +# Global config (set by main() before run()) +_config: Config | None = None + # Global state (initialized in lifespan) _state: AppState | None = None +def get_config() -> Config: + """Get the parsed configuration. + + Falls back to parsing from env/args if not pre-initialized by main(). + This preserves compatibility with alternative entry points like + `mcp dev greptimedb_mcp_server.server:mcp` or programmatic imports. + """ + global _config + if _config is None: + _config = Config.from_env_arguments() + return _config + + def get_state() -> AppState: """Get the application state.""" if _state is None: @@ -93,7 +111,7 @@ async def lifespan(mcp: FastMCP): """Initialize application state on startup.""" global _state - config = Config.from_env_arguments() + config = get_config() db_config = { "host": config.host, "port": config.port, @@ -337,9 +355,14 @@ async def execute_tql( "Example: rate(http_requests_total[5m])", ], start: Annotated[ - str, "Start time (RFC3339, Unix timestamp, or relative like 'now-1h')" + str, + "Start time: SQL expression (e.g., \"now() - interval '5' minute\"), " + "RFC3339 (e.g., '2024-01-01T00:00:00Z'), or Unix timestamp", + ], + end: Annotated[ + str, + "End time: SQL expression (e.g., 'now()'), " "RFC3339, or Unix timestamp", ], - end: Annotated[str, "End time (RFC3339, Unix timestamp, or relative like 'now')"], step: Annotated[str, "Query resolution step, e.g., '1m', '5m', '1h'"], lookback: Annotated[str | None, "Lookback delta for range queries"] = None, format: Annotated[ @@ -354,8 +377,8 @@ async def execute_tql( if format not in VALID_FORMATS: raise ValueError(f"Invalid format: {format}. Must be one of: {VALID_FORMATS}") - validate_tql_param(start, "start") - validate_tql_param(end, "end") + validate_time_expression(start, "start") + validate_time_expression(end, "end") validate_tql_param(step, "step") if lookback: validate_tql_param(lookback, "lookback") @@ -364,10 +387,12 @@ async def execute_tql( if is_dangerous: return f"Error: Dangerous operation blocked: {reason}" + start_param = format_tql_time_param(start) + end_param = format_tql_time_param(end) if lookback: - tql = f"TQL EVAL ('{start}', '{end}', '{step}', '{lookback}') {query}" + tql = f"TQL EVAL ({start_param}, {end_param}, '{step}', '{lookback}') {query}" else: - tql = f"TQL EVAL ('{start}', '{end}', '{step}') {query}" + tql = f"TQL EVAL ({start_param}, {end_param}, '{step}') {query}" start_time = time.time() @@ -834,7 +859,22 @@ def prompt_fn({arg_params}) -> str: def main(): """Main entry point.""" - mcp.run() + global _config + _config = Config.from_env_arguments() + + # Only configure HTTP server settings for non-stdio transports + # to avoid overriding user's programmatic configuration + if _config.transport != "stdio": + mcp.settings.host = _config.listen_host + mcp.settings.port = _config.listen_port + logger.info( + f"Starting MCP server with transport: {_config.transport} " + f"on {_config.listen_host}:{_config.listen_port}" + ) + else: + logger.info("Starting MCP server with transport: stdio") + + mcp.run(transport=_config.transport) if __name__ == "__main__": diff --git a/src/greptimedb_mcp_server/templates/promql_analysis/config.yaml b/src/greptimedb_mcp_server/templates/promql_analysis/config.yaml index ade6ca2..ea0bc16 100644 --- a/src/greptimedb_mcp_server/templates/promql_analysis/config.yaml +++ b/src/greptimedb_mcp_server/templates/promql_analysis/config.yaml @@ -5,10 +5,10 @@ arguments: description: "The metric name to query (required)." required: true - name: "start_time" - description: "Query start time (e.g., 'now-1h' or ISO timestamp)." + description: "Start time: SQL expression (now() - interval '5' minute), RFC3339, or Unix timestamp." required: true - name: "end_time" - description: "Query end time (e.g., 'now' or ISO timestamp)." + description: "End time: SQL expression (now()), RFC3339, or Unix timestamp." required: true metadata: tags: diff --git a/src/greptimedb_mcp_server/templates/promql_analysis/template.md b/src/greptimedb_mcp_server/templates/promql_analysis/template.md index 7a73b09..f39b299 100644 --- a/src/greptimedb_mcp_server/templates/promql_analysis/template.md +++ b/src/greptimedb_mcp_server/templates/promql_analysis/template.md @@ -57,6 +57,7 @@ histogram_quantile(0.99, rate({{ metric }}_bucket[5m])) > 0.5 ## Notes - Use `execute_tql` tool with: query, start, end, step (required), lookback (optional) +- Time formats: SQL expression (now(), now() - interval '5' minute), RFC3339, or Unix timestamp - Label matchers: `=`, `!=`, `=~` (regex), `!~` - Time durations: s, m, h, d, w diff --git a/src/greptimedb_mcp_server/utils.py b/src/greptimedb_mcp_server/utils.py index a705ce3..685266e 100644 --- a/src/greptimedb_mcp_server/utils.py +++ b/src/greptimedb_mcp_server/utils.py @@ -135,3 +135,32 @@ def validate_fill(value: str) -> str: if not FILL_PATTERN.match(value): raise ValueError("Invalid fill: must be NULL, PREV, LINEAR, or a number") return value + + +def is_sql_time_expression(value: str) -> bool: + """Check if value is a SQL time expression (contains function call).""" + return "(" in value + + +def format_tql_time_param(value: str) -> str: + """Format time parameter for TQL: quote literals, leave SQL expressions as-is.""" + if is_sql_time_expression(value): + return value + # Escape single quotes in literal values to avoid breaking the TQL statement + safe_value = value.replace("'", "''") + return f"'{safe_value}'" + + +def validate_time_expression(value: str, name: str) -> str: + """Validate time expression for TQL start/end parameters.""" + if not value: + raise ValueError(f"{name} is required") + if ";" in value or "--" in value: + raise ValueError(f"Invalid characters in {name}") + # Guard against malformed or injected strings with unbalanced quotes + if value.count("'") % 2 != 0: + raise ValueError(f"Unbalanced quotes in {name}") + is_dangerous, reason = security_gate(value) + if is_dangerous: + raise ValueError(f"Dangerous pattern in {name}: {reason}") + return value diff --git a/tests/test_config.py b/tests/test_config.py index d4a0c07..cdf2f86 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -20,6 +20,9 @@ def test_config_default_values(): assert config.http_protocol == "http" assert config.mask_enabled is True assert config.mask_patterns == "" + assert config.transport == "stdio" + assert config.listen_host == "0.0.0.0" + assert config.listen_port == 8080 def test_config_env_variables(): @@ -36,6 +39,9 @@ def test_config_env_variables(): "GREPTIMEDB_HTTP_PROTOCOL": "https", "GREPTIMEDB_MASK_ENABLED": "false", "GREPTIMEDB_MASK_PATTERNS": "phone,address", + "GREPTIMEDB_TRANSPORT": "streamable-http", + "GREPTIMEDB_LISTEN_HOST": "127.0.0.1", + "GREPTIMEDB_LISTEN_PORT": "3000", } with patch.dict(os.environ, env_vars): @@ -51,6 +57,9 @@ def test_config_env_variables(): assert config.http_protocol == "https" assert config.mask_enabled is False assert config.mask_patterns == "phone,address" + assert config.transport == "streamable-http" + assert config.listen_host == "127.0.0.1" + assert config.listen_port == 3000 def test_config_cli_arguments(): @@ -77,6 +86,12 @@ def test_config_cli_arguments(): "false", "--mask-patterns", "custom1,custom2", + "--transport", + "sse", + "--listen-host", + "192.168.1.1", + "--listen-port", + "9090", ] with patch.dict(os.environ, {}, clear=True): @@ -92,6 +107,9 @@ def test_config_cli_arguments(): assert config.http_protocol == "https" assert config.mask_enabled is False assert config.mask_patterns == "custom1,custom2" + assert config.transport == "sse" + assert config.listen_host == "192.168.1.1" + assert config.listen_port == 9090 def test_config_precedence(): @@ -108,6 +126,9 @@ def test_config_precedence(): "GREPTIMEDB_HTTP_PROTOCOL": "http", "GREPTIMEDB_MASK_ENABLED": "true", "GREPTIMEDB_MASK_PATTERNS": "env_pattern", + "GREPTIMEDB_TRANSPORT": "stdio", + "GREPTIMEDB_LISTEN_HOST": "env-listen-host", + "GREPTIMEDB_LISTEN_PORT": "1111", } cli_args = [ @@ -130,6 +151,12 @@ def test_config_precedence(): "false", "--mask-patterns", "cli_pattern", + "--transport", + "streamable-http", + "--listen-host", + "cli-listen-host", + "--listen-port", + "2222", ] with patch.dict(os.environ, env_vars): @@ -145,3 +172,6 @@ def test_config_precedence(): assert config.http_protocol == "https" assert config.mask_enabled is False assert config.mask_patterns == "cli_pattern" + assert config.transport == "streamable-http" + assert config.listen_host == "cli-listen-host" + assert config.listen_port == 2222 diff --git a/tests/test_http_transport.py b/tests/test_http_transport.py new file mode 100644 index 0000000..42f015a --- /dev/null +++ b/tests/test_http_transport.py @@ -0,0 +1,250 @@ +"""Blackbox tests for HTTP transport modes (streamable-http and sse).""" + +import asyncio +import json +import pytest +import socket +from contextlib import closing +from unittest.mock import patch + +import httpx + + +def find_free_port() -> int: + """Find a free port on localhost.""" + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +@pytest.fixture +def free_port(): + """Get a free port for testing.""" + return find_free_port() + + +@pytest.fixture +def mock_db_connection(): + """Mock database connection for testing.""" + with patch("greptimedb_mcp_server.server.connect") as mock_connect: + mock_conn = mock_connect.return_value.__enter__.return_value + mock_cursor = mock_conn.cursor.return_value.__enter__.return_value + mock_cursor.fetchone.return_value = ("GreptimeDB 0.9.0",) + yield mock_connect + + +class TestStreamableHttpTransport: + """Tests for streamable-http transport mode.""" + + @pytest.mark.asyncio + async def test_initialize_returns_valid_mcp_response( + self, free_port, mock_db_connection + ): + """Test that initialize request returns valid MCP protocol response.""" + from mcp.server.fastmcp import FastMCP + + test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port) + + @test_mcp.tool() + def ping() -> str: + return "pong" + + async def run_server(): + await test_mcp.run_streamable_http_async() + + server_task = asyncio.create_task(run_server()) + await asyncio.sleep(0.5) + + try: + async with httpx.AsyncClient(trust_env=False) as client: + response = await client.post( + f"http://127.0.0.1:{free_port}/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"}, + }, + }, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + timeout=5.0, + ) + assert response.status_code == 200 + + # Response is SSE format: "event: message\r\ndata: {...}\r\n" + # Extract JSON from the data line + content = response.text + data = None + for line in content.split("\n"): + if line.startswith("data:"): + data = json.loads(line[5:].strip()) + break + + assert data is not None, "No data line found in SSE response" + assert data.get("jsonrpc") == "2.0" + assert data.get("id") == 1 + assert "result" in data + + # Verify MCP initialize response fields + result = data["result"] + assert "protocolVersion" in result + assert "capabilities" in result + assert "serverInfo" in result + assert result["serverInfo"]["name"] == "test_server" + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_mcp_endpoint_rejects_invalid_json( + self, free_port, mock_db_connection + ): + """Test that /mcp endpoint returns error for invalid JSON.""" + from mcp.server.fastmcp import FastMCP + + test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port) + + async def run_server(): + await test_mcp.run_streamable_http_async() + + server_task = asyncio.create_task(run_server()) + await asyncio.sleep(0.5) + + try: + async with httpx.AsyncClient(trust_env=False) as client: + response = await client.post( + f"http://127.0.0.1:{free_port}/mcp", + content=b"not valid json", + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + timeout=5.0, + ) + # Should return error status for invalid request + assert response.status_code in [400, 422, 500] + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + +class TestSseTransport: + """Tests for SSE transport mode.""" + + @pytest.mark.asyncio + async def test_sse_endpoint_returns_endpoint_event( + self, free_port, mock_db_connection + ): + """Test that /sse endpoint returns SSE event with messages endpoint.""" + from mcp.server.fastmcp import FastMCP + + test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port) + + async def run_server(): + await test_mcp.run_sse_async() + + server_task = asyncio.create_task(run_server()) + await asyncio.sleep(0.5) + + try: + async with httpx.AsyncClient(trust_env=False) as client: + async with client.stream( + "GET", f"http://127.0.0.1:{free_port}/sse", timeout=2.0 + ) as response: + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get( + "content-type", "" + ) + + # Read first SSE event (endpoint announcement) + event_data = "" + async for line in response.aiter_lines(): + if line.startswith("data:"): + event_data = line[5:].strip() + break + + # Verify endpoint URL is provided + assert "/messages/" in event_data + except httpx.ReadTimeout: + pass # SSE stream stays open, timeout is expected + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_messages_endpoint_rejects_invalid_session( + self, free_port, mock_db_connection + ): + """Test that /messages/ endpoint rejects requests without valid session.""" + from mcp.server.fastmcp import FastMCP + + test_mcp = FastMCP("test_server", host="127.0.0.1", port=free_port) + + async def run_server(): + await test_mcp.run_sse_async() + + server_task = asyncio.create_task(run_server()) + await asyncio.sleep(0.5) + + try: + async with httpx.AsyncClient(trust_env=False) as client: + response = await client.post( + f"http://127.0.0.1:{free_port}/messages/", + json={"jsonrpc": "2.0", "id": 1, "method": "ping"}, + timeout=5.0, + ) + # Without valid session ID, should return error + assert response.status_code in [400, 404, 500] + finally: + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + +class TestTransportConfig: + """Tests for transport configuration.""" + + def test_config_transport_choices(self): + """Test that transport config only accepts valid choices.""" + from greptimedb_mcp_server.config import Config + import sys + + # Valid transports should work + for transport in ["stdio", "sse", "streamable-http"]: + with patch.dict("os.environ", {}, clear=True): + with patch.object( + sys, + "argv", + ["test", "--transport", transport], + ): + config = Config.from_env_arguments() + assert config.transport == transport + + def test_config_invalid_transport_rejected(self): + """Test that invalid transport is rejected.""" + import sys + + with patch.dict("os.environ", {}, clear=True): + with patch.object(sys, "argv", ["test", "--transport", "invalid"]): + with pytest.raises(SystemExit): + from greptimedb_mcp_server.config import Config + + Config.from_env_arguments() diff --git a/tests/test_server.py b/tests/test_server.py index e0d4d4c..56981b2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -36,7 +36,12 @@ def setup_state(): http_protocol="http", mask_enabled=False, mask_patterns="", + transport="stdio", + listen_host="0.0.0.0", + listen_port=8080, ) + # Set global config for get_config() calls + server._config = config db_config = { "host": config.host, "port": config.port, @@ -251,6 +256,21 @@ async def test_execute_tql_injection_blocked(): assert "Invalid characters" in str(excinfo.value) +@pytest.mark.asyncio +async def test_execute_tql_escapes_literal_quotes(): + """Literal start/end values with quotes should be escaped, not injected""" + result = await execute_tql( + query="rate(http_requests_total[5m])", + start="2024-01-01T00:00:00Z'quoted'", + end="2024-01-01T01:00:00Z", + step="1m", + ) + + data = json.loads(result) + # Quotes should be doubled inside the TQL string to keep the literal safe + assert "2024-01-01T00:00:00Z''quoted''" in data["tql"] + + @pytest.mark.asyncio async def test_execute_tql_dangerous_query_blocked(): """Test execute_tql blocks dangerous patterns in query""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 7b3dc12..52903df 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,9 @@ templates_loader, security_gate, validate_table_name, + is_sql_time_expression, + format_tql_time_param, + validate_time_expression, ) from greptimedb_mcp_server.formatter import format_results @@ -418,3 +421,105 @@ def test_validate_table_name_invalid(): with pytest.raises(ValueError) as excinfo: validate_table_name("") assert "required" in str(excinfo.value) + + +# Tests for time expression functions + + +def test_is_sql_time_expression_with_function(): + """Test is_sql_time_expression detects SQL function calls""" + assert is_sql_time_expression("now()") is True + assert is_sql_time_expression("now() - interval '5' minute") is True + assert is_sql_time_expression("date_trunc('day', now())") is True + assert is_sql_time_expression("NOW()") is True + + +def test_is_sql_time_expression_without_function(): + """Test is_sql_time_expression returns False for non-SQL expressions""" + assert is_sql_time_expression("2024-01-01T00:00:00Z") is False + assert is_sql_time_expression("1704067200") is False + assert is_sql_time_expression("1704067200.123") is False + + +def test_format_tql_time_param_sql_expression(): + """Test format_tql_time_param leaves SQL expressions unquoted""" + assert format_tql_time_param("now()") == "now()" + assert format_tql_time_param("now() - interval '5' minute") == ( + "now() - interval '5' minute" + ) + assert format_tql_time_param("date_trunc('day', now())") == ( + "date_trunc('day', now())" + ) + + +def test_format_tql_time_param_literal(): + """Test format_tql_time_param quotes literal values""" + assert format_tql_time_param("2024-01-01T00:00:00Z") == "'2024-01-01T00:00:00Z'" + assert format_tql_time_param("1704067200") == "'1704067200'" + assert format_tql_time_param("1704067200.123") == "'1704067200.123'" + + +def test_validate_time_expression_valid(): + """Test validate_time_expression accepts valid expressions""" + assert validate_time_expression("now()", "start") == "now()" + assert validate_time_expression("now() - interval '5' minute", "start") == ( + "now() - interval '5' minute" + ) + assert validate_time_expression("2024-01-01T00:00:00Z", "end") == ( + "2024-01-01T00:00:00Z" + ) + assert validate_time_expression("1704067200", "start") == "1704067200" + + +def test_validate_time_expression_empty(): + """Test validate_time_expression rejects empty values""" + with pytest.raises(ValueError) as excinfo: + validate_time_expression("", "start") + assert "start is required" in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + validate_time_expression("", "end") + assert "end is required" in str(excinfo.value) + + +def test_validate_time_expression_injection(): + """Test validate_time_expression blocks injection attempts""" + with pytest.raises(ValueError) as excinfo: + validate_time_expression("now(); DROP TABLE users", "start") + assert "Invalid characters" in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + validate_time_expression("now() -- comment", "end") + assert "Invalid characters" in str(excinfo.value) + + +def test_validate_time_expression_unbalanced_quotes(): + """Test validate_time_expression blocks unbalanced quotes""" + # Odd number of quotes should be rejected + with pytest.raises(ValueError) as excinfo: + validate_time_expression("2024-01-01T00:00:00Z' OR 1=1", "start") + assert "Unbalanced quotes" in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + validate_time_expression("now() - interval '5 minute", "end") + assert "Unbalanced quotes" in str(excinfo.value) + + # Balanced quotes are allowed + result = validate_time_expression("now() - interval '5' minute", "start") + assert result == "now() - interval '5' minute" + + +def test_format_tql_time_param_escapes_quotes(): + """Test format_tql_time_param escapes quotes in literals""" + # Quotes in literals should be escaped + assert format_tql_time_param("test'value") == "'test''value'" + assert format_tql_time_param("a''b") == "'a''''b'" + # SQL expressions are not escaped + assert format_tql_time_param("now()") == "now()" + + +def test_validate_time_expression_dangerous(): + """Test validate_time_expression blocks dangerous patterns""" + with pytest.raises(ValueError) as excinfo: + validate_time_expression("DELETE FROM users", "start") + assert "Dangerous pattern" in str(excinfo.value)