diff --git a/circe/execution/LIMITATIONS.md b/circe/execution/LIMITATIONS.md new file mode 100644 index 0000000..e69de29 diff --git a/circe/execution/engine/custom_era.py b/circe/execution/engine/custom_era.py new file mode 100644 index 0000000..c7f08c3 --- /dev/null +++ b/circe/execution/engine/custom_era.py @@ -0,0 +1,314 @@ +"""Custom era implementation using SQLGlot for cross-dialect compatibility. + +This module implements custom era logic (gap-based event grouping with offsets) +using a reference SQL implementation that is transpiled to target dialects via SQLGlot. + +The custom era strategy groups events by person_id and creates "eras" where events +are grouped together if they occur within gap_days of each other. Each era can have +start and end offsets applied. + +Example: + Given events on days [1, 3, 10, 12] with gap_days=5: + - Era 1: days 1-3 (within 5 days) + - Era 2: days 10-12 (within 5 days) +""" + +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING + +import sqlglot + +from ..errors import CompilationError, UnsupportedFeatureError +from ..plan.schema import END_DATE, PERSON_ID, START_DATE + +if TYPE_CHECKING: + from ..typing import IbisBackendLike + + +# Mapping of Ibis backend names to SQLGlot dialect names +BACKEND_DIALECT_MAP = { + "duckdb": "duckdb", + "postgres": "postgres", + "spark": "spark", + "databricks": "databricks", + "snowflake": "snowflake", + "bigquery": "bigquery", + "trino": "trino", + "mysql": "mysql", + "sqlite": "sqlite", +} + + +def get_backend_dialect(backend: IbisBackendLike) -> str: + """Get SQLGlot dialect name from Ibis backend. + + Args: + backend: Ibis backend instance + + Returns: + SQLGlot dialect name + + Raises: + UnsupportedFeatureError: If backend is not supported for custom era + """ + backend_name = backend.name.lower() + + # Handle special cases + if "databricks" in backend_name or "spark" in backend_name: + return "databricks" + + dialect = BACKEND_DIALECT_MAP.get(backend_name) + if dialect is None: + raise UnsupportedFeatureError( + f"Custom era not supported for backend: {backend_name}. " + f"Supported backends: {', '.join(BACKEND_DIALECT_MAP.keys())}" + ) + + return dialect + + +def generate_custom_era_sql_reference( + events_table_name: str, + gap_days: int, + offset_start: int = 0, + offset_end: int = 0, +) -> str: + """Generate reference custom era SQL in PostgreSQL dialect. + + This is the "golden" implementation that gets transpiled to other dialects. + PostgreSQL is chosen as the reference because it has the most standard SQL + syntax for window functions and date arithmetic. + + The logic: + 1. Compute LAG(start_date) for each person's events + 2. Mark new era boundaries where gap > gap_days + 3. Assign era IDs using cumulative sum + 4. Group by person + era and compute era bounds with offsets + + Args: + events_table_name: Fully qualified events table (e.g., "schema.events") + gap_days: Maximum days between events in same era + offset_start: Days to subtract from era start (can be negative) + offset_end: Days to add to era end (can be negative) + + Returns: + PostgreSQL SQL query as string + """ + # Use triple-quoted f-string for readability + sql = f""" + WITH event_gaps AS ( + SELECT *, + LAG({START_DATE}) OVER ( + PARTITION BY {PERSON_ID} + ORDER BY {START_DATE} + ) AS prev_start_date + FROM {events_table_name} + ), + era_boundaries AS ( + SELECT *, + CASE + WHEN prev_start_date IS NULL THEN 1 + WHEN {START_DATE} - prev_start_date > INTERVAL '{gap_days} days' THEN 1 + ELSE 0 + END AS is_new_era + FROM event_gaps + ), + era_ids AS ( + SELECT *, + SUM(is_new_era) OVER ( + PARTITION BY {PERSON_ID} + ORDER BY {START_DATE} + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS era_id + FROM era_boundaries + ), + eras AS ( + SELECT + {PERSON_ID}, + era_id, + MIN({START_DATE}) - INTERVAL '{offset_start} days' AS era_start, + MAX({END_DATE}) + INTERVAL '{offset_end} days' AS era_end + FROM era_ids + GROUP BY {PERSON_ID}, era_id + ) + SELECT + {PERSON_ID}, + era_start AS {START_DATE}, + era_end AS {END_DATE} + FROM eras + ORDER BY {PERSON_ID}, {START_DATE} + """ + + return sql.strip() + + +def transpile_custom_era_sql( + reference_sql: str, + target_dialect: str, +) -> str: + """Transpile reference PostgreSQL SQL to target dialect using SQLGlot. + + Args: + reference_sql: Custom era SQL in PostgreSQL syntax + target_dialect: Target SQL dialect (e.g., "spark", "duckdb") + + Returns: + Transpiled SQL for target dialect + + Raises: + CompilationError: If transpilation fails + """ + try: + # Parse and transpile + transpiled = sqlglot.transpile( + reference_sql, + read="postgres", + write=target_dialect, + pretty=True, + ) + + if not transpiled: + raise CompilationError(f"SQLGlot transpilation produced no output for dialect: {target_dialect}") + + return transpiled[0] + + except Exception as exc: + raise CompilationError(f"Failed to transpile custom era SQL to {target_dialect}: {exc}") from exc + + +def build_custom_era_sql( + backend: IbisBackendLike, + events_table_name: str, + gap_days: int, + offset_start: int = 0, + offset_end: int = 0, + debug: bool = False, +) -> str: + """Build custom era SQL for a specific backend using SQLGlot transpilation. + + Args: + backend: Ibis backend instance + events_table_name: Fully qualified events table name + gap_days: Maximum days between events in same era + offset_start: Days to subtract from era start + offset_end: Days to add to era end + debug: If True, print reference and transpiled SQL + + Returns: + Transpiled custom era SQL for the backend's dialect + + Raises: + UnsupportedFeatureError: If backend doesn't support custom era + CompilationError: If SQL generation or transpilation fails + """ + # Validate parameters + if gap_days < 0: + raise CompilationError(f"gap_days must be non-negative, got: {gap_days}") + + # Get target dialect + target_dialect = get_backend_dialect(backend) + + # Generate reference SQL + reference_sql = generate_custom_era_sql_reference( + events_table_name=events_table_name, + gap_days=gap_days, + offset_start=offset_start, + offset_end=offset_end, + ) + + if debug: + print("=== Reference SQL (PostgreSQL) ===") + print(reference_sql) + print() + + # Transpile to target dialect + transpiled_sql = transpile_custom_era_sql(reference_sql, target_dialect) + + if debug: + print(f"=== Transpiled SQL ({target_dialect}) ===") + print(transpiled_sql) + print() + + return transpiled_sql + + +def apply_custom_era( + backend: IbisBackendLike, + events, + gap_days: int, + offset_start: int = 0, + offset_end: int = 0, + schema: str | None = None, + debug: bool = False, +): + """Apply custom era strategy to events using SQLGlot-transpiled SQL. + + This function: + 1. Materializes events to a temporary table + 2. Generates custom era SQL via SQLGlot transpilation + 3. Executes the SQL to produce era-grouped events + 4. Returns the result as an Ibis table expression + + Args: + backend: Ibis backend instance + events: Ibis table expression of events (must have person_id, start_date, end_date) + gap_days: Maximum days between events in same era + offset_start: Days to subtract from era start + offset_end: Days to add to era end + schema: Schema for temporary table (optional) + debug: If True, print generated SQL + + Returns: + Ibis table expression with custom eras applied + + Raises: + UnsupportedFeatureError: If backend doesn't support custom era + CompilationError: If SQL generation fails + """ + # Create a temporary table name + temp_table_name = f"_custom_era_events_{id(events)}" + full_table_name = f"{schema}.{temp_table_name}" if schema else temp_table_name + + # Materialize events to temporary table + # Note: Some backends may not support CREATE TEMP TABLE, adjust as needed + try: + backend.create_table(temp_table_name, events, schema=schema, temp=True) + except Exception: + # Fallback: try without temp=True + backend.create_table(temp_table_name, events, schema=schema, overwrite=True) + + try: + # Generate custom era SQL + era_sql = build_custom_era_sql( + backend=backend, + events_table_name=full_table_name, + gap_days=gap_days, + offset_start=offset_start, + offset_end=offset_end, + debug=debug, + ) + + # Execute SQL and return as Ibis table + eras = backend.sql(era_sql) + + return eras + + finally: + # Clean up temporary table + with contextlib.suppress(Exception): + backend.drop_table(temp_table_name, schema=schema, force=True) + + +def validate_custom_era_support(backend: IbisBackendLike) -> bool: + """Check if backend supports custom era implementation. + + Args: + backend: Ibis backend instance + + Returns: + True if custom era is supported for this backend + """ + backend_name = backend.name.lower() + return backend_name in BACKEND_DIALECT_MAP or "databricks" in backend_name or "spark" in backend_name diff --git a/circe/execution/engine/end_strategy.py b/circe/execution/engine/end_strategy.py index a099985..0671450 100644 --- a/circe/execution/engine/end_strategy.py +++ b/circe/execution/engine/end_strategy.py @@ -5,6 +5,14 @@ from ..errors import UnsupportedFeatureError from ..plan.schema import END_DATE, PERSON_ID, START_DATE +# Import custom era functions (conditional to avoid breaking if sqlglot not installed) +try: + from .custom_era import apply_custom_era, validate_custom_era_support + + CUSTOM_ERA_AVAILABLE = True +except ImportError: + CUSTOM_ERA_AVAILABLE = False + def attach_observation_bounds(events, ctx): observation_period = ctx.table("observation_period").select( @@ -64,7 +72,36 @@ def apply_end_strategy(events, strategy, ctx): return _replace_end_date(events, with_bounds, end_date_expr) if strategy.kind == "custom_era": - raise UnsupportedFeatureError("Ibis executor end-strategy error: custom_era is not supported.") + # Check if custom era implementation is available + if not CUSTOM_ERA_AVAILABLE: + raise UnsupportedFeatureError( + "Custom era requires sqlglot package. Install with: pip install 'ohdsi-circe-python-alpha[ibis]'" + ) + + # Validate backend supports custom era + if not validate_custom_era_support(ctx.backend): + raise UnsupportedFeatureError( + f"Custom era not supported for backend: {ctx.backend.name}. " + "Supported backends: duckdb, postgres, spark, databricks, snowflake" + ) + + # Extract custom era parameters from strategy + gap_days = int(strategy.payload.get("gap_days", 0)) + offset = int(strategy.payload.get("offset", 0)) + + # Apply custom era using SQLGlot transpilation + # Note: Custom era replaces end_date with era end, so we use the events directly + eras = apply_custom_era( + backend=ctx.backend, + events=events, + gap_days=gap_days, + offset_start=0, # Custom era typically doesn't offset start + offset_end=offset, + schema=ctx.results_schema, + debug=False, # Set to True for SQL debugging + ) + + return eras # Fallback: preserve default semantics of op_end_date clipping. return _replace_end_date(events, with_bounds, with_bounds.op_end_date) diff --git a/pyproject.toml b/pyproject.toml index 401d3dd..447262e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,15 +60,19 @@ docs = [ ] ibis = [ "ibis-framework>=11.0.0; python_version >= '3.9'", + "sqlglot>=23.0.0", ] ibis-duckdb = [ "ibis-framework[duckdb]>=11.0.0; python_version >= '3.9'", + "sqlglot>=23.0.0", ] ibis-postgres = [ "ibis-framework[postgres]>=11.0.0; python_version >= '3.9'", + "sqlglot>=23.0.0", ] ibis-databricks = [ "ibis-framework[databricks]>=11.0.0; python_version >= '3.9'", + "sqlglot>=23.0.0", ] waveform = [ "pydantic>=2.0.0", diff --git a/tests/test_custom_era.py b/tests/test_custom_era.py new file mode 100644 index 0000000..8878e7c --- /dev/null +++ b/tests/test_custom_era.py @@ -0,0 +1,336 @@ +"""Tests for custom era implementation using SQLGlot transpilation.""" + +import pytest + +pytest.importorskip("ibis") +pytest.importorskip("sqlglot") + + +import ibis + +from circe.execution.engine.custom_era import ( + build_custom_era_sql, + generate_custom_era_sql_reference, + get_backend_dialect, + transpile_custom_era_sql, + validate_custom_era_support, +) +from circe.execution.errors import CompilationError, UnsupportedFeatureError + + +class TestDialectMapping: + """Test backend to dialect mapping.""" + + def test_get_backend_dialect_duckdb(self): + """DuckDB backend should map to duckdb dialect.""" + + # Create a mock backend with name attribute + class MockBackend: + name = "duckdb" + + backend = MockBackend() + assert get_backend_dialect(backend) == "duckdb" + + def test_get_backend_dialect_postgres(self): + """PostgreSQL backend should map to postgres dialect.""" + + class MockBackend: + name = "postgres" + + backend = MockBackend() + assert get_backend_dialect(backend) == "postgres" + + def test_get_backend_dialect_databricks(self): + """Databricks backend should map to databricks dialect.""" + + class MockBackend: + name = "databricks" + + backend = MockBackend() + assert get_backend_dialect(backend) == "databricks" + + def test_get_backend_dialect_spark(self): + """Spark backend should map to databricks dialect.""" + + class MockBackend: + name = "spark" + + backend = MockBackend() + assert get_backend_dialect(backend) == "databricks" + + def test_get_backend_dialect_unsupported(self): + """Unsupported backend should raise UnsupportedFeatureError.""" + + class MockBackend: + name = "unsupported_db" + + backend = MockBackend() + with pytest.raises(UnsupportedFeatureError, match="Custom era not supported"): + get_backend_dialect(backend) + + +class TestReferenceSQLGeneration: + """Test reference PostgreSQL SQL generation.""" + + def test_generate_reference_sql_basic(self): + """Reference SQL should contain expected components.""" + sql = generate_custom_era_sql_reference( + events_table_name="test_schema.events", + gap_days=30, + offset_start=0, + offset_end=0, + ) + + # Check for key SQL components + assert "WITH event_gaps AS" in sql + assert "era_boundaries AS" in sql + assert "era_ids AS" in sql + assert "LAG(start_date)" in sql + assert "PARTITION BY person_id" in sql + assert "SUM(is_new_era)" in sql + assert "INTERVAL '30 days'" in sql + assert "test_schema.events" in sql + + def test_generate_reference_sql_with_offsets(self): + """Reference SQL should include offsets.""" + sql = generate_custom_era_sql_reference( + events_table_name="events", + gap_days=7, + offset_start=10, + offset_end=5, + ) + + assert "INTERVAL '7 days'" in sql + assert "INTERVAL '10 days'" in sql # offset_start + assert "INTERVAL '5 days'" in sql # offset_end + + def test_generate_reference_sql_negative_offset(self): + """Reference SQL should handle negative offsets.""" + sql = generate_custom_era_sql_reference( + events_table_name="events", + gap_days=30, + offset_start=-5, + offset_end=-10, + ) + + # Negative offsets are still syntactically valid + assert "INTERVAL '-5 days'" in sql + assert "INTERVAL '-10 days'" in sql + + +class TestSQLTranspilation: + """Test SQLGlot transpilation to different dialects.""" + + def test_transpile_to_duckdb(self): + """Transpilation to DuckDB should succeed.""" + reference = generate_custom_era_sql_reference("events", gap_days=30) + transpiled = transpile_custom_era_sql(reference, "duckdb") + + assert len(transpiled) > 0 + assert "WITH" in transpiled + # DuckDB should use INTERVAL syntax + assert "INTERVAL" in transpiled + + def test_transpile_to_spark(self): + """Transpilation to Spark should succeed.""" + reference = generate_custom_era_sql_reference("events", gap_days=30) + transpiled = transpile_custom_era_sql(reference, "spark") + + assert len(transpiled) > 0 + assert "WITH" in transpiled + + def test_transpile_to_snowflake(self): + """Transpilation to Snowflake should succeed.""" + reference = generate_custom_era_sql_reference("events", gap_days=30) + transpiled = transpile_custom_era_sql(reference, "snowflake") + + assert len(transpiled) > 0 + assert "WITH" in transpiled + + def test_transpile_preserves_logic(self): + """Transpiled SQL should preserve key logic components.""" + reference = generate_custom_era_sql_reference("events", gap_days=7) + + for dialect in ["duckdb", "postgres", "spark", "snowflake"]: + transpiled = transpile_custom_era_sql(reference, dialect) + + # All dialects should preserve window function structure + assert "LAG" in transpiled.upper() + assert "PARTITION BY" in transpiled.upper() + assert "SUM" in transpiled.upper() + + +class TestBuildCustomEraSQLIntegration: + """Test full SQL building with backend integration.""" + + def test_build_custom_era_sql_duckdb(self): + """Build custom era SQL for DuckDB backend.""" + + class MockBackend: + name = "duckdb" + + backend = MockBackend() + sql = build_custom_era_sql( + backend=backend, + events_table_name="test.events", + gap_days=30, + offset_start=0, + offset_end=0, + ) + + assert len(sql) > 0 + assert "test.events" in sql or "test" in sql # Table name should be present + + def test_build_custom_era_sql_invalid_gap_days(self): + """Negative gap_days should raise error.""" + + class MockBackend: + name = "duckdb" + + backend = MockBackend() + with pytest.raises(CompilationError, match="gap_days must be non-negative"): + build_custom_era_sql( + backend=backend, + events_table_name="events", + gap_days=-1, + ) + + def test_build_custom_era_sql_debug_mode(self, capsys): + """Debug mode should print SQL.""" + + class MockBackend: + name = "postgres" + + backend = MockBackend() + build_custom_era_sql( + backend=backend, + events_table_name="events", + gap_days=30, + debug=True, + ) + + captured = capsys.readouterr() + assert "Reference SQL" in captured.out + assert "Transpiled SQL" in captured.out + + +class TestValidateCustomEraSupport: + """Test validation of custom era support.""" + + def test_validate_support_duckdb(self): + """DuckDB should be supported.""" + + class MockBackend: + name = "duckdb" + + assert validate_custom_era_support(MockBackend()) + + def test_validate_support_postgres(self): + """PostgreSQL should be supported.""" + + class MockBackend: + name = "postgres" + + assert validate_custom_era_support(MockBackend()) + + def test_validate_support_databricks(self): + """Databricks should be supported.""" + + class MockBackend: + name = "databricks" + + assert validate_custom_era_support(MockBackend()) + + def test_validate_support_unsupported(self): + """Unsupported backends should return False.""" + + class MockBackend: + name = "unsupported_db" + + assert not validate_custom_era_support(MockBackend()) + + +@pytest.mark.integration +class TestCustomEraExecution: + """Integration tests with real Ibis backends.""" + + @pytest.fixture + def duckdb_backend(self): + """Create DuckDB backend with test data.""" + con = ibis.duckdb.connect(":memory:") + + # Create test events table + test_data = ibis.memtable( + { + "person_id": [1, 1, 1, 2, 2], + "start_date": ["2020-01-01", "2020-01-05", "2020-01-15", "2020-02-01", "2020-02-10"], + "end_date": ["2020-01-01", "2020-01-05", "2020-01-15", "2020-02-01", "2020-02-10"], + } + ) + + # Cast dates properly + test_data = test_data.mutate( + start_date=test_data.start_date.cast("date"), + end_date=test_data.end_date.cast("date"), + ) + + con.create_table("test_events", test_data, overwrite=True) + return con + + def test_custom_era_basic_execution(self, duckdb_backend): + """Test basic custom era execution.""" + # Build SQL for gap_days=7 (should group events 1-5 together, separate from 15) + sql = build_custom_era_sql( + backend=duckdb_backend, + events_table_name="test_events", + gap_days=7, + offset_start=0, + offset_end=0, + ) + + # Execute SQL + result = duckdb_backend.sql(sql) + df = result.execute() + + # Should have eras grouped by person and gap + assert len(df) > 0 + assert "person_id" in df.columns + assert "start_date" in df.columns + assert "end_date" in df.columns + + def test_generate_sql_is_valid_postgres(self): + """Generated PostgreSQL SQL should be syntactically valid.""" + sql = generate_custom_era_sql_reference("test_events", gap_days=30) + + # Try to parse with SQLGlot to validate syntax + from sqlglot import parse + + parsed = parse(sql, dialect="postgres") + assert len(parsed) > 0 + assert parsed[0] is not None + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_empty_events_table(self): + """Custom era should handle empty events table.""" + # This is more of a documentation test - actual behavior depends on SQL execution + sql = generate_custom_era_sql_reference("empty_events", gap_days=30) + assert "empty_events" in sql + + def test_single_event_per_person(self): + """Custom era should handle single event per person.""" + # SQL should still be valid + sql = generate_custom_era_sql_reference("events", gap_days=30) + assert "LAG" in sql # LAG will return NULL for first row, handled by CASE + + def test_zero_gap_days(self): + """Zero gap days should be valid (each event is own era).""" + sql = generate_custom_era_sql_reference("events", gap_days=0) + assert "INTERVAL '0 days'" in sql + + def test_large_gap_days(self): + """Large gap days should be handled.""" + sql = generate_custom_era_sql_reference("events", gap_days=365) + assert "INTERVAL '365 days'" in sql