diff --git a/README.md b/README.md index e1af7030..b1d076e6 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # CIRCE Python Implementation [![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://www.python.org/downloads/) -[![Tests](https://img.shields.io/badge/tests-3400%2B%20passed-brightgreen)](tests/) +[![Tests](https://img.shields.io/badge/tests-passing-brightgreen)](tests/) [![codecov](https://codecov.io/gh/OHDSI/Circepy/graph/badge.svg?token=CODECOV_TOKEN)](https://codecov.io/gh/OHDSI/Circepy) [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) [![PyPI](https://img.shields.io/badge/PyPI-ohdsi--circe--python--alpha-blue)](https://pypi.org/project/ohdsi-circe-python-alpha/) @@ -27,17 +27,15 @@ CIRCE Python provides a comprehensive toolkit for working with OMOP CDM cohort d > [!IMPORTANT] > This package is currently in **Alpha** status and undergoing rigorous parity testing against the Java implementation. -- **Version**: 0.1.0 (Alpha) -- **Tests**: 3,400+ passing -- **Coverage**: 34% (Core logic focus) +- **Version**: 0.2.0 (Alpha) +- **Tests**: Passing in CI - **Python**: 3.9+ - **License**: Apache 2.0 ## Installation > [!NOTE] -> This package is currently in private development. Install from source using Git. -> The recommended workflow uses `uv` and the checked-in `uv.lock` for a reproducible environment. +> The recommended source workflow uses `uv` and the checked-in `uv.lock` for a reproducible environment. ### From Source (Current Method) @@ -148,18 +146,17 @@ An experimental backend-native execution API is available under `circe.execution`. ```python -from circe.execution import ExecutionOptions, IbisExecutor +from circe.execution import build_cohort # Requires optional extras, e.g. `pip install ohdsi-circe-python-alpha[ibis-duckdb]` -executor = IbisExecutor(conn, ExecutionOptions(cdm_schema="main")) -events = executor.build(cohort) # lazy ibis relation +events = build_cohort(cohort, backend=conn, cdm_schema="main") # lazy ibis relation ``` ## What's Included This package provides a complete Python implementation of CIRCE-BE with: -- **3,400+ passing tests** with focused coverage on core logic +- **Passing test suite** with focused coverage on core logic - **18+ SQL builders** for all OMOP CDM domains: - Condition Occurrence/Era - Drug Exposure/Era @@ -227,7 +224,7 @@ circe/ - [x] Java interoperability with camelCase/snake_case field support - [x] Cohort expression validation with 40+ checker implementations - [x] Markdown rendering for print-friendly descriptions -- [x] Full test suite (3,400+ tests) +- [x] Full test suite - [x] Type hints throughout with py.typed marker - [x] Concept set expression handling - [x] Window criteria and correlated criteria support @@ -373,7 +370,7 @@ uv run circe --help uv run pytest ``` -All 3,400+ tests should pass. +The full test suite should pass. ### Linting and Formatting diff --git a/circe/__init__.py b/circe/__init__.py index 914de487..9a09bc7a 100644 --- a/circe/__init__.py +++ b/circe/__init__.py @@ -78,17 +78,14 @@ ) from .api import ( + build_cohort, build_cohort_query, cohort_expression_from_json, cohort_print_friendly, -) -from .execution import ( - ExecutionOptions, - IbisExecutor, - build_ibis, - to_polars, write_cohort, ) + +# Main exports from .io import load_expression from .vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem @@ -208,13 +205,10 @@ def get_json_schema() -> dict: # API functions "cohort_expression_from_json", "build_cohort_query", + "build_cohort", + "write_cohort", "cohort_print_friendly", "safe_model_rebuild", - # I/O and experimental execution API + # I/O helpers "load_expression", - "ExecutionOptions", - "IbisExecutor", - "build_ibis", - "to_polars", - "write_cohort", ] diff --git a/circe/api.py b/circe/api.py index 06b52370..8b466f92 100644 --- a/circe/api.py +++ b/circe/api.py @@ -4,10 +4,12 @@ This module provides a simple R CirceR-style API for working with cohort definitions: - cohort_expression_from_json(): Load cohort expression from JSON string - build_cohort_query(): Generate SQL from cohort expression +- build_cohort(): Build cohort as a relational expression (experimental) +- write_cohort(): Write OHDSI cohort-table rows to a database table - cohort_print_friendly(): Generate Markdown from cohort expression """ -from typing import Optional +from typing import Literal, Optional from .cohortdefinition import ( BuildExpressionQueryOptions, @@ -15,6 +17,7 @@ CohortExpressionQueryBuilder, MarkdownRender, ) +from .execution.typing import IbisBackendLike, Table from .vocabulary.concept import ConceptSet @@ -102,6 +105,122 @@ def build_cohort_query( return builder.build_expression_query(expression, options) +def build_cohort( + expression: CohortExpression, + *, + backend: IbisBackendLike, + cdm_schema: str, + vocabulary_schema: Optional[str] = None, + results_schema: Optional[str] = None, +) -> Table: + """Build a cohort as a relational table expression. + + This uses the experimental Ibis execution engine to compile the cohort + expression into a backend-native relational expression. + + Args: + expression: CohortExpression instance + backend: Ibis backend used to compile the cohort relation + cdm_schema: Schema containing the OMOP CDM tables + vocabulary_schema: Optional schema for vocabulary tables. Defaults to + ``cdm_schema`` when omitted. + results_schema: Optional schema used for result-side table resolution + + Returns: + Ibis table expression representing the cohort result + + Raises: + ExecutionError: If the cohort cannot be normalized, lowered, or + compiled into a relational expression + + Example: + >>> import ibis + >>> backend = ibis.duckdb.connect() + >>> expression = cohort_expression_from_json(json_str) + >>> relation = build_cohort( + ... expression, + ... backend=backend, + ... cdm_schema="cdm", + ... vocabulary_schema="vocab", + ... ) + """ + from .execution import build_cohort as _build_cohort + + return _build_cohort( + expression, + backend=backend, + cdm_schema=cdm_schema, + vocabulary_schema=vocabulary_schema, + results_schema=results_schema, + ) + + +def write_cohort( + expression: CohortExpression, + *, + backend: IbisBackendLike, + cdm_schema: str, + cohort_table: str, + cohort_id: int, + vocabulary_schema: Optional[str] = None, + results_schema: Optional[str] = None, + if_exists: Literal["fail", "replace"] = "fail", +) -> None: + """Build and write an OHDSI cohort table. + + This wraps :func:`build_cohort`, projects the resulting relation into the + standard OHDSI cohort-table shape, and materializes it to a backend table. + Existing rows for other cohort IDs are preserved. + + Args: + expression: CohortExpression instance + backend: Ibis backend used to compile and write the cohort relation + cdm_schema: Schema containing the OMOP CDM tables + cohort_table: Name of the OHDSI cohort table to create or update + cohort_id: Cohort definition identifier written to + ``cohort_definition_id`` + vocabulary_schema: Optional schema for vocabulary tables. Defaults to + ``cdm_schema`` when omitted. + results_schema: Optional schema for the target table + if_exists: Cohort-row policy, either ``"fail"`` or ``"replace"``. + ``"fail"`` raises if rows for ``cohort_id`` already exist. + ``"replace"`` replaces only rows for ``cohort_id``. + + Returns: + None + + Raises: + ExecutionError: If the cohort cannot be built or the target table + cannot be written + + Example: + >>> import ibis + >>> backend = ibis.duckdb.connect() + >>> expression = cohort_expression_from_json(json_str) + >>> write_cohort( + ... expression, + ... backend=backend, + ... cdm_schema="cdm", + ... cohort_table="cohort", + ... cohort_id=1, + ... results_schema="results", + ... if_exists="replace", + ... ) + """ + from .execution import write_cohort as _write_cohort + + _write_cohort( + expression, + backend=backend, + cdm_schema=cdm_schema, + cohort_table=cohort_table, + cohort_id=cohort_id, + vocabulary_schema=vocabulary_schema, + results_schema=results_schema, + if_exists=if_exists, + ) + + def cohort_print_friendly( expression: CohortExpression, concept_sets: Optional[list[ConceptSet]] = None, diff --git a/circe/execution/__init__.py b/circe/execution/__init__.py index c0adffbf..ba27df0a 100644 --- a/circe/execution/__init__.py +++ b/circe/execution/__init__.py @@ -1,13 +1,26 @@ -"""Experimental backend execution APIs.""" +"""New Ibis execution subsystem. -from .ibis import IbisExecutor, build_ibis, to_polars, write_cohort -from .options import ExecutionOptions, SchemaName +This package is intentionally parallel to the existing SQL builder path and does +not modify cohortdefinition model semantics. +""" + +from .api import build_cohort, write_cohort +from .databricks_compat import apply_databricks_post_connect_workaround +from .errors import ( + CompilationError, + ExecutionError, + ExecutionNormalizationError, + UnsupportedCriterionError, + UnsupportedFeatureError, +) __all__ = [ - "ExecutionOptions", - "SchemaName", - "IbisExecutor", - "build_ibis", - "to_polars", + "build_cohort", "write_cohort", + "apply_databricks_post_connect_workaround", + "ExecutionError", + "ExecutionNormalizationError", + "UnsupportedCriterionError", + "UnsupportedFeatureError", + "CompilationError", ] diff --git a/circe/execution/_dataclass.py b/circe/execution/_dataclass.py new file mode 100644 index 00000000..f7129f39 --- /dev/null +++ b/circe/execution/_dataclass.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass +from typing import Any, Callable, TypeVar, cast, overload + +from typing_extensions import dataclass_transform + +T = TypeVar("T") + + +@overload +def frozen_slots_dataclass(_cls: type[T], **kwargs: Any) -> type[T]: ... + + +@overload +def frozen_slots_dataclass(_cls: None = None, **kwargs: Any) -> Callable[[type[T]], type[T]]: ... + + +@dataclass_transform(frozen_default=True) +def frozen_slots_dataclass( + _cls: type[T] | None = None, + **kwargs: Any, +) -> type[T] | Callable[[type[T]], type[T]]: + """Compatibility wrapper for frozen+slots dataclasses. + + `slots=True` is preferred for memory/layout guarantees, but this wrapper keeps + compatibility with older Python runtimes that do not support dataclass slots. + """ + + def wrap(cls: type[T]) -> type[T]: + dataclass_factory = cast(Any, dataclass) + if sys.version_info >= (3, 10): + return cast(type[T], dataclass_factory(frozen=True, slots=True, **kwargs)(cls)) + return cast(type[T], dataclass_factory(frozen=True, **kwargs)(cls)) + + if _cls is None: + return wrap + return wrap(_cls) diff --git a/circe/execution/api.py b/circe/execution/api.py new file mode 100644 index 00000000..c49911a1 --- /dev/null +++ b/circe/execution/api.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import Literal + +from ..cohortdefinition import CohortExpression +from .databricks_compat import maybe_apply_databricks_post_connect_workaround +from .engine.cohort import build_cohort_table +from .errors import ExecutionError +from .ibis.context import make_execution_context +from .ibis.materialize import project_to_ohdsi_cohort_table +from .ibis.operations import ( + cohort_rows_exist, + create_table, + exclude_cohort_rows, + insert_relation, + read_table, + replace_cohort_rows_transactionally, + supports_transactional_replace, + table_exists, +) +from .normalize.cohort import normalize_cohort +from .typing import IbisBackendLike, Table + + +def build_cohort( + expression: CohortExpression, + *, + backend: IbisBackendLike, + cdm_schema: str, + results_schema: str | None = None, + vocabulary_schema: str | None = None, +) -> Table: + """Normalize, compile, and assemble a cohort relation.""" + maybe_apply_databricks_post_connect_workaround(backend) + + normalized = normalize_cohort(expression) + + ctx = make_execution_context( + backend=backend, + cdm_schema=cdm_schema, + results_schema=results_schema, + vocabulary_schema=vocabulary_schema, + concept_sets=normalized.concept_sets, + ) + + return build_cohort_table(normalized, ctx) + + +def write_relation( + relation: Table, + *, + backend: IbisBackendLike, + target_table: str, + target_schema: str | None = None, + if_exists: Literal["fail", "replace"] = "fail", + temporary: bool = False, +) -> None: + """Materialize a relation to a backend table.""" + if if_exists not in {"fail", "replace"}: + raise ValueError("if_exists must be one of {'fail', 'replace'} for write_relation.") + + maybe_apply_databricks_post_connect_workaround(backend) + + write_kwargs = { + "obj": relation, + "overwrite": if_exists == "replace", + } + if temporary: + write_kwargs["temp"] = True + + try: + create_table( + backend, + table_name=target_table, + schema=target_schema, + **write_kwargs, + ) + except Exception as exc: + schema_label = target_schema if target_schema is not None else "" + raise ExecutionError( + "Ibis executor write error: failed writing relation to " + f"table '{target_table}' in schema '{schema_label}' " + f"(if_exists={if_exists!r}, temporary={temporary})." + ) from exc + + +def write_cohort( + expression: CohortExpression, + *, + backend: IbisBackendLike, + cdm_schema: str, + cohort_table: str, + cohort_id: int, + results_schema: str | None = None, + vocabulary_schema: str | None = None, + if_exists: Literal["fail", "replace"] = "fail", +) -> None: + """Build cohort rows and materialize them with cohort-scoped semantics.""" + if if_exists not in {"fail", "replace"}: + raise ValueError("if_exists must be one of {'fail', 'replace'} for write_cohort.") + + new_rows = build_cohort( + expression, + backend=backend, + cdm_schema=cdm_schema, + results_schema=results_schema, + vocabulary_schema=vocabulary_schema, + ) + new_rows = project_to_ohdsi_cohort_table(new_rows, cohort_id=cohort_id) + + if not table_exists(backend, table_name=cohort_table, schema=results_schema): + write_relation( + new_rows, + backend=backend, + target_table=cohort_table, + target_schema=results_schema, + if_exists="fail", + ) + return + + if if_exists == "fail": + if cohort_rows_exist( + backend, + cohort_table=cohort_table, + results_schema=results_schema, + cohort_id=cohort_id, + ): + raise ExecutionError( + "Ibis executor write error: cohort table " + f"'{cohort_table}' already contains rows for cohort_id={cohort_id}." + ) + insert_relation( + new_rows, + backend=backend, + target_table=cohort_table, + target_schema=results_schema, + ) + return + + if supports_transactional_replace(backend): + replace_cohort_rows_transactionally( + new_rows, + backend=backend, + cohort_table=cohort_table, + results_schema=results_schema, + cohort_id=cohort_id, + ) + return + + existing = read_table( + backend, + table_name=cohort_table, + schema=results_schema, + ) + filtered = exclude_cohort_rows(existing, cohort_id=cohort_id) + relation = filtered.union(new_rows, distinct=False) + write_relation( + relation, + backend=backend, + target_table=cohort_table, + target_schema=results_schema, + if_exists="replace", + ) diff --git a/circe/execution/build_context.py b/circe/execution/build_context.py deleted file mode 100644 index 48650cd8..00000000 --- a/circe/execution/build_context.py +++ /dev/null @@ -1,579 +0,0 @@ -from __future__ import annotations - -import uuid -import weakref -from collections.abc import Iterable -from dataclasses import dataclass -from functools import reduce -from pathlib import Path -from typing import Callable, Union - -import ibis -import ibis.common.exceptions as ibis_exc -import ibis.expr.types as ir - -from ..vocabulary.concept import ConceptSet -from .ibis_compat import table_from_literal_list - -Database = Union[str, tuple[str, str]] - - -def _qualify(database: Database | None, name: str) -> str: - """Only for statements were constructing outside of Ibis.""" - if database is None: - return name - if isinstance(database, tuple): - return ".".join(database + (name,)) - return f"{database}.{name}" - - -def _table(conn: ibis.BaseBackend, database: Database | None, name: str) -> ir.Table: - return conn.table(name, database=database) - - -def _warn(message: str) -> None: - print(f"Warning: {message}") - - -def _analyze_table(conn: ibis.BaseBackend, *, backend: str | None, qualified_name: str) -> None: - if not backend: - return - if backend in ("postgres", "duckdb"): - conn.raw_sql(f"ANALYZE {qualified_name}") - return - if backend == "databricks": - conn.raw_sql(f"ANALYZE TABLE {qualified_name} COMPUTE STATISTICS") - - -def _drop_table_safely( - conn: ibis.BaseBackend, - *, - name: str, - database: Database | None = None, - warning_label: str, -) -> None: - try: - conn.drop_table(name, database=database, force=True) - except Exception as exc: - _warn(f"could not drop {warning_label}: {exc}") - - -@dataclass(frozen=True) -class CohortBuildOptions: - cdm_schema: str | None = None - vocabulary_schema: str | None = None - result_schema: str | None = None - target_table: str | None = None - cohort_id: int | None = None - generate_stats: bool = False - temp_emulation_schema: str | None = None - profile_dir: str | None = None - capture_sql: bool = False - backend: str | None = None - materialize_stages: bool = True - materialize_codesets: bool = True - - -@dataclass -class CodesetResource: - table: ir.Table - _dropper: Callable[[], None] | None = None - - def cleanup(self): - if self._dropper: - try: - self._dropper() - finally: - self._dropper = None - - -class BuildContext: - """Holds shared state (connection, schemas, compiled codesets) used across builders.""" - - def __init__( - self, - conn: ibis.BaseBackend, - options: CohortBuildOptions, - codeset_resource: CodesetResource | ir.Table, - ): - self._conn = conn - self._options = options - if isinstance(codeset_resource, CodesetResource): - self._codeset_resource = codeset_resource - else: - self._codeset_resource = CodesetResource(table=codeset_resource) - self._codesets = self._codeset_resource.table - self._cleanup_callbacks: list[Callable[[], None]] = [] - self._correlated_cache: dict[str, ir.Table] = {} - self._profile_dir = None - if options.profile_dir: - path = Path(options.profile_dir).resolve() - path.mkdir(parents=True, exist_ok=True) - self._profile_dir = path - self._captured_sql: list[tuple[str, str]] = [] - self._slice_cache: dict[str, ir.Table] = {} - weakref.finalize(self, self.close) - - def _table(self, database: str | None, name: str) -> ir.Table: - try: - return _table(self._conn, database, name) - except ( - ibis_exc.IbisError, - TypeError, - ValueError, - AttributeError, - NotImplementedError, - ): - return self._conn.sql(f"SELECT * FROM {_qualify(database, name)}") - - def table(self, name: str) -> ir.Table: - """Return a CDM table.""" - return self._table(self._options.cdm_schema, name) - - def vocabulary_table(self, name: str) -> ir.Table: - """Return a vocabulary table (concept, concept_ancestor, etc.).""" - schema = self._options.vocabulary_schema or self._options.cdm_schema - return self._table(schema, name) - - def codeset(self, codeset_id: int, *, is_exclusion: bool = False) -> ir.Table: - """Return concepts for the requested codeset. `is_exclusion` is provided for parity with Circe.""" - _ = is_exclusion # placeholder for future differentiated handling - return self._codesets.filter(self._codesets.codeset_id == codeset_id) - - def get_cached_correlated(self, key: str) -> ir.Table | None: - return self._correlated_cache.get(key) - - def cache_correlated(self, key: str, table: ir.Table) -> None: - self._correlated_cache[key] = table - - def materialize( - self, - expr: ir.Table, - *, - label: str, - temp: bool = True, - analyze: bool = True, - ) -> ir.Table: - """ - Materialize an Ibis expression, capturing a unique DuckDB profiling - artifact for this step. - """ - step_id = uuid.uuid4().hex[:8] - table_name = f"_stage_{label}_{step_id}" - backend = self._options.backend - - # "temp emulation" means: create a *real* table in a chosen database/schema. - use_temp_emulation = temp and self._options.temp_emulation_schema is not None - database: Database | None = self._options.temp_emulation_schema if use_temp_emulation else None - temp_flag = False if use_temp_emulation else temp - - # duckdb profiling setup for local dev - profile_filename: Path | None = None - profiling_enabled = False - if backend == "duckdb" and self._profile_dir is not None: - profile_filename = (self._profile_dir / f"ibis_profile_{label}_{step_id}.json").resolve() - try: - escaped = str(profile_filename).replace("'", "''") - self._conn.raw_sql(f"SET profiling_output='{escaped}'") - self._conn.raw_sql("SET enable_profiling='json'") - self._conn.raw_sql("SET profiling_coverage='ALL'") - profiling_enabled = True - except Exception as exc: - _warn(f"could not enable DuckDB profiling for {label}: {exc}") - - try: - self._conn.create_table( - table_name, - obj=expr, - database=database, - temp=temp_flag, - overwrite=True, - ) - if self._options.capture_sql: - self._captured_sql.append((table_name, self._conn.compile(expr))) - finally: - if profiling_enabled: - try: - self._conn.raw_sql("PRAGMA disable_profiling") - except Exception as exc: - _warn(f"could not disable DuckDB profiling for {label}: {exc}") - - if profiling_enabled and profile_filename is not None: - print(f"[Profile Captured]: {profile_filename} (Table: {table_name})") - - if analyze: - qualified = _qualify(database, table_name) - try: - _analyze_table(self._conn, backend=backend, qualified_name=qualified) - except Exception as exc: - _warn(f"could not analyze table {qualified}: {exc}") - - def _drop(): - _drop_table_safely( - self._conn, - name=table_name, - database=database, - warning_label=f"table {table_name} in {database}", - ) - - self.register_cleanup(_drop) - return _table(self._conn, database, table_name) - - def should_materialize_stages(self) -> bool: - return bool(self._options.materialize_stages) - - def maybe_materialize( - self, - expr: ir.Table, - *, - label: str, - temp: bool = True, - analyze: bool = True, - ) -> ir.Table: - if not self.should_materialize_stages(): - return expr - return self.materialize(expr, label=label, temp=temp, analyze=analyze) - - def write_cohort_table( - self, - events: ir.Table, - *, - table_name: str | None = None, - database: Database | None = None, - overwrite: bool = True, - append: bool = False, - ) -> ir.Table: - """ - Persist cohort rows to a results table. - - Output schema matches OHDSI cohort tables: - (cohort_definition_id, subject_id, cohort_start_date, cohort_end_date) - """ - if append and overwrite: - raise ValueError("`append=True` and `overwrite=True` cannot be used together.") - target_table = table_name or self._options.target_table - if not target_table: - raise ValueError("target_table must be set (argument or CohortBuildOptions.target_table)") - target_db = database if database is not None else self._options.result_schema - if target_db is None: - raise ValueError("result_schema must be set (argument or CohortBuildOptions.result_schema)") - - cohort_id = self._options.cohort_id - cohort_id_expr = ( - ibis.literal(int(cohort_id), type="int64") if cohort_id is not None else ibis.null().cast("int64") - ) - - result = events.select( - cohort_id_expr.name("cohort_definition_id"), - events.person_id.cast("int64").name("subject_id"), - events.start_date.cast("date").name("cohort_start_date"), - events.end_date.cast("date").name("cohort_end_date"), - ) - - obj = result - if append: - try: - existing = _table(self._conn, target_db, target_table) - obj = existing.union(result, distinct=False) - except ( - ibis_exc.IbisError, - TypeError, - ValueError, - AttributeError, - NotImplementedError, - ): - obj = result - - self._conn.create_table( - target_table, - obj=obj, - database=target_db, - temp=False, - overwrite=overwrite, - ) - return _table(self._conn, target_db, target_table) - - @property - def codesets(self) -> ir.Table: - return self._codesets - - @property - def conn(self) -> ibis.BaseBackend: - return self._conn - - def options(self) -> CohortBuildOptions: - return self._options - - def captured_sql(self) -> list[tuple[str, str]]: - return list(self._captured_sql) - - def register_cleanup(self, callback: Callable[[], None]): - self._cleanup_callbacks.append(callback) - - def get_or_materialize_slice( - self, - cache_key: str, - expr: ir.Table, - *, - label: str | None = None, - ) -> ir.Table: - """Materialize an expression once and reuse the resulting temp table for later lookups.""" - if not self.should_materialize_stages(): - return expr.view() - cached = self._slice_cache.get(cache_key) - if cached is not None: - return cached - label_hint = label or "slice" - table = self.materialize(expr, label=label_hint, temp=True, analyze=True) - self._slice_cache[cache_key] = table - return table - - def close(self): - if self._codeset_resource is not None: - self._codeset_resource.cleanup() - self._codeset_resource = None # type: ignore[assignment] - while self._cleanup_callbacks: - callback = self._cleanup_callbacks.pop() - try: - callback() - except Exception as exc: - _warn(f"cleanup callback failed: {exc}") - self._captured_sql.clear() - self._slice_cache.clear() - - -def compile_codesets( - conn: ibis.BaseBackend, - concept_sets: list[ConceptSet], - options: CohortBuildOptions, -) -> CodesetResource: - """Rebuild Circe concept set logic as an ibis expression.""" - - vocab_schema = options.vocabulary_schema or options.cdm_schema - concept = _table(conn, vocab_schema, "concept") - concept_ancestor = _table(conn, vocab_schema, "concept_ancestor") - concept_relationship = _table(conn, vocab_schema, "concept_relationship") - - compiled = [] - for concept_set in concept_sets or []: - compiled_expr = _compile_single_codeset(concept, concept_ancestor, concept_relationship, concept_set) - if compiled_expr is not None: - compiled.append(compiled_expr) - - compiled_expr = _empty_codeset_table() if not compiled else _union_all(compiled).distinct() - - if not options.materialize_codesets: - return CodesetResource(table=compiled_expr) - - return _materialize_codesets(conn, compiled_expr, options) - - -def _compile_single_codeset( - concept: ir.Table, - concept_ancestor: ir.Table, - concept_relationship: ir.Table, - concept_set: ConceptSet, -) -> ir.Table | None: - expression = concept_set.expression - if expression is None or not expression.items: - return None - - include_ids: list[int] = [] - include_descendant_ids: list[int] = [] - include_mapped_ids: list[int] = [] - include_mapped_descendant_ids: list[int] = [] - - exclude_ids: list[int] = [] - exclude_descendant_ids: list[int] = [] - exclude_mapped_ids: list[int] = [] - exclude_mapped_descendant_ids: list[int] = [] - - for item in expression.items: - if item.concept is None or item.concept.concept_id is None: - continue - target_include = not bool(item.is_excluded) - include_descendants = bool(item.include_descendants) - include_mapped = bool(item.include_mapped) - concept_id = int(item.concept.concept_id) - - if target_include: - include_ids.append(concept_id) - if include_descendants: - include_descendant_ids.append(concept_id) - if include_mapped: - include_mapped_ids.append(concept_id) - if include_descendants: - include_mapped_descendant_ids.append(concept_id) - else: - exclude_ids.append(concept_id) - if include_descendants: - exclude_descendant_ids.append(concept_id) - if include_mapped: - exclude_mapped_ids.append(concept_id) - if include_descendants: - exclude_mapped_descendant_ids.append(concept_id) - - include_expr = _union_distinct( - [ - _ids_memtable(include_ids), - _descendants(concept, concept_ancestor, include_descendant_ids), - _mapped_concepts( - concept, - concept_ancestor, - concept_relationship, - include_mapped_ids, - include_mapped_descendant_ids, - ), - ] - ) - - if include_expr is None: - return None - - exclude_expr = _union_distinct( - [ - _ids_memtable(exclude_ids), - _descendants(concept, concept_ancestor, exclude_descendant_ids), - _mapped_concepts( - concept, - concept_ancestor, - concept_relationship, - exclude_mapped_ids, - exclude_mapped_descendant_ids, - ), - ] - ) - - if exclude_expr is not None: - include_expr = include_expr.anti_join(exclude_expr, ["concept_id"]) - - codeset_literal = ibis.literal(int(concept_set.id), type="int64") - return include_expr.mutate(codeset_id=codeset_literal)[["codeset_id", "concept_id"]] - - -def _ids_memtable(ids: list[int]) -> ir.Table | None: - if not ids: - return None - return table_from_literal_list(ids, column_name="concept_id", element_type="int64").distinct() - - -def _descendants(concept: ir.Table, concept_ancestor: ir.Table, ancestor_ids: list[int]) -> ir.Table | None: - if not ancestor_ids: - return None - return ( - concept_ancestor.filter(concept_ancestor.ancestor_concept_id.isin(ancestor_ids)) - .join(concept, concept_ancestor.descendant_concept_id == concept.concept_id) - .filter(concept.invalid_reason.isnull()) - .select(concept.concept_id.cast("int64").name("concept_id")) - .distinct() - ) - - -def _mapped_concepts( - concept: ir.Table, - concept_ancestor: ir.Table, - concept_relationship: ir.Table, - concepts_to_map: list[int], - concepts_with_descendants_to_map: list[int], -) -> ir.Table | None: - sources = _union_distinct( - [ - _ids_memtable(concepts_to_map), - _descendants(concept, concept_ancestor, concepts_with_descendants_to_map), - ] - ) - - if sources is None: - return None - - valid_relationships = concept_relationship.filter( - [ - concept_relationship.relationship_id == "Maps to", - concept_relationship.invalid_reason.isnull(), - ] - ) - - return ( - sources.join(valid_relationships, sources.concept_id == valid_relationships.concept_id_2) - .select(valid_relationships.concept_id_1.cast("int64").name("concept_id")) - .distinct() - ) - - -def _empty_codeset_table() -> ir.Table: - empty_concepts = table_from_literal_list([], column_name="concept_id", element_type="int64") - empty_codesets = empty_concepts.mutate( - codeset_id=ibis.null().cast("int64"), - ) - return empty_codesets.select("codeset_id", "concept_id") - - -def _materialize_codesets( - conn: ibis.BaseBackend, - expr: ir.Table, - options: CohortBuildOptions, -) -> CodesetResource: - name = f"_codesets_{uuid.uuid4().hex}" - if options.temp_emulation_schema: - database: Database = options.temp_emulation_schema - conn.create_table( - name, - obj=expr, - database=database, - temp=False, - overwrite=True, - ) - table = _table(conn, database, name) - qualified = _qualify(database, name) - - def _drop(): - _drop_table_safely( - conn, - name=name, - database=database, - warning_label=f"codeset table {name} in {database}", - ) - - else: - conn.create_table( - name, - obj=expr, - temp=True, - overwrite=True, - ) - table = _table(conn, None, name) - qualified = _qualify(None, name) - - def _drop(): - _drop_table_safely( - conn, - name=name, - warning_label=f"codeset temp table {name}", - ) - - backend = options.backend - if backend: - try: - _analyze_table(conn, backend=backend, qualified_name=qualified) - except Exception as exc: - _warn(f"could not analyze codeset table {qualified}: {exc}") - - resource = CodesetResource(table=table, _dropper=_drop) - weakref.finalize(resource, resource.cleanup) - return resource - - -def _union_distinct(tables: Iterable[ir.Table | None]) -> ir.Table | None: - valid_tables = [t for t in tables if t is not None] - if not valid_tables: - return None - - return reduce( - lambda left, right: left.union(right, distinct=True), - valid_tables[1:], - valid_tables[0], - ) - - -def _union_all(tables: list[ir.Table]) -> ir.Table: - return reduce(lambda left, right: left.union(right), tables[1:], tables[0]) diff --git a/circe/execution/builders/__init__.py b/circe/execution/builders/__init__.py deleted file mode 100644 index b625e93a..00000000 --- a/circe/execution/builders/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -from . import ( - condition_era, - condition_occurrence, - death, - device_exposure, - dose_era, - drug_era, - drug_exposure, - measurement, - observation, - observation_period, - payer_plan_period, - procedure_occurrence, - specimen, - visit_detail, - visit_occurrence, -) -from .pipeline import build_primary_events -from .registry import build_events, register - -__all__ = [ - "condition_era", - "condition_occurrence", - "death", - "device_exposure", - "dose_era", - "drug_era", - "drug_exposure", - "measurement", - "observation", - "observation_period", - "payer_plan_period", - "procedure_occurrence", - "specimen", - "visit_detail", - "visit_occurrence", - "build_primary_events", - "build_events", - "register", -] diff --git a/circe/execution/builders/common.py b/circe/execution/builders/common.py deleted file mode 100644 index 59b739e8..00000000 --- a/circe/execution/builders/common.py +++ /dev/null @@ -1,757 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Any, Callable, cast - -import ibis -import ibis.expr.types as ir -from ibis.expr.api import row_number - -from ...cohortdefinition.core import ( - CollapseSettings, - CollapseType, - ConceptSetSelection, - CustomEraStrategy, - DateOffsetStrategy, - DateRange, - EndStrategy, - NumericRange, - TextFilter, -) -from ...vocabulary.concept import Concept -from ..build_context import BuildContext - -OutputFormatter = Callable[[ir.Table], ir.Table] - - -def _person_subset(ctx: BuildContext, columns: list[str]) -> ir.Table: - person = ctx.table("person") - missing = [col for col in columns if col not in person.columns] - if missing: - raise ValueError(f"Person table missing required columns: {missing}") - return person.select(columns) - - -def standardize_output( - table: ir.Table, - *, - primary_key: str, - start_column: str, - end_column: str, -) -> ir.Table: - """Project and rename columns to the strict builder output contract.""" - start_expr = table[start_column].cast("timestamp") - same_column = end_column == start_column - if same_column: - end_expr = start_expr - needs_offset = ibis.literal(True) - elif end_column in table.columns: - end_raw = table[end_column].cast("timestamp") - end_expr = ibis.coalesce(end_raw, start_expr).cast("timestamp") - needs_offset = end_raw.isnull() - else: - end_expr = start_expr - needs_offset = ibis.literal(True) - one_day = ibis.interval(days=1) - end_expr = ibis.ifelse(needs_offset, cast(Any, end_expr) + one_day, end_expr).cast("timestamp") - visit_expr = ( - table.visit_occurrence_id.cast("int64") - if "visit_occurrence_id" in table.columns - else ibis.null().cast("int64") - ).name("visit_occurrence_id") - return table.select( - table.person_id.cast("int64").name("person_id"), - table[primary_key].cast("int64").name("event_id"), - start_expr.name("start_date"), - end_expr.name("end_date"), - visit_expr, - ) - - -def project_event_columns( - table: ir.Table, - *, - primary_key: str, - start_column: str, - end_column: str, - include_visit_occurrence: bool = False, -) -> ir.Table: - keep = ["person_id", primary_key, start_column] - if end_column in table.columns or include_visit_occurrence and start_column != end_column: - keep.append(end_column) - if include_visit_occurrence and "visit_occurrence_id" in table.columns: - keep.append("visit_occurrence_id") - unique_keep = [col for i, col in enumerate(keep) if col in table.columns and col not in keep[:i]] - return table.select(*(table[col] for col in unique_keep)) - - -def apply_codeset_filter( - table: ir.Table, - concept_column: str, - codeset_id: int | None, - ctx: BuildContext, -) -> ir.Table: - if codeset_id is None: - return table - base_columns = table.columns - left = table.view() - concepts = ctx.codesets.filter(ctx.codesets["codeset_id"] == ibis.literal(codeset_id)).view() - joined = left.join(concepts, [left[concept_column] == concepts["concept_id"]]) - return _project_columns(joined, base_columns) - - -def apply_concept_set_selection( - table: ir.Table, - column: str, - selection: ConceptSetSelection | None, - ctx: BuildContext, -) -> ir.Table: - if selection is None or selection.codeset_id is None: - return table - base_columns = table.columns - left = table.view() - codeset_table = ctx.codesets.filter( - ctx.codesets["codeset_id"] == ibis.literal(selection.codeset_id) - ).view() - if selection.is_exclusion: - return left.anti_join(codeset_table, [left[column] == codeset_table.concept_id]) - joined = left.join(codeset_table, [left[column] == codeset_table.concept_id]) - return _project_columns(joined, base_columns) - - -def coerce_concept_set_selection( - value: object | None, -) -> ConceptSetSelection | None: - if value is None: - return None - if isinstance(value, ConceptSetSelection): - return value - if hasattr(value, "codeset_id"): - return cast(ConceptSetSelection, value) - try: - return ConceptSetSelection(CodesetId=int(cast(Any, value))) - except (TypeError, ValueError) as exc: - raise ValueError(f"Unsupported concept set selection value: {value!r}") from exc - - -def apply_concept_criteria( - table: ir.Table, - *, - column: str, - concepts: Sequence[Concept] | None, - selection: ConceptSetSelection | None, - ctx: BuildContext, - exclude: bool = False, -) -> ir.Table: - table = apply_concept_filters(table, column, concepts, exclude=exclude) - return apply_concept_set_selection(table, column, selection, ctx) - - -def apply_date_range(table: ir.Table, column: str, date_range: DateRange | None) -> ir.Table: - if not date_range: - return table - expr = table[column] - if date_range.op.endswith("bt"): - lower = ibis.literal(date_range.value) - upper = ibis.literal(date_range.extent) - predicate = expr.between(lower, upper) - if date_range.op.startswith("!"): - predicate = ~predicate - else: - comparator = _map_operator(date_range.op) - operand = ibis.literal(date_range.value) - predicate = comparator(expr, operand) - return table.filter(predicate) - - -def apply_numeric_range(table: ir.Table, column, numeric_range: NumericRange | None) -> ir.Table: - if not numeric_range or numeric_range.value is None: - return table - op = numeric_range.op or "eq" - - expr = table[column] if isinstance(column, str) else column - if op.endswith("bt"): - lower = ibis.literal(numeric_range.value) - upper = ibis.literal(numeric_range.extent) - predicate = expr.between(lower, upper) - if op.startswith("!"): - predicate = ~predicate - else: - comparator = _map_operator(op) - operand = ibis.literal(numeric_range.value) - predicate = comparator(expr, operand) - return table.filter(predicate) - - -def apply_text_filter(table: ir.Table, column: str, text_filter: TextFilter | None) -> ir.Table: - if not text_filter or not text_filter.text: - return table - op = text_filter.op or "contains" - negate = op.startswith("!") - core = op[1:] if negate else op - core = core.lower() - prefix = "%" if core in {"endswith", "contains"} else "" - suffix = "%" if core in {"startswith", "contains"} else "" - pattern = f"{prefix}{text_filter.text}{suffix}" - col_expr = cast(ir.StringValue, table[column]) - predicate = col_expr.like(pattern) - if negate: - predicate = ~predicate - return table.filter(predicate) - - -def apply_interval_range( - table: ir.Table, - start_column: str, - end_column: str, - interval_range: NumericRange | None, -) -> ir.Table: - if not interval_range or interval_range.value is None: - return table - - op = (interval_range.op or "gte").lower() - value = int(interval_range.value) - start = cast(Any, table[start_column]) - end = table[end_column] - - def _interval(days: int): - return ibis.interval(days=int(days)) - - if op.endswith("bt"): - if interval_range.extent is None: - raise ValueError("Between operator for interval range requires an extent") - lower = _interval(value) - upper = _interval(int(interval_range.extent)) - predicate = (end >= start + lower) & (end <= start + upper) - if op.startswith("!"): - predicate = ~predicate - return table.filter(predicate) - - target = _interval(value) - if op == "lt": - predicate = end < start + target - elif op == "lte": - predicate = end <= start + target - elif op == "gt": - predicate = end > start + target - elif op == "gte": - predicate = end >= start + target - elif op == "eq": - predicate = (end >= start + target) & (end < start + _interval(value + 1)) - elif op == "!eq": - predicate = ~((end >= start + target) & (end < start + _interval(value + 1))) - else: - raise ValueError(f"Unsupported operator for interval range: {op}") - - return table.filter(predicate) - - -def _map_operator(op: str): - mapping = { - "lt": lambda a, b: a < b, - "lte": lambda a, b: a <= b, - "eq": lambda a, b: a == b, - "!eq": lambda a, b: a != b, - "gt": lambda a, b: a > b, - "gte": lambda a, b: a >= b, - } - if op not in mapping: - raise ValueError(f"Operator {op} not supported") - return mapping[op] - - -def apply_concept_filters( - table: ir.Table, - column: str, - include_concepts: Sequence[Concept] | None, - exclude: bool = False, -) -> ir.Table: - if not include_concepts: - return table - concept_ids = [c.concept_id for c in include_concepts if c.concept_id is not None] - if not concept_ids: - return table - predicate = table[column].isin(cast(Any, concept_ids)) - if exclude: - predicate = ~predicate - return table.filter(predicate) - - -def apply_age_filter( - table: ir.Table, - age_range: NumericRange | None, - ctx: BuildContext, - start_column: str, -) -> ir.Table: - if not age_range: - return table - base_columns = table.columns - person = _person_subset(ctx, ["person_id", "year_of_birth"]) - joined = table.join(person, ["person_id"]) - start_expr = cast(ir.TimestampValue, _ensure_timestamp(joined[start_column])) - age_expr = start_expr.year() - cast(Any, joined.year_of_birth) - joined = joined.mutate(_criteria_age=age_expr) - filtered = apply_numeric_range(joined, "_criteria_age", age_range) - filtered = filtered.drop("_criteria_age") - return _project_columns(filtered, base_columns) - - -def apply_gender_filter( - table: ir.Table, - genders: list[Concept] | None, - gender_selection: ConceptSetSelection | None, - ctx: BuildContext, -) -> ir.Table: - return _apply_person_concept_filter( - table, - person_column="gender_concept_id", - concepts=genders, - selection=gender_selection, - ctx=ctx, - ) - - -def apply_race_filter( - table: ir.Table, - races: list[Concept] | None, - race_selection: ConceptSetSelection | None, - ctx: BuildContext, -) -> ir.Table: - return _apply_person_concept_filter( - table, - person_column="race_concept_id", - concepts=races, - selection=race_selection, - ctx=ctx, - ) - - -def apply_ethnicity_filter( - table: ir.Table, - ethnicities: list[Concept] | None, - ethnicity_selection: ConceptSetSelection | None, - ctx: BuildContext, -) -> ir.Table: - return _apply_person_concept_filter( - table, - person_column="ethnicity_concept_id", - concepts=ethnicities, - selection=ethnicity_selection, - ctx=ctx, - ) - - -def _apply_person_concept_filter( - table: ir.Table, - *, - person_column: str, - concepts: Sequence[Concept] | None, - selection: ConceptSetSelection | None, - ctx: BuildContext, -) -> ir.Table: - if not concepts and not selection: - return table - base_columns = table.columns - person = _person_subset(ctx, ["person_id", person_column]) - joined = table.join(person, ["person_id"]) - joined = apply_concept_criteria( - joined, - column=person_column, - concepts=concepts, - selection=selection, - ctx=ctx, - ) - return _project_columns(joined, base_columns) - - -def apply_observation_window( - events: ir.Table, - observation_window, - ctx: BuildContext, -) -> ir.Table: - if observation_window is None: - return events - observation = ctx.table("observation_period").select( - "person_id", "observation_period_start_date", "observation_period_end_date" - ) - # Use a view to ensure subsequent joins don't mix incompatible relations. - left = events.view() - joined = left.join(observation, ["person_id"]) - prior_days = ibis.interval(days=int(observation_window.prior_days or 0)) - post_days = ibis.interval(days=int(observation_window.post_days or 0)) - start_col = _ensure_timestamp(joined.observation_period_start_date) - end_col = _ensure_timestamp(joined.observation_period_end_date) - start_bound = start_col + cast(Any, prior_days) - end_bound = end_col - cast(Any, post_days) - filtered = joined.filter((joined.start_date >= start_bound) & (joined.start_date <= end_bound)) - base_projection = [filtered[col] for col in events.columns] - base_projection.extend( - filtered[col] - for col in ("observation_period_start_date", "observation_period_end_date") - if col in filtered.columns - ) - return filtered.select(*base_projection) - - -def apply_first_event(table: ir.Table, start_column: str, primary_key: str) -> ir.Table: - window = ibis.window( - group_by=table.person_id, - order_by=[table[start_column], table[primary_key]], - ) - - ranked = table.mutate(_row_num=row_number().over(window)) - filtered = ranked.filter(ranked["_row_num"] == ibis.literal(0)) - keep_columns = [col for col in table.columns if col != "_row_num"] - if keep_columns: - return filtered.select(*(filtered[col] for col in keep_columns)) - return filtered.drop("_row_num") - - -def apply_visit_concept_filters( - table: ir.Table, - visit_types: list[Concept] | None, - visit_selection: ConceptSetSelection | None, - ctx: BuildContext, -) -> ir.Table: - return apply_concept_criteria( - table, - column="visit_concept_id", - concepts=visit_types, - selection=visit_selection, - ctx=ctx, - ) - - -def apply_provider_specialty_filter( - table: ir.Table, - provider_specialties: list[Concept] | None, - provider_specialty_selection: ConceptSetSelection | None, - ctx: BuildContext, - provider_column: str = "provider_id", -) -> ir.Table: - if not provider_specialties and not provider_specialty_selection: - return table - provider = ctx.table("provider") - provider = apply_concept_criteria( - provider, - column="specialty_concept_id", - concepts=provider_specialties, - selection=provider_specialty_selection, - ctx=ctx, - ) - filtered = provider.select(provider.provider_id) - return table.semi_join(filtered, [table[provider_column] == filtered.provider_id]) - - -def apply_care_site_filter( - table: ir.Table, - place_of_service_selection: ConceptSetSelection | None, - ctx: BuildContext, - care_site_column: str = "care_site_id", -) -> ir.Table: - if not place_of_service_selection: - return table - care_site = ctx.table("care_site") - filtered = apply_concept_set_selection( - care_site, "place_of_service_concept_id", place_of_service_selection, ctx - ) - filtered = filtered.select(filtered.care_site_id) - return table.semi_join(filtered, [table[care_site_column] == filtered.care_site_id]) - - -def apply_location_region_filter( - table: ir.Table, - *, - care_site_column: str, - location_codeset_id: int | None, - start_column: str, - end_column: str, - ctx: BuildContext, -) -> ir.Table: - if not location_codeset_id: - return table - base_columns = table.columns - care_site = ctx.table("care_site") - location_history = ctx.table("location_history") - location = ctx.table("location") - joined = table.join(care_site, [table[care_site_column] == care_site.care_site_id]) - start_expr = _ensure_timestamp(joined[start_column]) - end_expr = _ensure_timestamp(joined[end_column]) - lh = location_history - lh_condition = ( - (joined[care_site_column] == lh.entity_id) - & (lh.domain_id == ibis.literal("CARE_SITE")) - & (start_expr >= lh.start_date) - & (end_expr <= ibis.coalesce(lh.end_date, ibis.literal("2099-12-31").cast("date"))) - ) - joined = joined.join(lh, [lh_condition]) - joined = joined.join(location, [joined.location_id == location.location_id]) - codeset = ctx.codesets.filter(ctx.codesets.codeset_id == ibis.literal(location_codeset_id)) - filtered = joined.join(codeset, [location.region_concept_id == codeset.concept_id]) - return _project_columns(filtered, base_columns) - - -def apply_user_defined_period( - table: ir.Table, - start_column: str, - end_column: str, - period, -) -> tuple[ir.Table, str, str]: - if not period: - return table, start_column, end_column - - base_start = table[start_column] - base_end = table[end_column] - additions = {} - new_start = start_column - new_end = end_column - - if getattr(period, "start_date", None): - literal = _literal_like(period.start_date, base_start) - additions["_user_defined_start"] = literal - table = table.filter((base_start <= literal) & (base_end >= literal)) - new_start = "_user_defined_start" - - if getattr(period, "end_date", None): - literal = _literal_like(period.end_date, base_end) - additions["_user_defined_end"] = literal - table = table.filter((base_start <= literal) & (base_end >= literal)) - new_end = "_user_defined_end" - - if additions: - table = table.mutate(**additions) - - return table, new_start, new_end - - -def _literal_like(value, reference): - literal = ibis.literal(value) - dtype = reference.type() - if dtype.is_timestamp(): - return literal.cast("timestamp") - if dtype.is_date(): - return literal.cast("date") - return literal - - -def _ensure_timestamp(expr: ir.Value) -> ir.Value: - dtype = expr.type() - if dtype.is_timestamp(): - return expr - if dtype.is_date(): - return expr.cast("timestamp") - if dtype.is_string(): - return ibis.to_timestamp(expr) - raise ValueError(f"Cannot convert expression of type {dtype} to timestamp") - - -def _cast_like(expr: ir.Value, reference: ir.Value) -> ir.Value: - target_type = reference.type() - if expr.type() == target_type: - return expr - return expr.cast(cast(Any, target_type)) - - -def _project_columns(table: ir.Table, column_names: Sequence[str]) -> ir.Table: - available = [name for name in column_names if name in table.columns] - if not available: - return table - return table.select(*[table[name] for name in available]) - - -def apply_end_strategy( - events: ir.Table, - strategy: EndStrategy | DateOffsetStrategy | CustomEraStrategy | None, - ctx: BuildContext, -) -> ir.Table: - date_offset, custom_era = _resolve_end_strategy_parts(strategy) - if not date_offset and not custom_era: - if "observation_period_end_date" in events.columns: - op_end = _cast_like(_ensure_timestamp(events.observation_period_end_date), events.end_date) - return events.mutate(end_date=op_end) - return events - result = events - if custom_era: - result = _apply_custom_era_strategy(result, custom_era, ctx) - if date_offset: - interval = ibis.interval(days=int(date_offset.offset)) - date_field = str(date_offset.date_field or "StartDate").lower() - anchor = ( - _ensure_timestamp(result.start_date) - if date_field == "startdate" - else _ensure_timestamp(result.end_date) - ) - shifted = anchor + cast(Any, interval) - if "observation_period_end_date" in result.columns: - shifted = ibis.least( - shifted, - _ensure_timestamp(result.observation_period_end_date), - ) - result = result.mutate(end_date=_cast_like(shifted, result.end_date)) - return result - - -def has_end_strategy( - strategy: EndStrategy | DateOffsetStrategy | CustomEraStrategy | None, -) -> bool: - date_offset, custom_era = _resolve_end_strategy_parts(strategy) - return bool(date_offset or custom_era) - - -def _resolve_end_strategy_parts( - strategy: EndStrategy | DateOffsetStrategy | CustomEraStrategy | None, -) -> tuple[DateOffsetStrategy | None, CustomEraStrategy | None]: - if strategy is None: - return None, None - - if isinstance(strategy, DateOffsetStrategy): - return strategy, None - - if isinstance(strategy, CustomEraStrategy): - return None, strategy - - date_offset = getattr(strategy, "date_offset", None) - custom_era = getattr(strategy, "custom_era", None) - - if isinstance(date_offset, dict): - date_offset = DateOffsetStrategy.model_validate(date_offset, strict=False) - if isinstance(custom_era, dict): - custom_era = CustomEraStrategy.model_validate(custom_era, strict=False) - - return date_offset, custom_era - - -def collapse_events(events: ir.Table, settings: CollapseSettings | None) -> ir.Table: - if not settings or settings.collapse_type != CollapseType.ERA: - return events - pad_interval = ibis.interval(days=int(settings.era_pad or 0)) - order_by = [events.start_date, events.end_date, events.event_id] - prev_window = ibis.window( - group_by=events.person_id, - order_by=order_by, - preceding=(None, 1), - ) - extended_end = events.end_date + cast(Any, pad_interval) - prev_max = extended_end.max().over(prev_window) - is_start = ibis.ifelse( - prev_max.notnull() & (prev_max >= events.start_date), - 0, - 1, - ) - annotated = events.mutate( - extended_end=extended_end, - is_start=is_start, - ) - is_start_col = cast(ir.IntegerColumn, annotated["is_start"]) - era_window = ibis.window( - group_by=annotated.person_id, - order_by=[ - annotated.start_date, - ibis.desc(is_start_col), - annotated.end_date, - annotated.event_id, - ], - ) - era_id = is_start_col.cumsum().over(era_window) - grouped = annotated.mutate(_era_id=era_id) - max_end = cast(ir.IntervalScalar, grouped.extended_end.max()) - collapsed = grouped.group_by(grouped.person_id, grouped._era_id).aggregate( - start_date=grouped.start_date.min(), - end_date=(max_end - pad_interval), - visit_occurrence_id=grouped.visit_occurrence_id.max(), - ) - final_window = ibis.window(order_by=[collapsed.person_id, collapsed.start_date, collapsed.end_date]) - collapsed = collapsed.mutate(event_id=(ibis.row_number().over(final_window) + 1)).select( - "person_id", "event_id", "start_date", "end_date", "visit_occurrence_id" - ) - return collapsed - - -def _apply_custom_era_strategy(events: ir.Table, strategy: CustomEraStrategy, ctx: BuildContext) -> ir.Table: - if strategy.drug_codeset_id is None: - raise ValueError("Custom era strategy requires a drug codeset id.") - - persons = events.select(events.person_id).distinct() - codeset = ctx.codesets.filter(ctx.codesets.codeset_id == strategy.drug_codeset_id) - drug_exposure = ctx.table("drug_exposure") - - def _exposure_query(concept_column: str) -> ir.Table: - return ( - drug_exposure.join(persons, ["person_id"]) - .join(codeset, drug_exposure[concept_column] == codeset.concept_id) - .select( - drug_exposure.person_id, - drug_exposure.drug_exposure_start_date.name("drug_exposure_start_date"), - _drug_exposure_end(drug_exposure, strategy).name("drug_exposure_end_date"), - ) - ) - - exposures = _exposure_query("drug_concept_id").union( - _exposure_query("drug_source_concept_id"), distinct=False - ) - - gap = int(strategy.gap_days or 0) - offset = int(strategy.offset or 0) - extend_interval = ibis.interval(days=gap + offset) - - dt = exposures.select( - exposures.person_id, - exposures.drug_exposure_start_date.name("start_date"), - (exposures.drug_exposure_end_date + extend_interval).name("extended_end"), - ).distinct() - - prev_max_window = ibis.window( - group_by=dt.person_id, - order_by=[dt.start_date, dt.extended_end], - preceding=(None, 1), - ) - prev_running_max = dt.extended_end.max().over(prev_max_window) - is_start = ibis.ifelse(prev_running_max.notnull() & (prev_running_max >= dt.start_date), 0, 1) - staged = dt.mutate(is_start=is_start).view() - cumsum_window = ibis.window(group_by=staged.person_id, order_by=[staged.start_date, staged.extended_end]) - group_idx = staged.is_start.cumsum().over(cumsum_window) - annotated = staged.mutate(group_idx=group_idx) - - eras = annotated.group_by(annotated.person_id, annotated.group_idx).aggregate( - era_start=annotated.start_date.min(), - era_end=(annotated.extended_end.max() - ibis.interval(days=gap)), - ) - - join_condition = ( - (events.person_id == eras.person_id) - & (events.start_date >= eras.era_start) - & (events.start_date <= eras.era_end) - ) - joined = events.join(eras, join_condition, how="inner") - if not joined.columns: - return events.limit(0) - supplemental = [ - joined[column] - for column in ("observation_period_start_date", "observation_period_end_date") - if column in joined.columns - ] - return joined.select( - joined.person_id, - joined.event_id, - joined.start_date, - joined.era_end.name("end_date"), - joined.visit_occurrence_id, - *supplemental, - ) - - -def _drug_exposure_end(drug_exposure: ir.Table, strategy: CustomEraStrategy) -> ir.Value: - start = drug_exposure.drug_exposure_start_date - if strategy.days_supply_override is not None: - return start + ibis.interval(days=int(strategy.days_supply_override)) - - end_candidates = [ - drug_exposure.drug_exposure_end_date, - ibis.ifelse( - drug_exposure.days_supply.notnull(), - start + (ibis.interval(days=1) * drug_exposure.days_supply.cast("int64")), - ibis.null(), - ), - start + ibis.interval(days=1), - ] - return ibis.coalesce(*end_candidates) diff --git a/circe/execution/builders/condition_era.py b/circe/execution/builders/condition_era.py deleted file mode 100644 index 359232ce..00000000 --- a/circe/execution/builders/condition_era.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import ConditionEra -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_interval_range, - apply_numeric_range, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("ConditionEra") -def build_condition_era(criteria: ConditionEra, ctx: BuildContext): - table = ctx.table("condition_era") - - table = apply_codeset_filter(table, "condition_concept_id", criteria.codeset_id, ctx) - table = apply_date_range(table, "condition_era_start_date", criteria.era_start_date) - table = apply_date_range(table, "condition_era_end_date", criteria.era_end_date) - table = apply_numeric_range(table, "condition_occurrence_count", criteria.occurrence_count) - table = apply_interval_range( - table, "condition_era_start_date", "condition_era_end_date", criteria.era_length - ) - - if criteria.age_at_start: - table = apply_age_filter(table, criteria.age_at_start, ctx, "condition_era_start_date") - if criteria.age_at_end: - table = apply_age_filter(table, criteria.age_at_end, ctx, "condition_era_end_date") - - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - if criteria.first: - table = apply_first_event(table, "condition_era_start_date", "condition_era_id") - - events = standardize_output( - table, - primary_key="condition_era_id", - start_column="condition_era_start_date", - end_column="condition_era_end_date", - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/condition_occurrence.py b/circe/execution/builders/condition_occurrence.py deleted file mode 100644 index 73129f26..00000000 --- a/circe/execution/builders/condition_occurrence.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import ConditionOccurrence -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_visit_concept_filters, - coerce_concept_set_selection, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("ConditionOccurrence") -def build_condition_occurrence(criteria: ConditionOccurrence, ctx: BuildContext): - table = ctx.table("condition_occurrence") - - concept_column = criteria.get_concept_id_column() - table = apply_codeset_filter(table, concept_column, criteria.codeset_id, ctx) - if criteria.first: - table = apply_first_event(table, criteria.get_start_date_column(), criteria.get_primary_key_column()) - - table = apply_date_range(table, criteria.get_start_date_column(), criteria.occurrence_start_date) - table = apply_date_range(table, criteria.get_end_date_column(), criteria.occurrence_end_date) - - table = apply_concept_criteria( - table, - column="condition_type_concept_id", - concepts=criteria.condition_type, - selection=criteria.condition_type_cs, - ctx=ctx, - exclude=bool(criteria.condition_type_exclude), - ) - - table = apply_concept_criteria( - table, - column="condition_status_concept_id", - concepts=getattr(criteria, "condition_status", None), - selection=None, - ctx=ctx, - ) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - source_filter = getattr(criteria, "condition_source_concept", None) - selection = coerce_concept_set_selection(source_filter) - if selection is not None: - table = apply_concept_criteria( - table, - column="condition_source_concept_id", - concepts=None, - selection=selection, - ctx=ctx, - ) - - visit_source = getattr(criteria, "visit_source_concept", None) - needs_visit_filters = bool(criteria.visit_type or criteria.visit_type_cs or visit_source is not None) - if needs_visit_filters: - visit = ctx.table("visit_occurrence").select( - "person_id", - "visit_occurrence_id", - "visit_concept_id", - "visit_source_concept_id", - ) - table = table.join( - visit, - (table.visit_occurrence_id == visit.visit_occurrence_id) & (table.person_id == visit.person_id), - ) - table = apply_visit_concept_filters(table, criteria.visit_type, criteria.visit_type_cs, ctx) - if visit_source is not None: - table = table.filter(table.visit_source_concept_id == int(visit_source)) - - events = standardize_output( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/death.py b/circe/execution/builders/death.py deleted file mode 100644 index 847398db..00000000 --- a/circe/execution/builders/death.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import ibis - -from ...cohortdefinition.criteria import Death -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_gender_filter, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("Death") -def build_death(criteria: Death, ctx: BuildContext): - table = ctx.table("death") - - table = apply_codeset_filter(table, "cause_concept_id", criteria.codeset_id, ctx) - - table = apply_date_range(table, "death_date", getattr(criteria, "occurrence_start_date", None)) - - table = apply_concept_criteria( - table, - column="death_type_concept_id", - concepts=criteria.death_type, - selection=criteria.death_type_cs, - ctx=ctx, - exclude=bool(getattr(criteria, "death_type_exclude", False)), - ) - - if getattr(criteria, "death_source_concept", None) is not None: - table = apply_codeset_filter( - table, - "cause_source_concept_id", - int(criteria.death_source_concept), - ctx, - ) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - window = ibis.window(order_by=[table.person_id, table.death_date]) - table = table.mutate(death_event_id=ibis.row_number().over(window)) - - events = standardize_output( - table, - primary_key="death_event_id", - start_column="death_date", - end_column="death_date", - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/device_exposure.py b/circe/execution/builders/device_exposure.py deleted file mode 100644 index ac34fdf0..00000000 --- a/circe/execution/builders/device_exposure.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import DeviceExposure -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_numeric_range, - apply_provider_specialty_filter, - apply_text_filter, - apply_visit_concept_filters, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("DeviceExposure") -def build_device_exposure(criteria: DeviceExposure, ctx: BuildContext): - table = ctx.table("device_exposure") - - concept_column = criteria.get_concept_id_column() - table = apply_codeset_filter(table, concept_column, criteria.codeset_id, ctx) - - table = apply_date_range(table, criteria.get_start_date_column(), criteria.occurrence_start_date) - table = apply_date_range(table, criteria.get_end_date_column(), criteria.occurrence_end_date) - - table = apply_concept_criteria( - table, - column="device_type_concept_id", - concepts=criteria.device_type, - selection=criteria.device_type_cs, - ctx=ctx, - exclude=bool(criteria.device_type_exclude), - ) - - table = apply_numeric_range(table, "quantity", criteria.quantity) - table = apply_text_filter(table, "unique_device_id", getattr(criteria, "unique_device_id", None)) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - table = apply_provider_specialty_filter( - table, - getattr(criteria, "provider_specialty", None), - getattr(criteria, "provider_specialty_cs", None), - ctx, - provider_column="provider_id", - ) - table = apply_visit_concept_filters(table, criteria.visit_type, criteria.visit_type_cs, ctx) - if criteria.device_source_concept is not None: - table = apply_codeset_filter( - table, - "device_source_concept_id", - criteria.device_source_concept, - ctx, - ) - - if criteria.first: - table = apply_first_event(table, criteria.get_start_date_column(), criteria.get_primary_key_column()) - - events = standardize_output( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/dose_era.py b/circe/execution/builders/dose_era.py deleted file mode 100644 index 6aa0f469..00000000 --- a/circe/execution/builders/dose_era.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import DoseEra -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_interval_range, - apply_numeric_range, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("DoseEra") -def build_dose_era(criteria: DoseEra, ctx: BuildContext): - table = ctx.table("dose_era") - - table = apply_codeset_filter(table, "drug_concept_id", criteria.codeset_id, ctx) - table = apply_date_range(table, "dose_era_start_date", criteria.era_start_date) - table = apply_date_range(table, "dose_era_end_date", criteria.era_end_date) - - table = apply_concept_criteria( - table, - column="unit_concept_id", - concepts=criteria.unit, - selection=criteria.unit_cs, - ctx=ctx, - ) - - table = apply_numeric_range(table, "dose_value", criteria.dose_value) - table = apply_interval_range(table, "dose_era_start_date", "dose_era_end_date", criteria.era_length) - - if criteria.age_at_start: - table = apply_age_filter(table, criteria.age_at_start, ctx, "dose_era_start_date") - if criteria.age_at_end: - table = apply_age_filter(table, criteria.age_at_end, ctx, "dose_era_end_date") - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - if criteria.first: - table = apply_first_event(table, "dose_era_start_date", "dose_era_id") - - events = standardize_output( - table, - primary_key="dose_era_id", - start_column="dose_era_start_date", - end_column="dose_era_end_date", - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/drug_era.py b/circe/execution/builders/drug_era.py deleted file mode 100644 index f2e99a03..00000000 --- a/circe/execution/builders/drug_era.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import DrugEra -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_interval_range, - apply_numeric_range, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("DrugEra") -def build_drug_era(criteria: DrugEra, ctx: BuildContext): - table = ctx.table("drug_era") - - table = apply_codeset_filter(table, "drug_concept_id", criteria.codeset_id, ctx) - table = apply_date_range(table, "drug_era_start_date", criteria.era_start_date) - table = apply_date_range(table, "drug_era_end_date", criteria.era_end_date) - table = apply_numeric_range(table, "drug_exposure_count", criteria.occurrence_count) - table = apply_numeric_range(table, "gap_days", criteria.gap_days) - table = apply_interval_range(table, "drug_era_start_date", "drug_era_end_date", criteria.era_length) - - if criteria.age_at_start: - table = apply_age_filter(table, criteria.age_at_start, ctx, "drug_era_start_date") - if criteria.age_at_end: - table = apply_age_filter(table, criteria.age_at_end, ctx, "drug_era_end_date") - - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - if criteria.first: - table = apply_first_event(table, "drug_era_start_date", "drug_era_id") - - events = standardize_output( - table, - primary_key="drug_era_id", - start_column="drug_era_start_date", - end_column="drug_era_end_date", - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/drug_exposure.py b/circe/execution/builders/drug_exposure.py deleted file mode 100644 index 665f9f4c..00000000 --- a/circe/execution/builders/drug_exposure.py +++ /dev/null @@ -1,93 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import DrugExposure -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_numeric_range, - apply_provider_specialty_filter, - apply_text_filter, - apply_visit_concept_filters, - coerce_concept_set_selection, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("DrugExposure") -def build_drug_exposure(criteria: DrugExposure, ctx: BuildContext): - table = ctx.table("drug_exposure") - - concept_column = criteria.get_concept_id_column() - table = apply_codeset_filter(table, concept_column, criteria.codeset_id, ctx) - if criteria.first: - table = apply_first_event(table, criteria.get_start_date_column(), criteria.get_primary_key_column()) - - table = apply_date_range(table, criteria.get_start_date_column(), criteria.occurrence_start_date) - table = apply_date_range(table, criteria.get_end_date_column(), criteria.occurrence_end_date) - - table = apply_concept_criteria( - table, - column="drug_type_concept_id", - concepts=criteria.drug_type, - selection=criteria.drug_type_cs, - ctx=ctx, - exclude=bool(getattr(criteria, "drug_type_exclude", False)), - ) - table = apply_concept_criteria( - table, - column="route_concept_id", - concepts=criteria.route_concept, - selection=criteria.route_concept_cs, - ctx=ctx, - ) - table = apply_concept_criteria( - table, - column="dose_unit_concept_id", - concepts=getattr(criteria, "dose_unit", []), - selection=getattr(criteria, "dose_unit_cs", None), - ctx=ctx, - ) - - table = apply_numeric_range(table, "quantity", criteria.quantity) - table = apply_numeric_range(table, "days_supply", criteria.days_supply) - table = apply_numeric_range(table, "refills", criteria.refills) - table = apply_text_filter(table, "stop_reason", getattr(criteria, "stop_reason", None)) - table = apply_text_filter(table, "lot_number", getattr(criteria, "lot_number", None)) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - table = apply_provider_specialty_filter( - table, - getattr(criteria, "provider_specialty", None), - getattr(criteria, "provider_specialty_cs", None), - ctx, - provider_column="provider_id", - ) - table = apply_visit_concept_filters(table, criteria.visit_type, criteria.visit_type_cs, ctx) - - source_filter = getattr(criteria, "drug_source_concept", None) - selection = coerce_concept_set_selection(source_filter) - if selection is not None: - table = apply_concept_criteria( - table, - column="drug_source_concept_id", - concepts=None, - selection=selection, - ctx=ctx, - ) - - events = standardize_output( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/groups.py b/circe/execution/builders/groups.py deleted file mode 100644 index 3bcec09a..00000000 --- a/circe/execution/builders/groups.py +++ /dev/null @@ -1,441 +0,0 @@ -from __future__ import annotations - -from typing import Callable - -import ibis -import ibis.common.exceptions as ibis_exc -import ibis.expr.types as ir - -from ...cohortdefinition.core import ObservationFilter -from ...cohortdefinition.criteria import ( - Criteria, - CriteriaColumn, - CriteriaGroup, - VisitDetail, -) -from ..build_context import BuildContext -from ..criteria_compat import ( - CorrelatedCriteria, - DemoGraphicCriteria, - OccurrenceType, - parse_single_criteria, -) -from .common import ( - apply_age_filter, - apply_date_range, - apply_ethnicity_filter, - apply_gender_filter, - apply_observation_window, - apply_race_filter, -) -from .registry import build_events - - -def apply_criteria_group(events: ir.Table, group: CriteriaGroup | None, ctx: BuildContext) -> ir.Table: - mask = _group_mask(events, group, ctx) - if mask is None: - return events - return events.filter(mask) - - -def _correlated_mask(events: ir.Table, correlated: CorrelatedCriteria, ctx: BuildContext) -> ir.Value: - criteria_model = correlated.criteria - if criteria_model and not isinstance(criteria_model, ir.Expr): - criteria_model = parse_single_criteria(criteria_model) - if criteria_model is None: - return ibis.literal(True) - - count_column_name, count_column_enum = _resolve_count_column(correlated.occurrence) - - base_events = build_events(criteria_model, ctx) - base_events = _attach_count_columns( - base_events, - criteria_model, - ctx, - count_column_name=count_column_name, - count_column_enum=count_column_enum, - ) - requires_corr_end_alignment = _requires_observation_period_end_alignment(correlated) - zero_window: ObservationFilter | None = None - if not correlated.ignore_observation_period: - zero_window = ObservationFilter(prior_days=0, post_days=0) - base_events = apply_observation_window(base_events, zero_window, ctx) - - index_events = events - if not correlated.ignore_observation_period: - missing_observation_bounds = ( - "observation_period_start_date" not in index_events.columns - or "observation_period_end_date" not in index_events.columns - ) - if missing_observation_bounds: - zero_window = zero_window or ObservationFilter(prior_days=0, post_days=0) - index_events = apply_observation_window(index_events, zero_window, ctx) - - select_fields = [ - base_events.person_id, - base_events.event_id.name("_corr_event_id"), - base_events.start_date.name("_corr_start_date"), - base_events.end_date.name("_corr_end_date"), - ] - if "visit_occurrence_id" in base_events.columns: - select_fields.append(base_events.visit_occurrence_id.name("_corr_visit_occurrence_id")) - if count_column_name and count_column_name in base_events.columns: - select_fields.append(base_events[count_column_name]) - - criteria_events = base_events.select(*select_fields) - join_condition = index_events.person_id == criteria_events.person_id - if not correlated.ignore_observation_period: - if "observation_period_start_date" in index_events.columns: - join_condition &= criteria_events._corr_start_date >= index_events.observation_period_start_date - if "observation_period_end_date" in index_events.columns: - join_condition &= criteria_events._corr_start_date <= index_events.observation_period_end_date - if requires_corr_end_alignment: - join_condition &= criteria_events._corr_end_date <= index_events.observation_period_end_date - window_condition = _build_window_condition(index_events, criteria_events, correlated) - if window_condition is not None: - join_condition &= window_condition - - occurrence = correlated.occurrence - occ_type = getattr(occurrence, "type", None) - if isinstance(occ_type, int): - occ_type = OccurrenceType(occurrence.type) - - require_same_visit = bool(correlated.restrict_visit) - if correlated.restrict_visit is None and isinstance(criteria_model, VisitDetail): - require_same_visit = True - - if require_same_visit and ( - "visit_occurrence_id" in index_events.columns - and "_corr_visit_occurrence_id" in criteria_events.columns - ): - join_condition &= ( - index_events.visit_occurrence_id.notnull() - & criteria_events._corr_visit_occurrence_id.notnull() - & (index_events.visit_occurrence_id == criteria_events._corr_visit_occurrence_id) - ) - - joined = index_events.join(criteria_events, join_condition, how="left") - - corr_event_id = joined._corr_event_id - count_expr = corr_event_id - if count_column_name and count_column_name in joined.columns: - count_expr = joined[count_column_name] - match_expr = corr_event_id.notnull() - joined = joined.mutate( - _corr_match_value=ibis.ifelse(match_expr, count_expr, ibis.null()), - ) - - if correlated.occurrence and correlated.occurrence.is_distinct: - aggregator = joined._corr_match_value.nunique() - else: - aggregator = joined._corr_match_value.count() - - aggregated = joined.group_by(joined.person_id, joined.event_id).aggregate(match_count=aggregator) - predicate = _occurrence_predicate(aggregated.match_count, correlated.occurrence) - matching_ids = aggregated.filter(predicate).select("person_id", "event_id").distinct() - return _event_membership_mask(events, matching_ids) - - -def _group_mask(events: ir.Table, group: CriteriaGroup | None, ctx: BuildContext) -> ir.Value | None: - if not group or group.is_empty(): - return None - - masks: list[ir.Value] = [] - for correlated in group.criteria_list or []: - masks.append(_correlated_mask(events, correlated, ctx)) - - for demographic in group.demographic_criteria_list or []: - demo_mask = _demographic_mask(events, demographic, ctx) - if demo_mask is not None: - masks.append(demo_mask) - - for subgroup in group.groups or []: - sub_mask = _group_mask(events, subgroup, ctx) - if sub_mask is not None: - masks.append(sub_mask) - - if not masks: - return None - - group_type = (group.type or "ALL").upper() - if group_type == "ANY": - return _combine_any(masks) - if group_type.startswith("AT_"): - count = group.count - if group_type.endswith("LEAST"): - threshold = count if count is not None else 1 - return _combine_threshold(masks, threshold, at_least=True) - threshold = count if count is not None else 0 - return _combine_threshold(masks, threshold, at_least=False) - return _combine_all(masks) - - -def _combine_all(masks: list[ir.Value]) -> ir.Value: - combined = masks[0] - for mask in masks[1:]: - combined = combined & mask - return combined - - -def _combine_any(masks: list[ir.Value]) -> ir.Value: - combined = masks[0] - for mask in masks[1:]: - combined = combined | mask - return combined - - -def _combine_threshold(masks: list[ir.Value], threshold: int, *, at_least: bool) -> ir.Value: - def _to_int(mask: ir.Value) -> ir.Value: - return ibis.ifelse(mask, ibis.literal(1, type="int64"), ibis.literal(0, type="int64")) - - total = _to_int(masks[0]) - for mask in masks[1:]: - total = total + _to_int(mask) - return total >= threshold if at_least else total <= threshold - - -def _demographic_mask( - events: ir.Table, - demographic: DemoGraphicCriteria, - ctx: BuildContext, -) -> ir.Value | None: - if demographic is None: - return None - - filtered = events - applied = False - if demographic.age: - filtered = apply_age_filter(filtered, demographic.age, ctx, "start_date") - applied = True - if demographic.gender or demographic.gender_cs: - filtered = apply_gender_filter(filtered, demographic.gender, demographic.gender_cs, ctx) - applied = True - if demographic.race or demographic.race_cs: - filtered = apply_race_filter(filtered, demographic.race, demographic.race_cs, ctx) - applied = True - if demographic.ethnicity or demographic.ethnicity_cs: - filtered = apply_ethnicity_filter(filtered, demographic.ethnicity, demographic.ethnicity_cs, ctx) - applied = True - if demographic.occurrence_start_date: - filtered = apply_date_range(filtered, "start_date", demographic.occurrence_start_date) - applied = True - if demographic.occurrence_end_date: - filtered = apply_date_range(filtered, "end_date", demographic.occurrence_end_date) - applied = True - - if not applied: - return None - - filtered_ids = filtered.select(filtered.person_id, filtered.event_id).distinct() - return _event_membership_mask(events, filtered_ids) - - -def _event_membership_mask(events: ir.Table, ids: ir.Table) -> ir.Value: - keys = ids.mutate(_event_key=_event_key_expr(ids)).select("_event_key") - return _event_key_expr(events).isin(keys._event_key) - - -def _event_key_expr(table: ir.Table) -> ir.Value: - return table.person_id.cast("string") + ibis.literal(":") + table.event_id.cast("string") - - -def _occurrence_predicate(count_expr: ir.Value, occurrence) -> ir.Value: - if occurrence is None: - return count_expr > 0 - - occ_type = occurrence.type - if isinstance(occ_type, int): - occ_type = OccurrenceType(occurrence.type) - - if occ_type == OccurrenceType.EXACTLY: - return count_expr == occurrence.count - if occ_type == OccurrenceType.AT_LEAST: - return count_expr >= occurrence.count - if occ_type == OccurrenceType.AT_MOST: - return count_expr <= occurrence.count - return count_expr > 0 - - -def _build_window_condition( - index_events: ir.Table, - correlated_events: ir.Table, - correlated: CorrelatedCriteria, -) -> ir.Value: - cond = ibis.literal(True) - - if correlated.start_window: - correlated_start = _correlated_window_value( - correlated_events, - correlated.start_window.use_event_end, - default="start", - ) - lower = _apply_endpoint_anchor( - index_events, - correlated.start_window.start, - correlated.start_window.use_index_end, - ) - upper = _apply_endpoint_anchor( - index_events, - correlated.start_window.end, - correlated.start_window.use_index_end, - ) - if lower is not None: - cond &= correlated_start >= lower - if upper is not None: - cond &= correlated_start <= upper - - if correlated.end_window: - lower = _apply_endpoint_anchor( - index_events, - correlated.end_window.start, - correlated.end_window.use_index_end, - default_to_index_end=False, - ) - upper = _apply_endpoint_anchor( - index_events, - correlated.end_window.end, - correlated.end_window.use_index_end, - default_to_index_end=False, - ) - correlated_end = _correlated_window_value( - correlated_events, - correlated.end_window.use_event_end, - default="end", - ) - if lower is not None: - cond &= correlated_end >= lower - if upper is not None: - cond &= correlated_end <= upper - - return cond - - -def _apply_endpoint_anchor( - events: ir.Table, - endpoint, - use_index_end: bool | None, - *, - default_to_index_end: bool = False, -): - anchor = ( - events.end_date - if (use_index_end or (use_index_end is None and default_to_index_end)) - else events.start_date - ) - if not endpoint or endpoint.days is None: - return None - days = ibis.interval(days=int(endpoint.days)) - coeff = endpoint.coeff if endpoint.coeff is not None else 1 - return anchor + days * coeff - - -def _correlated_window_value( - correlated_events: ir.Table, - use_event_end: bool | None, - *, - default: str, -) -> ir.Value: - if use_event_end is True: - return correlated_events._corr_end_date - if use_event_end is False: - return correlated_events._corr_start_date - if default == "end": - return correlated_events._corr_end_date - return correlated_events._corr_start_date - - -_COUNT_COLUMN_MAPPING: dict[CriteriaColumn, str] = { - CriteriaColumn.START_DATE: "_corr_start_date", - CriteriaColumn.END_DATE: "_corr_end_date", - CriteriaColumn.VISIT_ID: "_corr_visit_occurrence_id", - CriteriaColumn.DOMAIN_CONCEPT: "_corr_domain_concept_id", - CriteriaColumn.DOMAIN_SOURCE_CONCEPT: "_corr_domain_source_concept_id", -} - - -_COUNT_COLUMN_SOURCES: dict[CriteriaColumn, Callable[[Criteria], str]] = { - CriteriaColumn.DOMAIN_CONCEPT: lambda criteria: criteria.get_concept_id_column(), - CriteriaColumn.DOMAIN_SOURCE_CONCEPT: lambda criteria: _source_concept_column(criteria), -} - - -def _resolve_count_column(occurrence): - if occurrence is None or occurrence.count_column is None: - return None, None - column = occurrence.count_column - enum_value: CriteriaColumn | None = None - if isinstance(column, CriteriaColumn): - enum_value = column - else: - value = str(column) - if value.upper() in CriteriaColumn.__members__: - enum_value = CriteriaColumn[value.upper()] - else: - lower = value.lower() - for member in CriteriaColumn: - if member.value == lower: - enum_value = member - break - if enum_value is None: - return None, None - return _COUNT_COLUMN_MAPPING.get(enum_value), enum_value - - -def _source_concept_column(criteria) -> str: - prefix = criteria.snake_case_class_name().split("_")[0] - return f"{prefix}_source_concept_id" - - -def _attach_count_columns( - events: ir.Table, - criteria_model, - ctx: BuildContext, - *, - count_column_name: str | None, - count_column_enum: CriteriaColumn | None, -) -> ir.Table: - if not count_column_name or not count_column_enum: - return events - source_getter = _COUNT_COLUMN_SOURCES.get(count_column_enum) - if source_getter is None: - return events - source_column = source_getter(criteria_model) - if source_column is None: - return events - table_name = criteria_model.snake_case_class_name() - try: - domain_table = ctx.table(table_name) - except ( - ibis_exc.IbisError, - TypeError, - ValueError, - AttributeError, - NotImplementedError, - ): - return events - if source_column not in domain_table.columns: - return events - primary_key = criteria_model.get_primary_key_column() - if primary_key not in domain_table.columns: - return events - lookup = domain_table.select( - domain_table[primary_key].name("_corr_join_key"), - domain_table[source_column].name(count_column_name), - ) - augmented = events.join(lookup, events.event_id == lookup._corr_join_key, how="left") - base_columns = events.columns - projection = [augmented[name] for name in base_columns if name in augmented.columns] - projection.append(augmented[count_column_name]) - return augmented.select(*projection) - - -def _requires_observation_period_end_alignment(correlated: CorrelatedCriteria) -> bool: - if correlated.start_window and correlated.start_window.use_event_end: - return True - if correlated.end_window and correlated.end_window.use_event_end: - return True - occurrence = correlated.occurrence - if occurrence and occurrence.count_column is not None: - resolved, _ = _resolve_count_column(occurrence) - return resolved == "_corr_end_date" - return False diff --git a/circe/execution/builders/measurement.py b/circe/execution/builders/measurement.py deleted file mode 100644 index 2659b519..00000000 --- a/circe/execution/builders/measurement.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import annotations - -import ibis - -from ...cohortdefinition.criteria import Measurement -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_numeric_range, - apply_provider_specialty_filter, - apply_visit_concept_filters, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("Measurement") -def build_measurement(criteria: Measurement, ctx: BuildContext): - table = ctx.table("measurement") - concept_column = criteria.get_concept_id_column() - table = apply_codeset_filter(table, concept_column, criteria.codeset_id, ctx) - if criteria.first: - table = apply_first_event(table, criteria.get_start_date_column(), criteria.get_primary_key_column()) - - table = apply_date_range(table, criteria.get_start_date_column(), criteria.occurrence_start_date) - table = apply_date_range(table, criteria.get_end_date_column(), criteria.occurrence_end_date) - - table = apply_concept_criteria( - table, - column="measurement_type_concept_id", - concepts=criteria.measurement_type, - selection=criteria.measurement_type_cs, - ctx=ctx, - exclude=bool(criteria.measurement_type_exclude), - ) - - table = apply_concept_criteria( - table, - column="operator_concept_id", - concepts=getattr(criteria, "operator_concept", None), - selection=getattr(criteria, "operator_concept_cs", None), - ctx=ctx, - ) - - value_column = "value_as_number" - if criteria.unit: - table = apply_concept_criteria( - table, - column="unit_concept_id", - concepts=criteria.unit, - selection=None, - ctx=ctx, - ) - table, value_column = _maybe_normalize_units(table, criteria.unit, criteria.value_as_number) - table = apply_concept_criteria( - table, - column="unit_concept_id", - concepts=None, - selection=criteria.unit_cs, - ctx=ctx, - ) - - table = apply_concept_criteria( - table, - column="value_as_concept_id", - concepts=criteria.value_as_concept, - selection=criteria.value_as_concept_cs, - ctx=ctx, - ) - - table = apply_numeric_range(table, value_column, criteria.value_as_number) - table = apply_numeric_range(table, "range_low", criteria.range_low) - table = apply_numeric_range(table, "range_high", criteria.range_high) - if getattr(criteria, "range_low_ratio", None): - denom = ibis.ifelse(table.range_low == 0, ibis.null(), table.range_low) - ratio = (table.value_as_number / denom).name("_range_low_ratio") - table = table.mutate(_range_low_ratio=ratio) - table = apply_numeric_range(table, "_range_low_ratio", criteria.range_low_ratio) - if getattr(criteria, "range_high_ratio", None): - denom = ibis.ifelse(table.range_high == 0, ibis.null(), table.range_high) - ratio = (table.value_as_number / denom).name("_range_high_ratio") - table = table.mutate(_range_high_ratio=ratio) - table = apply_numeric_range(table, "_range_high_ratio", criteria.range_high_ratio) - - if getattr(criteria, "abnormal", None): - abnormal_predicate = ( - (table.value_as_number < table.range_low) - | (table.value_as_number > table.range_high) - | table.value_as_concept_id.isin([4155142, 4155143]) - ) - table = table.filter(abnormal_predicate) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - table = apply_provider_specialty_filter( - table, - getattr(criteria, "provider_specialty", None), - getattr(criteria, "provider_specialty_cs", None), - ctx, - provider_column="provider_id", - ) - table = apply_visit_concept_filters(table, criteria.visit_type, criteria.visit_type_cs, ctx) - if criteria.measurement_source_concept is not None: - table = apply_codeset_filter( - table, - "measurement_source_concept_id", - criteria.measurement_source_concept, - ctx, - ) - - events = standardize_output( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) - - -def _maybe_normalize_units(table, units, value_range): - """ - Best-effort unit normalization for numeric comparisons. - - Circe generally relies on unit-specific criteria rows (separate thresholds per unit scale). - Normalizing in that situation breaks parity (e.g. neutrophil counts expressed as 10..1500 cells/uL). - - Strategy: - - Always normalize mass to kilograms (pounds -> kg). - - For cell counts, only normalize when the numeric range appears to be in the canonical 10^9/L scale. - Heuristic: upper bound <= 100. - """ - unit_ids = [concept.concept_id for concept in units if concept.concept_id is not None] - if not unit_ids: - return table, "value_as_number" - if not all(unit_id in _UNIT_NORMALIZATION for unit_id in unit_ids): - return table, "value_as_number" - groups = {_UNIT_NORMALIZATION[unit_id][0] for unit_id in unit_ids} - if len(groups) != 1: - return table, "value_as_number" - - group = next(iter(groups)) - if group == "mass_kg": - should_normalize = True - elif group == "count_10e9_per_l": - should_normalize = _range_looks_like_canonical_cell_count(value_range) - else: - should_normalize = False - - if not should_normalize: - return table, "value_as_number" - - multiplier = _unit_multiplier_expr(table.unit_concept_id, unit_ids) - normalized = (table.value_as_number * multiplier).name("_normalized_value") - table = table.mutate(_normalized_value=normalized) - return table, "_normalized_value" - - -def _range_looks_like_canonical_cell_count(value_range) -> bool: - if value_range is None or value_range.value is None: - return False - op = (value_range.op or "eq").lower() - upper = float(value_range.value) - if op.endswith("bt") and value_range.extent is not None: - upper = max(upper, float(value_range.extent)) - # Canonical 10^9/L scale is typically << 100; high thresholds indicate raw unit ranges. - return upper <= 100.0 - - -def _unit_multiplier_expr(unit_column, unit_ids): - multiplier_expr = ibis.literal(1.0) - for unit_id in unit_ids: - multiplier = _UNIT_NORMALIZATION[unit_id][1] - multiplier_expr = ibis.ifelse( - unit_column == ibis.literal(unit_id), - ibis.literal(multiplier), - multiplier_expr, - ) - return multiplier_expr - - -_UNIT_NORMALIZATION = { - # Mass - 9529: ("mass_kg", 1.0), # kilogram - 3195625: ("mass_kg", 0.45359237), # pound - # Cell counts per liter (expressed in 10^9/L) - 9444: ("count_10e9_per_l", 1.0), # billion per liter - 44777588: ("count_10e9_per_l", 1.0), - 8848: ("count_10e9_per_l", 1.0), # thousand per microliter - 8816: ("count_10e9_per_l", 1.0), # million per milliliter - 8961: ("count_10e9_per_l", 1.0), # thousand per cubic millimeter - 8784: ("count_10e9_per_l", 0.001), # cells per microliter - 8647: ("count_10e9_per_l", 0.001), # per microliter -} diff --git a/circe/execution/builders/observation.py b/circe/execution/builders/observation.py deleted file mode 100644 index 93dff5bc..00000000 --- a/circe/execution/builders/observation.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import Observation -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_numeric_range, - apply_provider_specialty_filter, - apply_text_filter, - apply_visit_concept_filters, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("Observation") -def build_observation(criteria: Observation, ctx: BuildContext): - table = ctx.table("observation") - table = apply_codeset_filter(table, criteria.get_concept_id_column(), criteria.codeset_id, ctx) - - table = apply_date_range(table, criteria.get_start_date_column(), criteria.occurrence_start_date) - table = apply_date_range(table, criteria.get_end_date_column(), criteria.occurrence_end_date) - - table = apply_concept_criteria( - table, - column="observation_type_concept_id", - concepts=criteria.observation_type, - selection=criteria.observation_type_cs, - ctx=ctx, - exclude=bool(criteria.observation_type_exclude), - ) - - table = apply_concept_criteria( - table, - column="qualifier_concept_id", - concepts=criteria.qualifier, - selection=criteria.qualifier_cs, - ctx=ctx, - ) - - table = apply_concept_criteria( - table, - column="unit_concept_id", - concepts=criteria.unit, - selection=criteria.unit_cs, - ctx=ctx, - ) - - table = apply_concept_criteria( - table, - column="value_as_concept_id", - concepts=criteria.value_as_concept, - selection=criteria.value_as_concept_cs, - ctx=ctx, - ) - - table = apply_numeric_range(table, "value_as_number", criteria.value_as_number) - table = apply_text_filter(table, "value_as_string", criteria.value_as_string) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - table = apply_provider_specialty_filter( - table, - getattr(criteria, "provider_specialty", None), - getattr(criteria, "provider_specialty_cs", None), - ctx, - provider_column="provider_id", - ) - table = apply_visit_concept_filters(table, criteria.visit_type, criteria.visit_type_cs, ctx) - if criteria.observation_source_concept is not None: - table = apply_codeset_filter( - table, - "observation_source_concept_id", - criteria.observation_source_concept, - ctx, - ) - - if criteria.first: - table = apply_first_event(table, criteria.get_start_date_column(), criteria.get_primary_key_column()) - - events = standardize_output( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/observation_period.py b/circe/execution/builders/observation_period.py deleted file mode 100644 index e5e396f2..00000000 --- a/circe/execution/builders/observation_period.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import ObservationPeriod -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_interval_range, - apply_user_defined_period, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("ObservationPeriod") -def build_observation_period(criteria: ObservationPeriod, ctx: BuildContext): - table = ctx.table("observation_period") - - table = apply_date_range(table, "observation_period_start_date", criteria.period_start_date) - table = apply_date_range(table, "observation_period_end_date", criteria.period_end_date) - - table = apply_concept_criteria( - table, - column="period_type_concept_id", - concepts=criteria.period_type, - selection=criteria.period_type_cs, - ctx=ctx, - ) - - table = apply_interval_range( - table, - "observation_period_start_date", - "observation_period_end_date", - criteria.period_length, - ) - - if criteria.age_at_start: - table = apply_age_filter(table, criteria.age_at_start, ctx, "observation_period_start_date") - if criteria.age_at_end: - table = apply_age_filter(table, criteria.age_at_end, ctx, "observation_period_end_date") - - table, start_column, end_column = apply_user_defined_period( - table, - "observation_period_start_date", - "observation_period_end_date", - criteria.user_defined_period, - ) - - if criteria.first: - table = apply_first_event(table, start_column, "observation_period_id") - - events = standardize_output( - table, - primary_key="observation_period_id", - start_column=start_column, - end_column=end_column, - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/payer_plan_period.py b/circe/execution/builders/payer_plan_period.py deleted file mode 100644 index 766160fd..00000000 --- a/circe/execution/builders/payer_plan_period.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import PayerPlanPeriod -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_interval_range, - apply_user_defined_period, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("PayerPlanPeriod") -def build_payer_plan_period(criteria: PayerPlanPeriod, ctx: BuildContext): - table = ctx.table("payer_plan_period") - - table = apply_date_range(table, "payer_plan_period_start_date", criteria.period_start_date) - table = apply_date_range(table, "payer_plan_period_end_date", criteria.period_end_date) - - table = apply_interval_range( - table, - "payer_plan_period_start_date", - "payer_plan_period_end_date", - criteria.period_length, - ) - - if criteria.age_at_start: - table = apply_age_filter(table, criteria.age_at_start, ctx, "payer_plan_period_start_date") - if criteria.age_at_end: - table = apply_age_filter(table, criteria.age_at_end, ctx, "payer_plan_period_end_date") - - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - table = apply_codeset_filter(table, "payer_concept_id", criteria.payer_concept, ctx) - table = apply_codeset_filter(table, "plan_concept_id", criteria.plan_concept, ctx) - table = apply_codeset_filter(table, "sponsor_concept_id", criteria.sponsor_concept, ctx) - table = apply_codeset_filter(table, "stop_reason_concept_id", criteria.stop_reason_concept, ctx) - table = apply_codeset_filter(table, "payer_source_concept_id", criteria.payer_source_concept, ctx) - table = apply_codeset_filter(table, "plan_source_concept_id", criteria.plan_source_concept, ctx) - table = apply_codeset_filter(table, "sponsor_source_concept_id", criteria.sponsor_source_concept, ctx) - table = apply_codeset_filter( - table, "stop_reason_source_concept_id", criteria.stop_reason_source_concept, ctx - ) - - table, start_column, end_column = apply_user_defined_period( - table, - "payer_plan_period_start_date", - "payer_plan_period_end_date", - criteria.user_defined_period, - ) - - if criteria.first: - table = apply_first_event(table, start_column, "payer_plan_period_id") - - events = standardize_output( - table, - primary_key="payer_plan_period_id", - start_column=start_column, - end_column=end_column, - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/pipeline.py b/circe/execution/builders/pipeline.py deleted file mode 100644 index 8b0b3b4d..00000000 --- a/circe/execution/builders/pipeline.py +++ /dev/null @@ -1,168 +0,0 @@ -from __future__ import annotations - -import ibis -import ibis.common.exceptions as ibis_exc -import ibis.expr.types as ir -import polars as pl - -from ...cohortdefinition import CohortExpression -from ..build_context import BuildContext -from .common import ( - apply_end_strategy, - apply_observation_window, - collapse_events, - has_end_strategy, -) -from .groups import apply_criteria_group -from .post_processing import apply_censor_window, apply_censoring, apply_inclusion_rules -from .registry import build_events - -OUTPUT_SCHEMA = { - "person_id": pl.Int64, - "event_id": pl.Int64, - "start_date": pl.Datetime, - "end_date": pl.Datetime, - "visit_occurrence_id": pl.Int64, -} - - -def build_primary_events(expression: CohortExpression, ctx: BuildContext): - def _maybe_materialize(table: ir.Table, label: str) -> ir.Table: - return ctx.maybe_materialize(table, label=label, analyze=True) - - primary = expression.primary_criteria - if primary is None or not primary.criteria_list: - return None - event_tables: list[ir.Table] = [] - for criteria in primary.criteria_list: - table = build_events(criteria, ctx) - if table is None: - continue - event_tables.append(table) - if not event_tables: - return None - if ctx.should_materialize_stages(): - materialized: list[ir.Table] = [] - for idx, table in enumerate(event_tables, start=1): - materialized.append(ctx.maybe_materialize(table, label=f"primary_src_{idx}", analyze=True)) - event_tables = materialized - events = event_tables[0] - for table in event_tables[1:]: - events = events.union(table, distinct=False) - events = events.mutate(_source_event_id=events.event_id) - events = apply_observation_window(events, primary.observation_window, ctx) - events = _assign_primary_event_ids(events) - if _should_limit(primary.primary_limit): - events = _apply_result_limit(events, primary.primary_limit) - - events = ctx.maybe_materialize(events, label="primary_events", analyze=True) - - # Short-circuit the remainder of the pipeline when no primary events exist. - if ctx.should_materialize_stages(): - try: - primary_count = events.count().execute() - except (ibis_exc.IbisError, RuntimeError, ValueError, TypeError): - primary_count = None - if primary_count == 0: - events = _drop_aux_columns(events) - return events.limit(0) - - events = apply_criteria_group(events, expression.additional_criteria, ctx) - if expression.additional_criteria: - events = ctx.maybe_materialize(events, label="additional_criteria", analyze=True) - - events = apply_inclusion_rules(events, expression.inclusion_rules, ctx) - if expression.inclusion_rules: - events = ctx.maybe_materialize(events, label="inclusion", analyze=True) - # Circe ignores QualifiedLimit, so we do the same to preserve parity. - if _should_limit(expression.expression_limit): - events = _apply_result_limit(events, expression.expression_limit) - events = apply_end_strategy(events, expression.end_strategy, ctx) - if has_end_strategy(expression.end_strategy): - events = _maybe_materialize(events, label="strategy_ends") - - # Censoring should cut the cohort end date, so apply it after end strategy. - events = apply_censoring(events, expression.censoring_criteria, ctx) - if expression.censoring_criteria: - events = ctx.maybe_materialize(events, label="censoring", analyze=True) - events = apply_censor_window(events, expression.censor_window, ctx) - events = _drop_aux_columns(events) - events = collapse_events(events, expression.collapse_settings) - if expression.collapse_settings and expression.collapse_settings.collapse_type: - events = _maybe_materialize(events, label="final_cohort") - return events - - -def build_primary_events_polars(expression: CohortExpression, ctx: BuildContext) -> pl.DataFrame: - events = build_primary_events(expression, ctx) - if events is None: - return pl.DataFrame(schema=OUTPUT_SCHEMA) - return events.to_polars() - - -def _assign_primary_event_ids(events): - if "_source_event_id" not in events.columns: - events = events.mutate(_source_event_id=events.event_id) - order = [events.person_id, events.start_date, events._source_event_id] - person_window = ibis.window(group_by=events.person_id, order_by=order[1:]) - person_rank = ibis.row_number().over(person_window) - events = events.mutate( - # Keep event ids unique *within* a person to avoid global sorts/shuffles. - # Most downstream logic keys by (person_id, event_id). - event_id=(person_rank + 1), - _person_ordinal=(person_rank + 1), - ) - supplemental = [ - events[column] - for column in ("observation_period_start_date", "observation_period_end_date") - if column in events.columns - ] - return events.select( - events.person_id, - events.event_id, - events.start_date, - events.end_date, - events.visit_occurrence_id, - events._source_event_id, - events._person_ordinal, - *supplemental, - ) - - -def _apply_result_limit(events: ir.Table, limit) -> ir.Table: - if not limit or (limit.type or "ALL").lower() == "all": - return events - - order_by = [events.start_date] - if "event_id" in events.columns: - order_by.append(events.event_id) - - w = ibis.window(group_by=events.person_id, order_by=order_by) - - helper = "__mitos_rn__" - - ranked = events.mutate(**{helper: ibis.row_number().over(w)}) - limited = ranked.filter(ranked[helper] == 0) - - return limited.select([limited[c] for c in events.columns]) - - -def _drop_aux_columns(events: ir.Table) -> ir.Table: - drop_cols = [ - col - for col in ( - "_source_event_id", - "_person_ordinal", - "observation_period_start_date", - "observation_period_end_date", - "_result_row", - ) - if col in events.columns - ] - if drop_cols: - events = events.drop(*drop_cols) - return events - - -def _should_limit(limit) -> bool: - return bool(limit and (limit.type or "all").lower() != "all") diff --git a/circe/execution/builders/post_processing.py b/circe/execution/builders/post_processing.py deleted file mode 100644 index d95bbc0f..00000000 --- a/circe/execution/builders/post_processing.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -import ibis -import ibis.expr.types as ir - -from ...cohortdefinition.criteria import Criteria, InclusionRule -from ..build_context import BuildContext -from .groups import apply_criteria_group -from .registry import build_events - - -def apply_additional_criteria(events: ir.Table, group, ctx: BuildContext) -> ir.Table: - return apply_criteria_group(events, group, ctx) - - -def apply_inclusion_rules(events: ir.Table, rules: list[InclusionRule], ctx: BuildContext) -> ir.Table: - if not rules: - return events - - base_events = events.select(events.person_id, events.event_id) - bit_hits = [] - used_bits: list[int] = [] - for idx, rule in enumerate(rules): - rule_events = apply_criteria_group(events, rule.expression, ctx) - if rule_events is None: - continue - bit_value = 1 << idx - used_bits.append(bit_value) - bit_hits.append( - rule_events.select( - rule_events.person_id, - rule_events.event_id, - ibis.literal(bit_value, type="int64").name("_rule_bit"), - ).distinct() - ) - if not bit_hits: - return events - - union_hits = bit_hits[0] - for table in bit_hits[1:]: - union_hits = union_hits.union(table, distinct=False) - - union_hits = ctx.maybe_materialize(union_hits, label="inclusion_hits", analyze=True) - - mask = union_hits.group_by(union_hits.person_id, union_hits.event_id).aggregate( - # Postgres returns NUMERIC for SUM(BIGINT), which breaks bitwise ops. - # Ibis also infers SUM(int64) -> int64 and may optimize away an int64 cast, - # so we force an intermediate cast to keep the SQL-level cast. - _rule_mask=union_hits._rule_bit.sum().cast("decimal(38,0)").cast("int64") - ) - target_mask = sum(used_bits) - target_literal = ibis.literal(target_mask, type="int64") - mask = mask.filter((mask._rule_mask & target_literal) == target_literal) - - filtered_ids = base_events.inner_join(mask, ["person_id", "event_id"]) - return events.inner_join(filtered_ids, ["person_id", "event_id"]).select(events.columns) - - -def apply_censoring(events: ir.Table, criteria_list: list[Criteria], ctx: BuildContext) -> ir.Table: - if not criteria_list: - return events - censor_tables = [build_events(criteria, ctx) for criteria in criteria_list if criteria] - if not censor_tables: - return events - censor_events = censor_tables[0] - for table in censor_tables[1:]: - censor_events = censor_events.union(table) - - censor_events = censor_events.select( - censor_events.person_id, - censor_events.start_date.name("censor_start"), - ) - joined = events.join( - censor_events, - (events.person_id == censor_events.person_id) & (censor_events.censor_start >= events.start_date), - how="left", - ) - min_censor = joined.group_by(joined.person_id, joined.event_id).aggregate( - censor_date=joined.censor_start.min() - ) - event_columns = events.columns - events = events.left_join( - min_censor, - (events.person_id == min_censor.person_id) & (events.event_id == min_censor.event_id), - ) - events = events.select(*event_columns, min_censor.censor_date) - events = events.mutate( - end_date=ibis.ifelse( - events.censor_date.notnull() & (events.censor_date < events.end_date), - events.censor_date, - events.end_date, - ) - ).select(*event_columns) - return events - - -def apply_censor_window(events: ir.Table, window, ctx: BuildContext) -> ir.Table: - if not window: - return events - if window.start_date: - events = events.filter(events.start_date >= ibis.timestamp(window.start_date)) - if window.end_date: - events = events.filter(events.end_date <= ibis.timestamp(window.end_date)) - return events diff --git a/circe/execution/builders/procedure_occurrence.py b/circe/execution/builders/procedure_occurrence.py deleted file mode 100644 index f09c1d92..00000000 --- a/circe/execution/builders/procedure_occurrence.py +++ /dev/null @@ -1,75 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import ProcedureOccurrence -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_numeric_range, - apply_provider_specialty_filter, - apply_visit_concept_filters, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("ProcedureOccurrence") -def build_procedure_occurrence(criteria: ProcedureOccurrence, ctx: BuildContext): - table = ctx.table("procedure_occurrence") - - concept_column = criteria.get_concept_id_column() - table = apply_codeset_filter(table, concept_column, criteria.codeset_id, ctx) - if criteria.first: - table = apply_first_event(table, criteria.get_start_date_column(), criteria.get_primary_key_column()) - - table = apply_date_range(table, criteria.get_start_date_column(), criteria.occurrence_start_date) - table = apply_date_range(table, criteria.get_end_date_column(), criteria.occurrence_end_date) - - table = apply_concept_criteria( - table, - column="procedure_type_concept_id", - concepts=criteria.procedure_type, - selection=criteria.procedure_type_cs, - ctx=ctx, - exclude=bool(criteria.procedure_type_exclude), - ) - - table = apply_concept_criteria( - table, - column="modifier_concept_id", - concepts=criteria.modifier, - selection=criteria.modifier_cs, - ctx=ctx, - ) - - table = apply_numeric_range(table, "quantity", criteria.quantity) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - table = apply_provider_specialty_filter( - table, - getattr(criteria, "provider_specialty", None), - getattr(criteria, "provider_specialty_cs", None), - ctx, - provider_column="provider_id", - ) - table = apply_visit_concept_filters(table, criteria.visit_type, criteria.visit_type_cs, ctx) - - if criteria.procedure_source_concept is not None: - table = apply_codeset_filter( - table, "procedure_source_concept_id", criteria.procedure_source_concept, ctx - ) - - events = standardize_output( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/registry.py b/circe/execution/builders/registry.py deleted file mode 100644 index 68a0a633..00000000 --- a/circe/execution/builders/registry.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import hashlib -from collections.abc import Callable - -import ibis.expr.types as ir - -from ...cohortdefinition.criteria import Criteria -from ..build_context import BuildContext - -_REGISTRY: dict[str, Callable[[Criteria, BuildContext], ir.Table]] = {} - - -def register(criteria_name: str): - def decorator(func: Callable[[Criteria, BuildContext], ir.Table]): - _REGISTRY[criteria_name] = func - return func - - return decorator - - -def get_builder(criteria: Criteria): - name = criteria.__class__.__name__ - try: - return _REGISTRY[name] - except KeyError as exc: - raise ValueError(f"No builder registered for criteria {name}") from exc - - -def build_events(criteria: Criteria, ctx: BuildContext) -> ir.Table: - builder = get_builder(criteria) - table = builder(criteria, ctx) - cache_key, label = _criteria_cache_key(criteria) - return ctx.get_or_materialize_slice(cache_key, table, label=label) - - -def _criteria_cache_key(criteria: Criteria) -> tuple[str, str]: - payload = criteria.model_dump_json( - by_alias=True, - exclude_defaults=False, - exclude_none=False, - ) - raw_key = f"{criteria.__class__.__name__}:{payload}" - digest = hashlib.sha1(raw_key.encode("utf-8")).hexdigest()[:8] - label = f"{criteria.__class__.__name__.lower()}_{digest}" - return raw_key, label diff --git a/circe/execution/builders/specimen.py b/circe/execution/builders/specimen.py deleted file mode 100644 index 4bb7c9bc..00000000 --- a/circe/execution/builders/specimen.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import Specimen -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_numeric_range, - apply_text_filter, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("Specimen") -def build_specimen(criteria: Specimen, ctx: BuildContext): - table = ctx.table("specimen") - - table = apply_codeset_filter(table, "specimen_concept_id", criteria.codeset_id, ctx) - table = apply_date_range(table, "specimen_date", criteria.occurrence_start_date) - - table = apply_concept_criteria( - table, - column="specimen_type_concept_id", - concepts=criteria.specimen_type, - selection=criteria.specimen_type_cs, - ctx=ctx, - exclude=bool(criteria.specimen_type_exclude), - ) - - table = apply_numeric_range(table, "quantity", criteria.quantity) - - table = apply_concept_criteria( - table, - column="unit_concept_id", - concepts=criteria.unit, - selection=criteria.unit_cs, - ctx=ctx, - ) - - table = apply_concept_criteria( - table, - column="anatomic_site_concept_id", - concepts=criteria.anatomic_site, - selection=criteria.anatomic_site_cs, - ctx=ctx, - ) - - table = apply_concept_criteria( - table, - column="disease_status_concept_id", - concepts=criteria.disease_status, - selection=criteria.disease_status_cs, - ctx=ctx, - ) - - table = apply_text_filter(table, "specimen_source_id", criteria.source_id) - if criteria.specimen_source_concept is not None: - table = apply_codeset_filter( - table, - "specimen_source_concept_id", - criteria.specimen_source_concept, - ctx, - ) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, "specimen_date") - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - if criteria.first: - table = apply_first_event(table, "specimen_date", "specimen_id") - - events = standardize_output( - table, - primary_key="specimen_id", - start_column="specimen_date", - end_column="specimen_date", - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/visit_detail.py b/circe/execution/builders/visit_detail.py deleted file mode 100644 index 5e8b075a..00000000 --- a/circe/execution/builders/visit_detail.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import VisitDetail -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_care_site_filter, - apply_codeset_filter, - apply_concept_set_selection, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_interval_range, - apply_location_region_filter, - apply_provider_specialty_filter, - project_event_columns, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("VisitDetail") -def build_visit_detail(criteria: VisitDetail, ctx: BuildContext): - table = ctx.table("visit_detail") - - table = apply_codeset_filter(table, "visit_detail_concept_id", criteria.codeset_id, ctx) - if criteria.first: - table = apply_first_event(table, "visit_detail_start_date", "visit_detail_id") - table = apply_date_range(table, "visit_detail_start_date", criteria.visit_detail_start_date) - table = apply_date_range(table, "visit_detail_end_date", criteria.visit_detail_end_date) - table = apply_concept_set_selection( - table, "visit_detail_type_concept_id", criteria.visit_detail_type_cs, ctx - ) - if criteria.visit_detail_source_concept is not None: - table = apply_codeset_filter( - table, - "visit_detail_source_concept_id", - criteria.visit_detail_source_concept, - ctx, - ) - table = apply_interval_range( - table, - "visit_detail_start_date", - "visit_detail_end_date", - criteria.visit_detail_length, - ) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, "visit_detail_end_date") - table = apply_gender_filter(table, [], criteria.gender_cs, ctx) - table = apply_provider_specialty_filter( - table, - None, - criteria.provider_specialty_cs, - ctx, - ) - table = apply_care_site_filter(table, criteria.place_of_service_cs, ctx) - table = apply_location_region_filter( - table, - care_site_column="care_site_id", - location_codeset_id=criteria.place_of_service_location, - start_column="visit_detail_start_date", - end_column="visit_detail_end_date", - ctx=ctx, - ) - - table = project_event_columns( - table, - primary_key="visit_detail_id", - start_column="visit_detail_start_date", - end_column="visit_detail_end_date", - include_visit_occurrence=True, - ) - - events = standardize_output( - table, - primary_key="visit_detail_id", - start_column="visit_detail_start_date", - end_column="visit_detail_end_date", - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/builders/visit_occurrence.py b/circe/execution/builders/visit_occurrence.py deleted file mode 100644 index 1f2e8eff..00000000 --- a/circe/execution/builders/visit_occurrence.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -from ...cohortdefinition.criteria import VisitOccurrence -from ..build_context import BuildContext -from .common import ( - apply_age_filter, - apply_codeset_filter, - apply_concept_criteria, - apply_date_range, - apply_first_event, - apply_gender_filter, - apply_numeric_range, - apply_provider_specialty_filter, - project_event_columns, - standardize_output, -) -from .groups import apply_criteria_group -from .registry import register - - -@register("VisitOccurrence") -def build_visit_occurrence(criteria: VisitOccurrence, ctx: BuildContext): - table = ctx.table("visit_occurrence") - - concept_column = criteria.get_concept_id_column() - table = apply_codeset_filter(table, concept_column, criteria.codeset_id, ctx) - - table = apply_date_range(table, criteria.get_start_date_column(), criteria.occurrence_start_date) - table = apply_date_range(table, criteria.get_end_date_column(), criteria.occurrence_end_date) - - table = apply_concept_criteria( - table, - column="visit_type_concept_id", - concepts=criteria.visit_type, - selection=criteria.visit_type_cs, - ctx=ctx, - exclude=bool(criteria.visit_type_exclude), - ) - - table = apply_provider_specialty_filter( - table, - criteria.provider_specialty, - criteria.provider_specialty_cs, - ctx, - ) - table = apply_concept_criteria( - table, - column="place_of_service_concept_id", - concepts=criteria.place_of_service, - selection=criteria.place_of_service_cs, - ctx=ctx, - ) - if criteria.visit_length: - table = apply_numeric_range(table, "visit_length", criteria.visit_length) - - if criteria.age: - table = apply_age_filter(table, criteria.age, ctx, criteria.get_start_date_column()) - table = apply_gender_filter(table, criteria.gender, criteria.gender_cs, ctx) - - if criteria.visit_source_concept is not None: - table = apply_codeset_filter(table, "visit_source_concept_id", criteria.visit_source_concept, ctx) - - if criteria.first: - table = apply_first_event(table, criteria.get_start_date_column(), criteria.get_primary_key_column()) - - table = project_event_columns( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - include_visit_occurrence=True, - ) - - events = standardize_output( - table, - primary_key=criteria.get_primary_key_column(), - start_column=criteria.get_start_date_column(), - end_column=criteria.get_end_date_column(), - ) - return apply_criteria_group(events, criteria.correlated_criteria, ctx) diff --git a/circe/execution/criteria_compat.py b/circe/execution/criteria_compat.py deleted file mode 100644 index fb52f2b6..00000000 --- a/circe/execution/criteria_compat.py +++ /dev/null @@ -1,203 +0,0 @@ -from __future__ import annotations - -from enum import IntEnum -from typing import Any - -from ..cohortdefinition.criteria import ( - ConditionEra, - ConditionOccurrence, - CorelatedCriteria, - Criteria, - Death, - DemographicCriteria, - DeviceExposure, - DoseEra, - DrugEra, - DrugExposure, - Measurement, - Observation, - ObservationPeriod, - PayerPlanPeriod, - ProcedureOccurrence, - Specimen, - VisitDetail, - VisitOccurrence, -) - -CorrelatedCriteria = CorelatedCriteria -DemoGraphicCriteria = DemographicCriteria - - -class OccurrenceType(IntEnum): - EXACTLY = 0 - AT_MOST = 1 - AT_LEAST = 2 - - -_CONCEPT_ID_OVERRIDES: dict[str, str] = { - "Death": "cause_concept_id", - "DoseEra": "drug_concept_id", - "VisitDetail": "visit_detail_concept_id", -} - -_PRIMARY_KEY_OVERRIDES: dict[str, str] = { - "Death": "person_id", -} - -_START_DATE_OVERRIDES: dict[str, str] = { - "ConditionEra": "condition_era_start_date", - "DrugExposure": "drug_exposure_start_date", - "Measurement": "measurement_date", - "Observation": "observation_date", - "DeviceExposure": "device_exposure_start_date", - "ProcedureOccurrence": "procedure_date", - "DrugEra": "drug_era_start_date", - "DoseEra": "dose_era_start_date", - "ObservationPeriod": "observation_period_start_date", - "Specimen": "specimen_date", - "Death": "death_date", - "VisitDetail": "visit_detail_start_date", - "PayerPlanPeriod": "payer_plan_period_start_date", -} - -_END_DATE_OVERRIDES: dict[str, str] = { - "ConditionEra": "condition_era_end_date", - "DrugExposure": "drug_exposure_end_date", - "Measurement": "measurement_date", - "Observation": "observation_date", - "DeviceExposure": "device_exposure_end_date", - "ProcedureOccurrence": "procedure_date", - "DrugEra": "drug_era_end_date", - "DoseEra": "dose_era_end_date", - "ObservationPeriod": "observation_period_end_date", - "Specimen": "specimen_date", - "Death": "death_date", - "VisitDetail": "visit_detail_end_date", - "PayerPlanPeriod": "payer_plan_period_end_date", -} - - -def _to_snake_case(name: str) -> str: - output: list[str] = [] - for idx, char in enumerate(name): - if char.isupper() and idx > 0: - output.append("_") - output.append(char.lower()) - return "".join(output) - - -def _snake_case_class_name(cls: type[Criteria]) -> str: - return _to_snake_case(cls.__name__) - - -def _get_concept_id_column(self: Criteria) -> str: - cls_name = self.__class__.__name__ - overridden = _CONCEPT_ID_OVERRIDES.get(cls_name) - if overridden: - return overridden - table_name = self.snake_case_class_name() - return f"{table_name.split('_')[0]}_concept_id" - - -def _get_primary_key_column(self: Criteria) -> str: - cls_name = self.__class__.__name__ - overridden = _PRIMARY_KEY_OVERRIDES.get(cls_name) - if overridden: - return overridden - return f"{self.snake_case_class_name()}_id" - - -def _get_start_date_column(self: Criteria) -> str: - cls_name = self.__class__.__name__ - overridden = _START_DATE_OVERRIDES.get(cls_name) - if overridden: - return overridden - return f"{self.snake_case_class_name().split('_')[0]}_start_date" - - -def _get_end_date_column(self: Criteria) -> str: - cls_name = self.__class__.__name__ - overridden = _END_DATE_OVERRIDES.get(cls_name) - if overridden: - return overridden - return f"{self.snake_case_class_name().split('_')[0]}_end_date" - - -def ensure_criteria_compat() -> None: - if getattr(Criteria, "_execution_compat_patched", False): - return - - Criteria.snake_case_class_name = classmethod(_snake_case_class_name) - Criteria.get_concept_id_column = _get_concept_id_column - Criteria.get_primary_key_column = _get_primary_key_column - Criteria.get_start_date_column = _get_start_date_column - Criteria.get_end_date_column = _get_end_date_column - Criteria._execution_compat_patched = True - - -CRITERIA_TYPE_MAP: dict[str, type[Criteria]] = { - "ConditionOccurrence": ConditionOccurrence, - "ConditionEra": ConditionEra, - "VisitOccurrence": VisitOccurrence, - "DrugExposure": DrugExposure, - "DrugEra": DrugEra, - "DoseEra": DoseEra, - "ObservationPeriod": ObservationPeriod, - "Measurement": Measurement, - "Observation": Observation, - "Specimen": Specimen, - "DeviceExposure": DeviceExposure, - "ProcedureOccurrence": ProcedureOccurrence, - "Death": Death, - "VisitDetail": VisitDetail, - "PayerPlanPeriod": PayerPlanPeriod, -} -CRITERIA_TYPE_MAP_CASEFOLD: dict[str, type[Criteria]] = { - name.casefold(): model for name, model in CRITERIA_TYPE_MAP.items() -} - - -def parse_single_criteria(criteria_dict: Any) -> Criteria: - if isinstance(criteria_dict, Criteria): - return criteria_dict - - if not isinstance(criteria_dict, dict): - raise ValueError("Criteria wrapper must be an object.") - - if len(criteria_dict) != 1: - raise ValueError("Criteria wrapper must contain exactly one criteria type key.") - - criteria_type, criteria_data = next(iter(criteria_dict.items())) - model_cls = CRITERIA_TYPE_MAP.get(criteria_type) - if model_cls is None and isinstance(criteria_type, str): - model_cls = CRITERIA_TYPE_MAP_CASEFOLD.get(criteria_type.casefold()) - if model_cls is None: - raise ValueError(f"Unsupported criteria type: {criteria_type}") - - if criteria_data is None: - criteria_data = {} - - if not isinstance(criteria_data, dict): - raise ValueError(f"Criteria payload for {criteria_type} must be an object.") - - return model_cls.model_validate(criteria_data, strict=False) - - -def parse_criteria_list(criteria_list_data: Any) -> list[Criteria]: - if criteria_list_data is None: - return [] - - if not isinstance(criteria_list_data, list): - raise ValueError("Criteria list must be a list.") - - criteria_instances: list[Criteria] = [] - for idx, criteria_dict in enumerate(criteria_list_data): - try: - parsed = parse_single_criteria(criteria_dict) - except ValueError as exc: - raise ValueError(f"Invalid criteria wrapper at index {idx}: {exc}") from exc - criteria_instances.append(parsed) - return criteria_instances - - -ensure_criteria_compat() diff --git a/circe/execution/databricks_compat.py b/circe/execution/databricks_compat.py new file mode 100644 index 00000000..fa8375f7 --- /dev/null +++ b/circe/execution/databricks_compat.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import functools +import inspect +from collections.abc import Callable +from typing import Any + +ISSUE_REFERENCE = "https://github.com/ibis-project/ibis/issues/11598" +_PATCH_FLAG = "_circe_databricks_post_connect_patched" + + +def _databricks_backend_class() -> type[Any] | None: + try: + import ibis.backends.databricks as databricks_backend + except Exception: + return None + return getattr(databricks_backend, "Backend", None) + + +def _post_connect_needs_workaround(post_connect: Callable[..., Any]) -> bool: + try: + source = inspect.getsource(post_connect).lower() + except (OSError, TypeError): + return True + return "create volume if not exists" in source and "memtable" in source + + +def _is_memtable_volume_error(exc: Exception) -> bool: + message = str(exc).lower() + if "create volume if not exists" in message: + return True + return bool("memtable" in message and "volume" in message) + + +def _backend_looks_like_databricks(backend: object) -> bool: + backend_name = getattr(backend, "name", None) + if isinstance(backend_name, str) and backend_name.lower() == "databricks": + return True + class_name = backend.__class__.__name__.lower() + return "databricks" in class_name + + +def apply_databricks_post_connect_workaround( + *, + backend_cls: type[Any] | None = None, +) -> bool: + """ + Patch Databricks backend `_post_connect` for Ibis issue #11598. + + Some Ibis Databricks versions call `CREATE VOLUME IF NOT EXISTS ...` during + `_post_connect` for memtable support and can fail in read-only/locked-down + schemas. This workaround suppresses only that known failure mode and should + be removed once upstream behavior is fixed. + + Activation note: + This helper should be applied lazily by the execution path when a + Databricks backend is actually used. + """ + backend_cls = _databricks_backend_class() if backend_cls is None else backend_cls + if backend_cls is None: + return False + + post_connect = getattr(backend_cls, "_post_connect", None) + if not callable(post_connect): + return False + + if getattr(backend_cls, _PATCH_FLAG, False): + return True + + if not _post_connect_needs_workaround(post_connect): + return False + + @functools.wraps(post_connect) + def _patched_post_connect(self: Any, *args: Any, **kwargs: Any) -> Any: + try: + return post_connect(self, *args, **kwargs) + except Exception as exc: + if _is_memtable_volume_error(exc): + return None + raise + + backend_cls._post_connect = _patched_post_connect + setattr(backend_cls, _PATCH_FLAG, True) + return True + + +def maybe_apply_databricks_post_connect_workaround(backend: object) -> bool: + """Apply the workaround only for Databricks-like backends.""" + if not _backend_looks_like_databricks(backend): + return False + return apply_databricks_post_connect_workaround(backend_cls=backend.__class__) + + +__all__ = [ + "ISSUE_REFERENCE", + "apply_databricks_post_connect_workaround", + "maybe_apply_databricks_post_connect_workaround", +] diff --git a/circe/execution/engine/__init__.py b/circe/execution/engine/__init__.py new file mode 100644 index 00000000..dd3411cc --- /dev/null +++ b/circe/execution/engine/__init__.py @@ -0,0 +1,19 @@ +from .censoring import apply_censoring +from .cohort import build_cohort_table +from .collapse import collapse_events +from .end_strategy import apply_end_strategy +from .groups import apply_additional_criteria +from .inclusion import apply_inclusion_rules +from .limits import apply_result_limit +from .primary import build_primary_events + +__all__ = [ + "build_cohort_table", + "build_primary_events", + "apply_additional_criteria", + "apply_inclusion_rules", + "apply_end_strategy", + "apply_censoring", + "collapse_events", + "apply_result_limit", +] diff --git a/circe/execution/engine/censoring.py b/circe/execution/engine/censoring.py new file mode 100644 index 00000000..ce21f626 --- /dev/null +++ b/circe/execution/engine/censoring.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import ibis + +from ..ibis.compiler import compile_event_plan +from ..lower.criteria import lower_criterion +from ..plan.schema import END_DATE, PERSON_ID +from .end_strategy import attach_observation_bounds + + +def _union_all(tables): + current = tables[0] + for table in tables[1:]: + current = current.union(table, distinct=False) + return current + + +def _compile_censor_events(criteria, ctx): + compiled = [] + for index, criterion in enumerate(criteria): + plan = lower_criterion(criterion, criterion_index=10_000 + index) + table = compile_event_plan(plan, ctx) + compiled.append( + table.select( + table.person_id.cast("int64").name(PERSON_ID), + table.start_date.cast("date").name("censor_start_date"), + ) + ) + if not compiled: + return None + return _union_all(compiled) + + +def apply_censoring(events, criteria, window, ctx): + del window # Censor-window clipping is applied in collapse/finalization stage. + + if not criteria: + return events + + censor_events = _compile_censor_events(criteria, ctx) + if censor_events is None: + return events + + with_bounds = attach_observation_bounds(events, ctx) + + joined = with_bounds.join( + censor_events, + predicates=[with_bounds.person_id == censor_events.person_id], + ) + valid = joined.filter( + (joined.censor_start_date >= joined.start_date) & (joined.censor_start_date <= joined.op_end_date) + ) + censor_min = valid.group_by(valid.person_id, valid.event_id).aggregate( + censor_end_date=valid.censor_start_date.min() + ) + + merged = with_bounds.left_join( + censor_min, + predicates=[ + (with_bounds.person_id == censor_min.person_id) & (with_bounds.event_id == censor_min.event_id) + ], + ) + + new_end = ibis.coalesce( + ibis.least(merged.end_date, merged.censor_end_date), + merged.end_date, + ) + projected = merged.mutate(_new_end_date=new_end) + + return projected.select( + *[ + projected[c] if c != END_DATE else projected._new_end_date.cast("date").name(END_DATE) + for c in events.columns + ] + ) diff --git a/circe/execution/engine/cohort.py b/circe/execution/engine/cohort.py new file mode 100644 index 00000000..7f34652a --- /dev/null +++ b/circe/execution/engine/cohort.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from ..ibis.context import ExecutionContext +from ..lower.criteria import lower_criterion +from ..normalize.cohort import NormalizedCohort +from ..plan.cohort import CohortPlan, PrimaryEventInput +from ..typing import Table +from .censoring import apply_censoring +from .collapse import collapse_events +from .end_strategy import apply_end_strategy +from .groups import apply_additional_criteria +from .inclusion import apply_inclusion_rules +from .limits import apply_result_limit +from .primary import build_primary_events + + +def build_cohort_table(normalized: NormalizedCohort, ctx: ExecutionContext) -> Table: + primary_plans = tuple( + PrimaryEventInput( + event_plan=lower_criterion(criterion, criterion_index=index), + correlated_criteria=criterion.correlated_criteria, + ) + for index, criterion in enumerate(normalized.primary.criteria) + ) + cohort_plan = CohortPlan( + primary_event_plans=primary_plans, + observation_window=normalized.primary.observation_window, + primary_limit_type=normalized.primary.primary_limit_type, + qualified_limit_type=normalized.result_limits.qualified_limit_type, + expression_limit_type=normalized.result_limits.expression_limit_type, + ) + primary_events = build_primary_events(cohort_plan, ctx) + qualified_events = apply_additional_criteria(primary_events, normalized.additional_criteria, ctx) + if normalized.additional_criteria is not None and not normalized.additional_criteria.is_empty(): + qualified_events = apply_result_limit( + qualified_events, + cohort_plan.qualified_limit_type, + ) + included_events = apply_inclusion_rules(qualified_events, normalized.inclusion_rules, ctx) + included_events = apply_result_limit( + included_events, + cohort_plan.expression_limit_type, + ) + ended_events = apply_end_strategy(included_events, normalized.end_strategy, ctx) + censored_events = apply_censoring( + ended_events, + normalized.censoring_criteria, + normalized.censor_window, + ctx, + ) + return collapse_events( + censored_events, + normalized.collapse_settings, + normalized.censor_window, + ) diff --git a/circe/execution/engine/collapse.py b/circe/execution/engine/collapse.py new file mode 100644 index 00000000..c98619d2 --- /dev/null +++ b/circe/execution/engine/collapse.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import ibis + +from ..plan.schema import END_DATE, PERSON_ID, START_DATE + + +def _apply_censor_window(events, censor_window): + if censor_window is None: + return events + + start_expr = events.start_date + end_expr = events.end_date + + if censor_window.start_date: + start_bound = ibis.literal(censor_window.start_date).cast("date") + start_expr = ibis.greatest(events.start_date, start_bound) + + if censor_window.end_date: + end_bound = ibis.literal(censor_window.end_date).cast("date") + end_expr = ibis.least(events.end_date, end_bound) + + clipped = events.mutate(start_date=start_expr, end_date=end_expr) + return clipped.filter(clipped.start_date <= clipped.end_date) + + +def _collapse_era(intervals, era_pad: int): + padded = intervals.mutate(_padded_end_date=(intervals.end_date + ibis.interval(days=int(era_pad)))) + + ordering = [padded.start_date] + ordered_window = ibis.window(group_by=padded.person_id, order_by=ordering) + cumulative_window = ibis.cumulative_window(group_by=padded.person_id, order_by=ordering) + with_cummax = padded.mutate(_cummax_padded_end=padded._padded_end_date.max().over(cumulative_window)) + with_prev = with_cummax.mutate( + _prev_max_padded_end=with_cummax._cummax_padded_end.lag().over(ordered_window) + ) + marked = with_prev.mutate( + _is_new_group=ibis.ifelse( + with_prev._prev_max_padded_end.isnull() | (with_prev._prev_max_padded_end < with_prev.start_date), + ibis.literal(1, type="int64"), + ibis.literal(0, type="int64"), + ) + ) + + grouping_window = ibis.cumulative_window( + group_by=marked.person_id, + order_by=[marked.start_date, marked._is_new_group.desc()], + ) + group_index = marked._is_new_group.sum().over(grouping_window) + grouped = marked.mutate(_group_idx=group_index) + + collapsed = grouped.group_by(grouped.person_id, grouped._group_idx).aggregate( + start_date=grouped.start_date.min(), + _max_padded_end=grouped._padded_end_date.max(), + ) + return collapsed.select( + collapsed.person_id.cast("int64").name(PERSON_ID), + collapsed.start_date.cast("date").name(START_DATE), + (collapsed._max_padded_end - ibis.interval(days=int(era_pad))).cast("date").name(END_DATE), + ) + + +def collapse_events(events, collapse_settings, censor_window): + if collapse_settings is None: + return _apply_censor_window(events, censor_window) + + collapse_type = (collapse_settings.collapse_type or "era").lower() + if collapse_type == "no_collapse": + return _apply_censor_window(events, censor_window) + + intervals = events.select( + events.person_id.cast("int64").name(PERSON_ID), + events.start_date.cast("date").name(START_DATE), + events.end_date.cast("date").name(END_DATE), + ) + intervals = _apply_censor_window(intervals, censor_window) + return _collapse_era(intervals, collapse_settings.era_pad) diff --git a/circe/execution/engine/end_strategy.py b/circe/execution/engine/end_strategy.py new file mode 100644 index 00000000..a099985b --- /dev/null +++ b/circe/execution/engine/end_strategy.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import ibis + +from ..errors import UnsupportedFeatureError +from ..plan.schema import END_DATE, PERSON_ID, START_DATE + + +def attach_observation_bounds(events, ctx): + observation_period = ctx.table("observation_period").select( + PERSON_ID, + "observation_period_start_date", + "observation_period_end_date", + ) + joined = events.join( + observation_period, + (events[PERSON_ID] == observation_period[PERSON_ID]) + & (events[START_DATE] >= observation_period.observation_period_start_date.cast("date")) + & (events[START_DATE] <= observation_period.observation_period_end_date.cast("date")), + ) + return joined.select( + *[joined[c] for c in events.columns], + observation_period.observation_period_start_date.cast("date").name("op_start_date"), + observation_period.observation_period_end_date.cast("date").name("op_end_date"), + ).distinct() + + +def _apply_date_offset_strategy(with_bounds, strategy): + offset = int(strategy.payload.get("offset", 0)) + date_field = str(strategy.payload.get("date_field", START_DATE)).lower() + + if date_field in {"startdate", START_DATE}: + base_date = with_bounds[START_DATE] + elif date_field in {"enddate", END_DATE}: + base_date = with_bounds[END_DATE] + else: + raise UnsupportedFeatureError( + f"Ibis executor end-strategy error: unsupported date_offset date field {date_field!r}." + ) + + candidate = base_date + ibis.interval(days=offset) + return ibis.least(candidate, with_bounds.op_end_date) + + +def _replace_end_date(events, with_bounds, new_end_expr): + projected = with_bounds.mutate(_new_end_date=new_end_expr) + selected = projected.select( + *[ + projected[c] if c != END_DATE else projected._new_end_date.cast("date").name(END_DATE) + for c in events.columns + ] + ) + return selected + + +def apply_end_strategy(events, strategy, ctx): + with_bounds = attach_observation_bounds(events, ctx) + + if strategy is None: + return _replace_end_date(events, with_bounds, with_bounds.op_end_date) + + if strategy.kind == "date_offset": + end_date_expr = _apply_date_offset_strategy(with_bounds, strategy) + return _replace_end_date(events, with_bounds, end_date_expr) + + if strategy.kind == "custom_era": + raise UnsupportedFeatureError("Ibis executor end-strategy error: custom_era is not supported.") + + # Fallback: preserve default semantics of op_end_date clipping. + return _replace_end_date(events, with_bounds, with_bounds.op_end_date) diff --git a/circe/execution/engine/group_demographics.py b/circe/execution/engine/group_demographics.py new file mode 100644 index 00000000..bc5920aa --- /dev/null +++ b/circe/execution/engine/group_demographics.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import ibis + +from ..errors import UnsupportedFeatureError +from ..ibis.context import ExecutionContext +from ..normalize.groups import NormalizedDemographicCriteria +from ..plan.schema import EVENT_ID, PERSON_ID +from ..typing import Table + + +def _apply_numeric_predicate(expr, predicate): + op = (predicate.op or "eq").lower() + value = predicate.value + extent = predicate.extent + + if value is None: + return ibis.literal(True) + + if op in {"eq", "="}: + return expr == value + if op in {"neq", "!=", "ne"}: + return expr != value + if op in {"gt", ">"}: + return expr > value + if op in {"gte", ">="}: + return expr >= value + if op in {"lt", "<"}: + return expr < value + if op in {"lte", "<="}: + return expr <= value + if op in {"bt", "between"}: + if extent is None: + raise UnsupportedFeatureError( + "Ibis executor group evaluation error: demographic numeric range " + "'between' requires an extent value." + ) + lower = min(value, extent) + upper = max(value, extent) + return (expr >= lower) & (expr <= upper) + raise UnsupportedFeatureError( + f"Ibis executor group evaluation error: unsupported demographic numeric range op {predicate.op!r}." + ) + + +def _apply_date_predicate(expr, predicate): + op = (predicate.op or "eq").lower() + value = predicate.value + extent = predicate.extent + + if value is None: + return ibis.literal(True) + + value_expr = ibis.literal(value).cast("date") + date_expr = expr.cast("date") + if op in {"eq", "="}: + return date_expr == value_expr + if op in {"neq", "!=", "ne"}: + return date_expr != value_expr + if op in {"gt", ">"}: + return date_expr > value_expr + if op in {"gte", ">="}: + return date_expr >= value_expr + if op in {"lt", "<"}: + return date_expr < value_expr + if op in {"lte", "<="}: + return date_expr <= value_expr + if op in {"bt", "between"}: + if extent is None: + raise UnsupportedFeatureError( + "Ibis executor group evaluation error: demographic date range " + "'between' requires an extent value." + ) + extent_expr = ibis.literal(extent).cast("date") + lower = ibis.least(value_expr, extent_expr) + upper = ibis.greatest(value_expr, extent_expr) + return (date_expr >= lower) & (date_expr <= upper) + raise UnsupportedFeatureError( + f"Ibis executor group evaluation error: unsupported demographic date range op {predicate.op!r}." + ) + + +def _demographic_concept_ids( + *, + explicit_ids: tuple[int, ...], + codeset_id: int | None, + ctx: ExecutionContext, +) -> tuple[int, ...]: + all_ids = list(explicit_ids) + if codeset_id is not None: + for concept_id in ctx.concept_ids_for_codeset(codeset_id): + if concept_id not in all_ids: + all_ids.append(concept_id) + return tuple(all_ids) + + +def demographic_match_keys( + index_events: Table, + demographic: NormalizedDemographicCriteria, + ctx: ExecutionContext, +) -> Table: + person_table = ctx.table("person") + person = person_table.select( + person_table.person_id.name("p_person_id"), + "year_of_birth", + "gender_concept_id", + "race_concept_id", + "ethnicity_concept_id", + ) + joined = index_events.join(person, index_events.person_id == person.p_person_id) + + predicates = [ibis.literal(True)] + if demographic.age is not None: + event_date = joined.start_date.cast("date") + age_years = event_date.year() - joined.year_of_birth + predicates.append(_apply_numeric_predicate(age_years, demographic.age)) + + gender_ids = _demographic_concept_ids( + explicit_ids=demographic.gender_concept_ids, + codeset_id=demographic.gender_codeset_id, + ctx=ctx, + ) + if gender_ids: + predicates.append(joined.gender_concept_id.isin(gender_ids)) + + race_ids = _demographic_concept_ids( + explicit_ids=demographic.race_concept_ids, + codeset_id=demographic.race_codeset_id, + ctx=ctx, + ) + if race_ids: + predicates.append(joined.race_concept_id.isin(race_ids)) + + ethnicity_ids = _demographic_concept_ids( + explicit_ids=demographic.ethnicity_concept_ids, + codeset_id=demographic.ethnicity_codeset_id, + ctx=ctx, + ) + if ethnicity_ids: + predicates.append(joined.ethnicity_concept_id.isin(ethnicity_ids)) + + if demographic.occurrence_start_date is not None: + predicates.append( + _apply_date_predicate( + joined.start_date, + demographic.occurrence_start_date, + ) + ) + if demographic.occurrence_end_date is not None: + predicates.append( + _apply_date_predicate( + joined.end_date, + demographic.occurrence_end_date, + ) + ) + + predicate = predicates[0] + for part in predicates[1:]: + predicate = predicate & part + + matched = joined.filter(predicate) + return matched.select( + matched.person_id.name(PERSON_ID), + matched.event_id.name(EVENT_ID), + ).distinct() diff --git a/circe/execution/engine/group_keys.py b/circe/execution/engine/group_keys.py new file mode 100644 index 00000000..4de323d6 --- /dev/null +++ b/circe/execution/engine/group_keys.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ..plan.schema import EVENT_ID, PERSON_ID +from ..typing import Table + + +def union_all(tables: list[Table]) -> Table: + current = tables[0] + for table in tables[1:]: + current = current.union(table, distinct=False) + return current + + +def event_keys(events: Table) -> Table: + return events.select( + events.person_id.cast("int64").name(PERSON_ID), + events.event_id.cast("int64").name(EVENT_ID), + ).distinct() diff --git a/circe/execution/engine/group_operators.py b/circe/execution/engine/group_operators.py new file mode 100644 index 00000000..3ed7c8dd --- /dev/null +++ b/circe/execution/engine/group_operators.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import ibis + +from ..errors import UnsupportedFeatureError +from ..ibis.compiler import compile_event_plan +from ..ibis.context import ExecutionContext +from ..lower.criteria import lower_criterion +from ..normalize.groups import NormalizedCorrelatedCriteria +from ..plan.schema import ( + CONCEPT_ID, + DAYS_SUPPLY, + DURATION, + END_DATE, + EVENT_ID, + GAP_DAYS, + OCCURRENCE_COUNT, + PERSON_ID, + QUANTITY, + RANGE_HIGH, + RANGE_LOW, + REFILLS, + SOURCE_CONCEPT_ID, + START_DATE, + UNIT_CONCEPT_ID, + VALUE_AS_NUMBER, + VISIT_DETAIL_ID, + VISIT_OCCURRENCE_ID, +) +from ..typing import Table +from .group_keys import event_keys +from .group_windows import apply_window_constraints + + +def resolve_distinct_count_column(count_column: str | None) -> str: + if count_column is None: + return f"a_{CONCEPT_ID}" + + normalized = count_column.lower() + mapping = { + "domain_concept_id": f"a_{CONCEPT_ID}", + "domain_source_concept_id": f"a_{SOURCE_CONCEPT_ID}", + VISIT_OCCURRENCE_ID: f"a_{VISIT_OCCURRENCE_ID}", + "visit_id": f"a_{VISIT_OCCURRENCE_ID}", + "visit_detail_id": f"a_{VISIT_DETAIL_ID}", + START_DATE: f"a_{START_DATE}", + END_DATE: f"a_{END_DATE}", + "duration": f"a_{DURATION}", + "quantity": f"a_{QUANTITY}", + "days_supply": f"a_{DAYS_SUPPLY}", + "refills": f"a_{REFILLS}", + "range_low": f"a_{RANGE_LOW}", + "range_high": f"a_{RANGE_HIGH}", + "value_as_number": f"a_{VALUE_AS_NUMBER}", + "unit_concept_id": f"a_{UNIT_CONCEPT_ID}", + "occurrence_count": f"a_{OCCURRENCE_COUNT}", + "gap_days": f"a_{GAP_DAYS}", + } + if normalized in mapping: + return mapping[normalized] + + raise UnsupportedFeatureError( + "Ibis executor group evaluation error: unsupported distinct count column " + f"{count_column!r} for correlated criteria." + ) + + +def occurrence_predicate(match_count_expr, occurrence_type: int, occurrence_count: int): + if occurrence_type == 0: + return match_count_expr == occurrence_count + if occurrence_type == 1: + return match_count_expr <= occurrence_count + if occurrence_type == 2: + return match_count_expr >= occurrence_count + raise UnsupportedFeatureError( + f"Ibis executor group evaluation error: unsupported correlated occurrence type {occurrence_type}." + ) + + +def group_predicate(match_count_expr, mode: str, count: int | None, child_count: int): + normalized_mode = (mode or "ALL").upper() + if normalized_mode == "ALL": + return match_count_expr == child_count + if normalized_mode == "ANY": + return match_count_expr > 0 + if normalized_mode == "AT_LEAST": + threshold = 0 if count is None else int(count) + return match_count_expr >= threshold + if normalized_mode == "AT_MOST": + threshold = 0 if count is None else int(count) + return match_count_expr <= threshold + raise UnsupportedFeatureError( + f"Ibis executor group evaluation error: unsupported criteria group mode {mode!r}." + ) + + +def _compile_correlated_events( + correlated: NormalizedCorrelatedCriteria, + *, + criterion_index: int, + ctx: ExecutionContext, +) -> Table: + event_plan = lower_criterion(correlated.criterion, criterion_index=criterion_index) + events = compile_event_plan(event_plan, ctx) + + nested_group = correlated.criterion.correlated_criteria + if nested_group is None or nested_group.is_empty(): + return events + + # Correlated criteria can themselves carry nested correlated criteria. + # Re-apply the same group evaluator used for primary/additional criteria. + from .groups import apply_additional_criteria + + return apply_additional_criteria(events, nested_group, ctx) + + +def correlated_match_keys( + index_events: Table, + correlated: NormalizedCorrelatedCriteria, + *, + criterion_index: int, + ctx: ExecutionContext, +) -> Table: + correlated_events = _compile_correlated_events( + correlated, + criterion_index=criterion_index, + ctx=ctx, + ) + + p = index_events.select( + index_events[PERSON_ID].name("p_person_id"), + index_events[EVENT_ID].name("p_event_id"), + index_events[START_DATE].name("p_start_date"), + index_events[END_DATE].name("p_end_date"), + index_events[VISIT_OCCURRENCE_ID].name("p_visit_occurrence_id"), + index_events.op_start_date.name("p_op_start_date"), + index_events.op_end_date.name("p_op_end_date"), + ) + a = correlated_events.select( + correlated_events[PERSON_ID].name("a_person_id"), + correlated_events[EVENT_ID].name("a_event_id"), + correlated_events[START_DATE].name("a_start_date"), + correlated_events[END_DATE].name("a_end_date"), + correlated_events[VISIT_OCCURRENCE_ID].name("a_visit_occurrence_id"), + correlated_events[VISIT_DETAIL_ID].name("a_visit_detail_id"), + correlated_events[CONCEPT_ID].name("a_concept_id"), + correlated_events[SOURCE_CONCEPT_ID].name("a_source_concept_id"), + correlated_events[QUANTITY].name("a_quantity"), + correlated_events[DAYS_SUPPLY].name("a_days_supply"), + correlated_events[REFILLS].name("a_refills"), + correlated_events[RANGE_LOW].name("a_range_low"), + correlated_events[RANGE_HIGH].name("a_range_high"), + correlated_events[VALUE_AS_NUMBER].name("a_value_as_number"), + correlated_events[UNIT_CONCEPT_ID].name("a_unit_concept_id"), + correlated_events[OCCURRENCE_COUNT].name("a_occurrence_count"), + correlated_events[GAP_DAYS].name("a_gap_days"), + correlated_events[DURATION].name("a_duration"), + ) + + joined = p.join( + a, + predicates=[p.p_person_id == a.a_person_id], + ) + constrained = apply_window_constraints(joined, correlated) + + if correlated.occurrence_is_distinct: + distinct_col = resolve_distinct_count_column(correlated.occurrence_count_column) + counts = constrained.group_by( + constrained.p_person_id, + constrained.p_event_id, + ).aggregate(match_count=constrained[distinct_col].nunique()) + else: + counts = constrained.group_by( + constrained.p_person_id, + constrained.p_event_id, + ).aggregate(match_count=constrained.a_event_id.count()) + + keys = event_keys(index_events) + joined_counts = keys.left_join( + counts, + predicates=[(keys.person_id == counts.p_person_id) & (keys.event_id == counts.p_event_id)], + ) + counted = joined_counts.mutate(match_count=ibis.coalesce(joined_counts.match_count, ibis.literal(0))) + + predicate = occurrence_predicate( + counted.match_count, + int(correlated.occurrence_type), + int(correlated.occurrence_count), + ) + return counted.filter(predicate).select( + counted.person_id.name(PERSON_ID), + counted.event_id.name(EVENT_ID), + ) diff --git a/circe/execution/engine/group_windows.py b/circe/execution/engine/group_windows.py new file mode 100644 index 00000000..e0b126df --- /dev/null +++ b/circe/execution/engine/group_windows.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import ibis + +from ..ibis.context import ExecutionContext +from ..normalize.groups import NormalizedCorrelatedCriteria +from ..normalize.windows import NormalizedWindow, NormalizedWindowBound +from ..plan.schema import END_DATE, EVENT_ID, PERSON_ID, START_DATE, VISIT_OCCURRENCE_ID +from ..typing import Table + + +def attach_observation_period(events: Table, ctx: ExecutionContext) -> Table: + observation_period = ctx.table("observation_period").select( + PERSON_ID, + "observation_period_start_date", + "observation_period_end_date", + ) + + joined = events.join( + observation_period, + (events[PERSON_ID] == observation_period[PERSON_ID]) + & (events[START_DATE] >= observation_period.observation_period_start_date.cast("date")) + & (events[START_DATE] <= observation_period.observation_period_end_date.cast("date")), + ) + + return joined.select( + events[PERSON_ID].name(PERSON_ID), + events[EVENT_ID].name(EVENT_ID), + events[START_DATE].name(START_DATE), + events[END_DATE].name(END_DATE), + events[VISIT_OCCURRENCE_ID].name(VISIT_OCCURRENCE_ID), + observation_period.observation_period_start_date.cast("date").name("op_start_date"), + observation_period.observation_period_end_date.cast("date").name("op_end_date"), + ).distinct() + + +def window_bound_expression( + bound: NormalizedWindowBound | None, + *, + index_anchor_expr, + use_observation_period: bool, + op_start_expr, + op_end_expr, +): + if bound is None: + return None + + if bound.days is not None: + return index_anchor_expr + ibis.interval(days=int(bound.coeff) * int(bound.days)) + + if not use_observation_period: + return None + + return op_start_expr if int(bound.coeff) == -1 else op_end_expr + + +def apply_window_constraints(joined, correlated: NormalizedCorrelatedCriteria): + predicate = joined.a_person_id == joined.p_person_id + + if not correlated.ignore_observation_period: + predicate = predicate & (joined.a_start_date >= joined.p_op_start_date) + predicate = predicate & (joined.a_start_date <= joined.p_op_end_date) + + start_window: NormalizedWindow | None = correlated.start_window + if start_window is not None: + start_index_anchor = joined.p_end_date if bool(start_window.use_index_end) else joined.p_start_date + start_event_date = ( + joined.a_end_date + if (start_window.use_event_end is not None and start_window.use_event_end) + else joined.a_start_date + ) + + start_lower = window_bound_expression( + start_window.start, + index_anchor_expr=start_index_anchor, + use_observation_period=(not correlated.ignore_observation_period), + op_start_expr=joined.p_op_start_date, + op_end_expr=joined.p_op_end_date, + ) + if start_lower is not None: + predicate = predicate & (start_event_date >= start_lower) + + start_upper = window_bound_expression( + start_window.end, + index_anchor_expr=start_index_anchor, + use_observation_period=(not correlated.ignore_observation_period), + op_start_expr=joined.p_op_start_date, + op_end_expr=joined.p_op_end_date, + ) + if start_upper is not None: + predicate = predicate & (start_event_date <= start_upper) + + end_window: NormalizedWindow | None = correlated.end_window + if end_window is not None: + end_index_anchor = joined.p_end_date if bool(end_window.use_index_end) else joined.p_start_date + end_event_date = ( + joined.a_end_date + if (end_window.use_event_end is None or end_window.use_event_end) + else joined.a_start_date + ) + + end_lower = window_bound_expression( + end_window.start, + index_anchor_expr=end_index_anchor, + use_observation_period=(not correlated.ignore_observation_period), + op_start_expr=joined.p_op_start_date, + op_end_expr=joined.p_op_end_date, + ) + if end_lower is not None: + predicate = predicate & (end_event_date >= end_lower) + + end_upper = window_bound_expression( + end_window.end, + index_anchor_expr=end_index_anchor, + use_observation_period=(not correlated.ignore_observation_period), + op_start_expr=joined.p_op_start_date, + op_end_expr=joined.p_op_end_date, + ) + if end_upper is not None: + predicate = predicate & (end_event_date <= end_upper) + + if correlated.restrict_visit: + predicate = predicate & (joined.a_visit_occurrence_id == joined.p_visit_occurrence_id) + + return joined.filter(predicate) diff --git a/circe/execution/engine/groups.py b/circe/execution/engine/groups.py new file mode 100644 index 00000000..d630df29 --- /dev/null +++ b/circe/execution/engine/groups.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import ibis + +from ..ibis.context import ExecutionContext +from ..normalize.groups import NormalizedCriteriaGroup +from ..plan.schema import EVENT_ID, PERSON_ID +from ..typing import Table +from .group_demographics import demographic_match_keys +from .group_keys import event_keys, union_all +from .group_operators import correlated_match_keys, group_predicate +from .group_windows import attach_observation_period + + +def _evaluate_group( + index_events: Table, + group: NormalizedCriteriaGroup, + ctx: ExecutionContext, +) -> Table: + keys = event_keys(index_events) + + if group.is_empty(): + return keys + + child_results: list[Table] = [] + index_id = 0 + + for correlated in group.criteria: + correlated_matches = correlated_match_keys( + index_events, + correlated, + criterion_index=index_id, + ctx=ctx, + ) + child_results.append(correlated_matches.mutate(index_id=ibis.literal(index_id, type="int64"))) + index_id += 1 + + for demographic in group.demographics: + demographic_matches = demographic_match_keys(index_events, demographic, ctx) + child_results.append(demographic_matches.mutate(index_id=ibis.literal(index_id, type="int64"))) + index_id += 1 + + for child_group in group.groups: + child_group_matches = _evaluate_group(index_events, child_group, ctx) + child_results.append(child_group_matches.mutate(index_id=ibis.literal(index_id, type="int64"))) + index_id += 1 + + if not child_results: + return keys + + unioned = union_all(child_results) + group_counts = unioned.group_by(unioned.person_id, unioned.event_id).aggregate( + matched_children=unioned.index_id.nunique() + ) + + joined_counts = keys.left_join( + group_counts, + predicates=[(keys.person_id == group_counts.person_id) & (keys.event_id == group_counts.event_id)], + ) + counted = joined_counts.mutate( + matched_children=ibis.coalesce(joined_counts.matched_children, ibis.literal(0)) + ) + + predicate = group_predicate( + counted.matched_children, + group.mode, + group.count, + index_id, + ) + return counted.filter(predicate).select( + counted.person_id.name(PERSON_ID), + counted.event_id.name(EVENT_ID), + ) + + +def apply_additional_criteria( + events: Table, + group: NormalizedCriteriaGroup | None, + ctx: ExecutionContext, +) -> Table: + if group is None or group.is_empty(): + return events + + index_events = attach_observation_period(events, ctx) + matched_keys = _evaluate_group(index_events, group, ctx) + + filtered = events.join( + matched_keys, + predicates=[ + (events.person_id == matched_keys.person_id) & (events.event_id == matched_keys.event_id) + ], + ) + return filtered.select(*[filtered[c] for c in events.columns]) diff --git a/circe/execution/engine/inclusion.py b/circe/execution/engine/inclusion.py new file mode 100644 index 00000000..7e957842 --- /dev/null +++ b/circe/execution/engine/inclusion.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ..normalize.groups import NormalizedInclusionRule +from .groups import apply_additional_criteria + + +def apply_inclusion_rules( + events, + inclusion_rules: tuple[NormalizedInclusionRule, ...], + ctx, +): + if not inclusion_rules: + return events + + included = events + for rule in inclusion_rules: + included = apply_additional_criteria(included, rule.expression, ctx) + return included diff --git a/circe/execution/engine/limits.py b/circe/execution/engine/limits.py new file mode 100644 index 00000000..e8d3f936 --- /dev/null +++ b/circe/execution/engine/limits.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import ibis + +from ..plan.schema import DOMAIN, EVENT_ID, PERSON_ID, START_DATE + + +def apply_result_limit(events, limit_type: str): + normalized = (limit_type or "all").lower() + if normalized in {"all", ""}: + return events + + descending = normalized == "last" + order_by = [events[START_DATE], events[EVENT_ID]] + if DOMAIN in events.columns: + order_by.append(events[DOMAIN]) + + if descending: + order_by = [expr.desc() for expr in order_by] + + window = ibis.window( + group_by=events[PERSON_ID], + order_by=order_by, + ) + ranked = events.mutate(_limit_rn=ibis.row_number().over(window)) + return ranked.filter(ranked._limit_rn == 0).drop("_limit_rn") diff --git a/circe/execution/engine/primary.py b/circe/execution/engine/primary.py new file mode 100644 index 00000000..07f0ec41 --- /dev/null +++ b/circe/execution/engine/primary.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import ibis + +from ..errors import ExecutionNormalizationError +from ..ibis.compiler import compile_event_plan +from ..ibis.context import ExecutionContext +from ..normalize.windows import NormalizedObservationWindow +from ..plan.cohort import CohortPlan +from ..plan.schema import DOMAIN, EVENT_ID, PERSON_ID, START_DATE +from ..typing import Table +from .groups import apply_additional_criteria +from .limits import apply_result_limit + + +def _union_all(tables): + current = tables[0] + for table in tables[1:]: + current = current.union(table, distinct=False) + return current + + +def _assign_primary_event_ids(events): + ordering = [events[START_DATE], events[EVENT_ID], events[DOMAIN]] + person_window = ibis.window(group_by=events[PERSON_ID], order_by=ordering) + ranked = events.mutate(_primary_rn=ibis.row_number().over(person_window)) + return ranked.mutate(**{EVENT_ID: ranked._primary_rn + 1}).drop("_primary_rn") + + +def _apply_observation_window( + events, + ctx: ExecutionContext, + window: NormalizedObservationWindow, +): + observation_period = ctx.table("observation_period").select( + PERSON_ID, + "observation_period_start_date", + "observation_period_end_date", + ) + joined = events.join( + observation_period, + events[PERSON_ID] == observation_period[PERSON_ID], + ) + lower = joined.observation_period_start_date + ibis.interval(days=window.prior_days) + upper = joined.observation_period_end_date - ibis.interval(days=window.post_days) + filtered = joined.filter((joined[START_DATE] >= lower) & (joined[START_DATE] <= upper)) + return filtered.select(*[filtered[c] for c in events.columns]) + + +def build_primary_events(plan: CohortPlan, ctx: ExecutionContext) -> Table: + if not plan.primary_event_plans: + raise ExecutionNormalizationError( + "Ibis executor primary build error: no primary criteria were lowered to executable plans." + ) + + compiled = [] + for primary in plan.primary_event_plans: + events = compile_event_plan(primary.event_plan, ctx) + events = apply_additional_criteria(events, primary.correlated_criteria, ctx) + compiled.append(events) + + events = _union_all(compiled) + events = _assign_primary_event_ids(events) + + if plan.observation_window is not None: + events = _apply_observation_window(events, ctx, plan.observation_window) + + events = apply_result_limit(events, plan.primary_limit_type) + return events diff --git a/circe/execution/errors.py b/circe/execution/errors.py new file mode 100644 index 00000000..ab2ddba6 --- /dev/null +++ b/circe/execution/errors.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +class ExecutionError(RuntimeError): + """Base execution subsystem error.""" + + +class ExecutionNormalizationError(ExecutionError): + """Raised when expression normalization fails structurally.""" + + +class UnsupportedCriterionError(ExecutionError): + """Raised when a criterion type is unsupported by the executor.""" + + +class UnsupportedFeatureError(ExecutionError): + """Raised when requested executor semantics are unsupported.""" + + +class CompilationError(ExecutionError): + """Raised when lowering/compilation to Ibis cannot proceed.""" diff --git a/circe/execution/ibis.py b/circe/execution/ibis.py deleted file mode 100644 index e0d3d2e7..00000000 --- a/circe/execution/ibis.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Experimental ibis execution API.""" - -from __future__ import annotations - -from dataclasses import replace -from typing import TYPE_CHECKING, Any - -from ..io import ExpressionInput, load_expression -from .options import ExecutionOptions, SchemaName, schema_to_str - -if TYPE_CHECKING: - import pandas as pd - import polars as pl - - -class IbisExecutor: - """Execute cohort expressions against an ibis backend. - - Notes: - - This API is experimental. - - `build()` returns an ibis table expression (lazy relation). - - Materialization happens in `to_polars()` / `to_pandas()` / `write()`. - """ - - def __init__(self, conn: Any, options: ExecutionOptions | None = None): - self._conn = conn - self._options = options or ExecutionOptions() - self._open_contexts: list[Any] = [] - - @property - def conn(self) -> Any: - return self._conn - - @property - def options(self) -> ExecutionOptions: - return self._options - - def build(self, expression: ExpressionInput) -> Any: - """Build a lazy ibis relation for the final cohort rows.""" - cohort_expression = load_expression(expression) - self.close() - return self._build_native(cohort_expression) - - def to_polars(self, expression: ExpressionInput) -> pl.DataFrame: - """Execute cohort expression and collect to Polars.""" - table = self.build(expression) - if not hasattr(table, "to_polars"): - raise RuntimeError("The returned ibis table does not support to_polars() on this backend.") - return table.to_polars() - - def to_pandas(self, expression: ExpressionInput) -> pd.DataFrame: - """Execute cohort expression and collect to pandas.""" - table = self.build(expression) - if not hasattr(table, "to_pandas"): - raise RuntimeError("The returned ibis table does not support to_pandas() on this backend.") - return table.to_pandas() - - def write( - self, - expression: ExpressionInput, - *, - table: str, - schema: SchemaName | None = None, - overwrite: bool = True, - append: bool = False, - cohort_id: int | None = None, - ) -> Any: - """Persist cohort rows to a cohort table and return a backend table handle.""" - if append and overwrite: - raise ValueError("`append=True` and `overwrite=True` cannot be used together.") - cohort_expression = load_expression(expression) - self.close() - events, ctx = self._build_with_context_native(cohort_expression, cohort_id_override=cohort_id) - self._open_contexts.append(ctx) - return ctx.write_cohort_table( - events, - table_name=table, - database=schema_to_str(schema) or schema_to_str(self._options.result_schema), - overwrite=overwrite, - append=append, - ) - - def captured_sql(self) -> list[tuple[str, str]]: - """Return captured staged SQL snippets when capture_sql is enabled.""" - captured: list[tuple[str, str]] = [] - for ctx in self._open_contexts: - if hasattr(ctx, "captured_sql"): - captured.extend(ctx.captured_sql()) - return captured - - def close(self) -> None: - """Release temporary resources held by execution contexts.""" - while self._open_contexts: - ctx = self._open_contexts.pop() - try: - ctx.close() - except Exception as exc: - print(f"Warning: failed to close execution context: {exc}") - - def __enter__(self) -> IbisExecutor: - return self - - def __exit__(self, exc_type, exc, tb) -> None: - self.close() - - def _build_native(self, cohort_expression: Any) -> Any: - events, ctx = self._build_with_context_native(cohort_expression) - self._open_contexts.append(ctx) - return events - - def _build_with_context_native( - self, - cohort_expression: Any, - cohort_id_override: int | None = None, - ) -> Any: - try: - from .build_context import ( - BuildContext, - CohortBuildOptions, - compile_codesets, - ) - from .builders.pipeline import build_primary_events - except ModuleNotFoundError as exc: - raise RuntimeError( - "Ibis execution requires optional dependencies. " - "Install `ohdsi-circe-python-alpha[ibis]` plus a backend extra, " - "for example `[ibis-duckdb]`." - ) from exc - - backend = self._infer_backend_name(self._conn) - options = CohortBuildOptions( - cdm_schema=schema_to_str(self._options.cdm_schema), - vocabulary_schema=schema_to_str(self._options.vocabulary_schema), - result_schema=schema_to_str(self._options.result_schema), - cohort_id=(cohort_id_override if cohort_id_override is not None else self._options.cohort_id), - materialize_stages=self._options.materialize_stages, - materialize_codesets=self._options.materialize_codesets, - temp_emulation_schema=schema_to_str(self._options.temp_emulation_schema), - profile_dir=self._options.profile_dir, - capture_sql=self._options.capture_sql, - backend=backend, - ) - resource = compile_codesets(self._conn, cohort_expression.concept_sets or [], options) - ctx = BuildContext(self._conn, options, resource) - events = build_primary_events(cohort_expression, ctx) - if events is None: - raise RuntimeError("No primary events were generated for the supplied cohort expression.") - return events, ctx - - @staticmethod - def _infer_backend_name(conn: Any) -> str | None: - backend_name = getattr(conn, "name", None) - if isinstance(backend_name, str) and backend_name: - return backend_name.lower() - class_name = conn.__class__.__name__.lower() - if "duckdb" in class_name: - return "duckdb" - if "postgres" in class_name: - return "postgres" - if "databricks" in class_name: - return "databricks" - return None - - -def build_ibis( - expression: ExpressionInput, - conn: Any, - options: ExecutionOptions | None = None, -) -> Any: - """Convenience wrapper for IbisExecutor.build().""" - with IbisExecutor(conn, options) as executor: - return executor.build(expression) - - -def to_polars( - expression: ExpressionInput, - conn: Any, - options: ExecutionOptions | None = None, -) -> pl.DataFrame: - """Convenience wrapper for IbisExecutor.to_polars().""" - with IbisExecutor(conn, options) as executor: - return executor.to_polars(expression) - - -def write_cohort( - expression: ExpressionInput, - conn: Any, - *, - table: str, - schema: SchemaName | None = None, - overwrite: bool = True, - append: bool = False, - cohort_id: int | None = None, - options: ExecutionOptions | None = None, -) -> Any: - """Convenience wrapper for IbisExecutor.write().""" - effective_options = options - if cohort_id is not None: - effective_options = replace(options or ExecutionOptions(), cohort_id=cohort_id) - - with IbisExecutor(conn, effective_options) as executor: - return executor.write( - expression, - table=table, - schema=schema, - overwrite=overwrite, - append=append, - cohort_id=cohort_id, - ) diff --git a/circe/execution/ibis/__init__.py b/circe/execution/ibis/__init__.py new file mode 100644 index 00000000..5f0cdbac --- /dev/null +++ b/circe/execution/ibis/__init__.py @@ -0,0 +1,11 @@ +from ..plan.schema import STANDARD_EVENT_COLUMNS +from .compiler import compile_event_plan +from .context import ExecutionContext +from .standardize import standardize_event_table + +__all__ = [ + "ExecutionContext", + "compile_event_plan", + "STANDARD_EVENT_COLUMNS", + "standardize_event_table", +] diff --git a/circe/execution/ibis/codesets.py b/circe/execution/ibis/codesets.py new file mode 100644 index 00000000..d286c7dd --- /dev/null +++ b/circe/execution/ibis/codesets.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping +from typing import Any + +from ..errors import CompilationError +from ..normalize.cohort import NormalizedConceptSet, NormalizedConceptSetItem +from ..plan.schema import CONCEPT_ID +from ..typing import Table + + +class CachedConceptSetResolver: + """Resolve concept sets to concrete concept IDs using vocabulary tables.""" + + def __init__( + self, + *, + table_getter: Callable[[str, str | None], Table], + vocabulary_schema: str | None, + concept_sets: Mapping[int, NormalizedConceptSet], + ) -> None: + self._table_getter = table_getter + self._vocabulary_schema = vocabulary_schema + self._concept_sets = concept_sets + self._cache: dict[int, tuple[int, ...]] = {} + + def resolve_codeset(self, codeset_id: int) -> tuple[int, ...]: + normalized_id = int(codeset_id) + if normalized_id in self._cache: + return self._cache[normalized_id] + + concept_set = self._concept_sets.get(normalized_id) + if concept_set is None or not concept_set.items: + return () + + include_ids: set[int] = set() + exclude_ids: set[int] = set() + for item in concept_set.items: + expanded = self._expand_item(item) + if item.is_excluded: + exclude_ids.update(expanded) + else: + include_ids.update(expanded) + + resolved = tuple(sorted(include_ids - exclude_ids)) + self._cache[normalized_id] = resolved + return resolved + + def _expand_item(self, item: NormalizedConceptSetItem) -> set[int]: + base_ids: set[int] = {int(item.concept_id)} + if item.include_descendants: + base_ids.update(self._descendant_ids(base_ids)) + + expanded = set(base_ids) + if item.include_mapped: + expanded.update(self._mapped_ids(base_ids)) + return expanded + + def _vocabulary_table(self, table_name: str) -> Table: + try: + return self._table_getter(table_name, self._vocabulary_schema) + except Exception as exc: # pragma: no cover - backend specific error types + raise CompilationError( + f"Ibis executor compilation error: failed to access vocabulary table '{table_name}'." + ) from exc + + def _descendant_ids(self, ancestor_ids: set[int]) -> set[int]: + if not ancestor_ids: + return set() + + concept = self._vocabulary_table("concept") + concept_ancestor = self._vocabulary_table("concept_ancestor") + query = ( + concept_ancestor.join( + concept, + concept_ancestor.descendant_concept_id == concept.concept_id, + ) + .filter(concept_ancestor.ancestor_concept_id.isin(tuple(ancestor_ids))) + .filter(concept.invalid_reason.isnull()) + .select(concept_ancestor.descendant_concept_id.name(CONCEPT_ID)) + .distinct() + ) + return self._execute_concept_id_query(query) + + def _mapped_ids(self, input_ids: set[int]) -> set[int]: + if not input_ids: + return set() + + concept_relationship = self._vocabulary_table("concept_relationship") + query = ( + concept_relationship.filter(concept_relationship.concept_id_2.isin(tuple(input_ids))) + .filter(concept_relationship.relationship_id == "Maps to") + .filter(concept_relationship.invalid_reason.isnull()) + .select(concept_relationship.concept_id_1.name(CONCEPT_ID)) + .distinct() + ) + return self._execute_concept_id_query(query) + + def _execute_concept_id_query(self, query: Table) -> set[int]: + try: + rows = query.execute() + except Exception as exc: # pragma: no cover - backend specific error types + raise CompilationError( + "Ibis executor compilation error: failed executing concept-set expansion query." + ) from exc + + values: list[Any] + if hasattr(rows, "columns"): # pandas DataFrame + values = rows[CONCEPT_ID].tolist() if CONCEPT_ID in rows.columns else rows.iloc[:, 0].tolist() + elif isinstance(rows, (list, tuple, set)): + values = list(rows) + else: + values = [rows] + + output: set[int] = set() + for value in values: + if value is None: + continue + output.add(int(value)) + return output diff --git a/circe/execution/ibis/compile_steps.py b/circe/execution/ibis/compile_steps.py new file mode 100644 index 00000000..0c6ad844 --- /dev/null +++ b/circe/execution/ibis/compile_steps.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +import ibis + +from ..errors import CompilationError, UnsupportedFeatureError +from ..plan.events import ( + ApplyDateAdjustment, + FilterByCareSite, + FilterByCareSiteLocationRegion, + FilterByCodeset, + FilterByConceptSet, + FilterByDateRange, + FilterByNumericRange, + FilterByPersonAge, + FilterByPersonEthnicity, + FilterByPersonGender, + FilterByPersonRace, + FilterByProviderSpecialty, + FilterByText, + FilterByVisit, + JoinLocationRegion, + KeepFirstPerPerson, + RestrictToCorrelatedWindow, + StandardizeEventShape, +) +from ..plan.predicates import DateRangePredicate, NumericRangePredicate +from ..plan.schema import END_DATE, PERSON_ID, START_DATE +from .context import ExecutionContext +from .person_filters import ( + apply_person_age_filter, + apply_person_ethnicity_filter, + apply_person_gender_filter, + apply_person_race_filter, +) +from .standardize import standardize_event_table + + +def _apply_numeric_predicate(expr, predicate: NumericRangePredicate): + op = (predicate.op or "eq").lower() + value = predicate.value + extent = predicate.extent + + if value is None: + return ibis.literal(True) + + if op in {"eq", "="}: + return expr == value + if op in {"neq", "!=", "ne"}: + return expr != value + if op in {"gt", ">"}: + return expr > value + if op in {"gte", ">="}: + return expr >= value + if op in {"lt", "<"}: + return expr < value + if op in {"lte", "<="}: + return expr <= value + if op in {"bt", "between"}: + if extent is None: + raise CompilationError( + "Ibis executor compilation error: numeric range 'between' requires an extent value." + ) + lower = min(value, extent) + upper = max(value, extent) + return (expr >= lower) & (expr <= upper) + + raise CompilationError(f"Ibis executor compilation error: unsupported numeric range op {predicate.op!r}.") + + +def _apply_date_predicate(expr, predicate: DateRangePredicate): + op = (predicate.op or "eq").lower() + value = predicate.value + extent = predicate.extent + + if value is None: + return ibis.literal(True) + + value_expr = ibis.literal(value).cast("date") + + if op in {"eq", "="}: + return expr.cast("date") == value_expr + if op in {"neq", "!=", "ne"}: + return expr.cast("date") != value_expr + if op in {"gt", ">"}: + return expr.cast("date") > value_expr + if op in {"gte", ">="}: + return expr.cast("date") >= value_expr + if op in {"lt", "<"}: + return expr.cast("date") < value_expr + if op in {"lte", "<="}: + return expr.cast("date") <= value_expr + if op in {"bt", "between"}: + if extent is None: + raise CompilationError( + "Ibis executor compilation error: date range 'between' requires an extent value." + ) + extent_expr = ibis.literal(extent).cast("date") + lower = ibis.least(value_expr, extent_expr) + upper = ibis.greatest(value_expr, extent_expr) + return (expr.cast("date") >= lower) & (expr.cast("date") <= upper) + + raise CompilationError(f"Ibis executor compilation error: unsupported date range op {predicate.op!r}.") + + +def _resolve_concept_ids( + *, + direct_ids: tuple[int, ...], + codeset_id: int | None, + ctx: ExecutionContext, +) -> tuple[int, ...]: + all_ids = list(direct_ids) + if codeset_id is not None: + for cid in ctx.concept_ids_for_codeset(codeset_id): + if cid not in all_ids: + all_ids.append(cid) + return tuple(all_ids) + + +def _select_original_columns(table, joined): + return joined.select(*[joined[c] for c in table.columns]) + + +def _filter_visit_concepts(table, ctx: ExecutionContext, *, step: FilterByVisit): + visit = ctx.table("visit_occurrence") + visit_lookup = visit.select( + visit.visit_occurrence_id.name("_visit_occurrence_id"), + visit.person_id.name("_visit_person_id"), + visit.visit_concept_id.name("_visit_concept_id"), + ) + joined = table.join( + visit_lookup, + predicates=[ + table[step.visit_occurrence_column] == visit_lookup._visit_occurrence_id, + table[PERSON_ID] == visit_lookup._visit_person_id, + ], + ) + concept_ids = _resolve_concept_ids( + direct_ids=step.concept_ids, + codeset_id=step.codeset_id, + ctx=ctx, + ) + predicate = joined._visit_concept_id.isin(concept_ids) + filtered = joined.filter(~predicate if step.exclude else predicate) + return _select_original_columns(table, filtered) + + +def _filter_provider_specialty( + table, + ctx: ExecutionContext, + *, + step: FilterByProviderSpecialty, +): + provider = ctx.table("provider") + provider_lookup = provider.select( + provider.provider_id.name("_provider_id"), + provider.specialty_concept_id.name("_specialty_concept_id"), + ) + joined = table.join( + provider_lookup, + predicates=[table[step.provider_id_column] == provider_lookup._provider_id], + ) + concept_ids = _resolve_concept_ids( + direct_ids=step.concept_ids, + codeset_id=step.codeset_id, + ctx=ctx, + ) + predicate = joined._specialty_concept_id.isin(concept_ids) + filtered = joined.filter(~predicate if step.exclude else predicate) + return _select_original_columns(table, filtered) + + +def _filter_care_site(table, ctx: ExecutionContext, *, step: FilterByCareSite): + care_site = ctx.table("care_site") + care_site_lookup = care_site.select( + care_site.care_site_id.name("_care_site_id"), + care_site.place_of_service_concept_id.name("_place_of_service_concept_id"), + ) + joined = table.join( + care_site_lookup, + predicates=[table[step.care_site_id_column] == care_site_lookup._care_site_id], + ) + concept_ids = _resolve_concept_ids( + direct_ids=step.concept_ids, + codeset_id=step.codeset_id, + ctx=ctx, + ) + predicate = joined._place_of_service_concept_id.isin(concept_ids) + filtered = joined.filter(~predicate if step.exclude else predicate) + return _select_original_columns(table, filtered) + + +def _filter_care_site_location_region( + table, + ctx: ExecutionContext, + *, + step: FilterByCareSiteLocationRegion, +): + region_ids = ctx.concept_ids_for_codeset(step.codeset_id) + if not region_ids: + return table.limit(0) + + location_history = ctx.table("location_history") + history_lookup = location_history.select( + location_history.entity_id.name("_care_site_id"), + location_history.location_id.name("_history_location_id"), + location_history.domain_id.name("_history_domain_id"), + location_history.start_date.name("_history_start_date"), + location_history.end_date.name("_history_end_date"), + ) + joined_history = table.join( + history_lookup, + predicates=[table[step.care_site_id_column] == history_lookup._care_site_id], + ) + history_end = ibis.coalesce( + joined_history._history_end_date.cast("date"), + ibis.literal("2099-12-31").cast("date"), + ) + joined_history = joined_history.filter( + (joined_history._history_domain_id == "CARE_SITE") + & ( + joined_history[step.start_date_column].cast("date") + >= joined_history._history_start_date.cast("date") + ) + & (joined_history[step.end_date_column].cast("date") <= history_end) + ) + + location = ctx.table("location") + location_lookup = location.select( + location.location_id.name("_location_id"), + location.region_concept_id.name("_region_concept_id"), + ) + joined = joined_history.join( + location_lookup, + predicates=[joined_history._history_location_id == location_lookup._location_id], + ) + filtered = joined.filter(joined._region_concept_id.isin(region_ids)) + return _select_original_columns(table, filtered) + + +def apply_step(step, *, table, source, ctx: ExecutionContext): + if isinstance(step, JoinLocationRegion): + location = ctx.table("location").select( + "location_id", + step.region_column, + ) + joined = table.join( + location, + predicates=[table[step.location_id_column] == location.location_id], + ) + return joined.select( + *[joined[c] for c in table.columns], + location[step.region_column].name(step.region_column), + ) + + if isinstance(step, FilterByCodeset): + concept_ids = ctx.concept_ids_for_codeset(step.codeset_id) + if not concept_ids: + return table if step.exclude else table.limit(0) + predicate = table[step.column].isin(concept_ids) + return table.filter(~predicate if step.exclude else predicate) + + if isinstance(step, FilterByConceptSet): + if not step.concept_ids: + return table if step.exclude else table.limit(0) + predicate = table[step.column].isin(step.concept_ids) + return table.filter(~predicate if step.exclude else predicate) + + if isinstance(step, FilterByVisit): + return _filter_visit_concepts(table, ctx, step=step) + + if isinstance(step, FilterByProviderSpecialty): + return _filter_provider_specialty(table, ctx, step=step) + + if isinstance(step, FilterByCareSite): + return _filter_care_site(table, ctx, step=step) + + if isinstance(step, FilterByCareSiteLocationRegion): + return _filter_care_site_location_region(table, ctx, step=step) + + if isinstance(step, FilterByDateRange): + return table.filter(_apply_date_predicate(table[step.column], step.predicate)) + + if isinstance(step, FilterByNumericRange): + return table.filter(_apply_numeric_predicate(table[step.column], step.predicate)) + + if isinstance(step, FilterByText): + op = (step.op or "eq").lower() + if step.text is None: + return table + if op in {"eq", "="}: + return table.filter(table[step.column] == step.text) + if op in {"neq", "!=", "ne"}: + return table.filter(table[step.column] != step.text) + if op in {"contains", "like"}: + return table.filter(table[step.column].contains(step.text)) + raise CompilationError(f"Ibis executor compilation error: unsupported text filter op {step.op!r}.") + + if isinstance(step, FilterByPersonAge): + return apply_person_age_filter( + table, + ctx, + date_column=step.date_column, + predicate=step.predicate, + ) + + if isinstance(step, FilterByPersonGender): + return apply_person_gender_filter( + table, + ctx, + concept_ids=step.concept_ids, + codeset_id=step.codeset_id, + ) + + if isinstance(step, FilterByPersonRace): + return apply_person_race_filter( + table, + ctx, + concept_ids=step.concept_ids, + codeset_id=step.codeset_id, + ) + + if isinstance(step, FilterByPersonEthnicity): + return apply_person_ethnicity_filter( + table, + ctx, + concept_ids=step.concept_ids, + codeset_id=step.codeset_id, + ) + + if isinstance(step, KeepFirstPerPerson): + order_by = [table[c] for c in step.order_by if c in table.columns] + window = ibis.window(group_by=table[PERSON_ID], order_by=order_by) + ranked = table.mutate(_exec_rn=ibis.row_number().over(window)) + return ranked.filter(ranked._exec_rn == 0).drop("_exec_rn") + + if isinstance(step, ApplyDateAdjustment): + start_anchor = table[START_DATE] if step.start_with == START_DATE else table[END_DATE] + end_anchor = table[START_DATE] if step.end_with == START_DATE else table[END_DATE] + return table.mutate( + **{ + START_DATE: start_anchor + ibis.interval(days=step.start_offset_days), + END_DATE: end_anchor + ibis.interval(days=step.end_offset_days), + } + ) + + if isinstance(step, RestrictToCorrelatedWindow): + raise UnsupportedFeatureError( + "Ibis executor compilation error: RestrictToCorrelatedWindow step is not implemented." + ) + + if isinstance(step, StandardizeEventShape): + return standardize_event_table( + table, + source=source, + criterion_type=step.criterion_type, + criterion_index=step.criterion_index, + start_offset_days=step.start_offset_days, + end_offset_days=step.end_offset_days, + start_with=step.start_with, + end_with=step.end_with, + ) + + raise CompilationError( + f"Ibis executor compilation error: unsupported plan step {step.__class__.__name__}." + ) diff --git a/circe/execution/ibis/compiler.py b/circe/execution/ibis/compiler.py new file mode 100644 index 00000000..daf9a400 --- /dev/null +++ b/circe/execution/ibis/compiler.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from ..plan.events import EventPlan +from ..typing import Table +from .compile_steps import apply_step +from .context import ExecutionContext + + +def compile_event_plan(plan: EventPlan, ctx: ExecutionContext) -> Table: + table = ctx.table(plan.source.table_name) + for step in plan.steps: + table = apply_step(step, table=table, source=plan.source, ctx=ctx) + return table diff --git a/circe/execution/ibis/context.py b/circe/execution/ibis/context.py new file mode 100644 index 00000000..57dab8e4 --- /dev/null +++ b/circe/execution/ibis/context.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from collections.abc import Mapping + +from .._dataclass import frozen_slots_dataclass +from ..normalize.cohort import NormalizedConceptSet +from ..typing import IbisBackendLike, Table +from .codesets import CachedConceptSetResolver + + +def _table_with_schema_fallback( + backend: IbisBackendLike, + table_name: str, + schema: str | None, +) -> Table: + try: + if schema is not None: + return backend.table(table_name, database=schema) + except TypeError: + pass + return backend.table(table_name) + + +@frozen_slots_dataclass +class ExecutionContext: + backend: IbisBackendLike + cdm_schema: str + results_schema: str | None + vocabulary_schema: str | None + codeset_resolver: CachedConceptSetResolver + + def table(self, table_name: str) -> Table: + return self._table_from_schema(table_name, self.cdm_schema) + + def vocabulary_table(self, table_name: str) -> Table: + return self._table_from_schema( + table_name, + self.vocabulary_schema or self.cdm_schema, + ) + + def _table_from_schema(self, table_name: str, schema: str | None) -> Table: + return _table_with_schema_fallback(self.backend, table_name, schema) + + def concept_ids_for_codeset(self, codeset_id: int) -> tuple[int, ...]: + return self.codeset_resolver.resolve_codeset(codeset_id) + + +def make_execution_context( + *, + backend: IbisBackendLike, + cdm_schema: str, + concept_sets: Mapping[int, NormalizedConceptSet], + results_schema: str | None = None, + vocabulary_schema: str | None = None, +) -> ExecutionContext: + """Construct an executor context from API-level wiring arguments.""" + vocabulary_schema = vocabulary_schema or cdm_schema + + def _table_getter(table_name: str, schema: str | None) -> Table: + return _table_with_schema_fallback(backend, table_name, schema) + + resolver = CachedConceptSetResolver( + table_getter=_table_getter, + vocabulary_schema=vocabulary_schema, + concept_sets=concept_sets, + ) + return ExecutionContext( + backend=backend, + cdm_schema=cdm_schema, + results_schema=results_schema, + vocabulary_schema=vocabulary_schema, + codeset_resolver=resolver, + ) diff --git a/circe/execution/ibis/materialize.py b/circe/execution/ibis/materialize.py new file mode 100644 index 00000000..5df9e473 --- /dev/null +++ b/circe/execution/ibis/materialize.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ..typing import Table + + +def project_to_ohdsi_cohort_table(relation: Table, *, cohort_id: int | None) -> Table: + """Project a generic cohort relation into OHDSI cohort-table shape.""" + import ibis + + cohort_id_expr = ( + ibis.literal(int(cohort_id), type="int64") if cohort_id is not None else ibis.null().cast("int64") + ) + return relation.select( + cohort_id_expr.name("cohort_definition_id"), + relation.person_id.cast("int64").name("subject_id"), + relation.start_date.cast("date").name("cohort_start_date"), + relation.end_date.cast("date").name("cohort_end_date"), + ) diff --git a/circe/execution/ibis/operations.py b/circe/execution/ibis/operations.py new file mode 100644 index 00000000..c8ba111e --- /dev/null +++ b/circe/execution/ibis/operations.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import sqlglot as sg +import sqlglot.expressions as sge + +from ..errors import ExecutionError +from ..typing import IbisBackendLike + + +def _call_with_optional_database(method, *args, database: str | None, **kwargs): + if database is not None: + try: + return method(*args, database=database, **kwargs) + except TypeError: + pass + return method(*args, **kwargs) + + +def table_exists( + backend: IbisBackendLike, + *, + table_name: str, + schema: str | None, +) -> bool: + """Return whether a backend table exists.""" + list_tables = getattr(backend, "list_tables", None) + if callable(list_tables): + if schema is not None: + try: + return table_name in list_tables(database=schema) + except TypeError: + return table_name in list_tables() + return table_name in list_tables() + + try: + read_table(backend, table_name=table_name, schema=schema) + except Exception: + return False + return True + + +def read_table( + backend: IbisBackendLike, + *, + table_name: str, + schema: str | None, +): + """Read a backend table as an Ibis relation.""" + return _call_with_optional_database( + backend.table, + table_name, + database=schema, + ) + + +def create_table( + backend: IbisBackendLike, + *, + table_name: str, + schema: str | None, + **kwargs, +) -> None: + """Create or overwrite a backend table with schema fallback.""" + _call_with_optional_database( + backend.create_table, + table_name, + database=schema, + **kwargs, + ) + + +def cohort_rows_exist( + backend: IbisBackendLike, + *, + cohort_table: str, + results_schema: str | None, + cohort_id: int, +) -> bool: + """Return whether a cohort table already contains rows for a cohort id.""" + import ibis + + try: + table = read_table(backend, table_name=cohort_table, schema=results_schema) + cohort_id_expr = ibis.literal(int(cohort_id), type="int64") + matching = table.filter(table.cohort_definition_id.cast("int64") == cohort_id_expr) + return len(matching.limit(1).execute()) > 0 + except Exception as exc: + raise ExecutionError( + f"Ibis executor write error: failed checking existing rows for cohort_id={cohort_id}." + ) from exc + + +def delete_cohort_rows( + backend: IbisBackendLike, + *, + cohort_table: str, + results_schema: str | None, + cohort_id: int, +) -> None: + """Delete existing cohort-table rows for a single cohort id.""" + raw_sql = getattr(backend, "raw_sql", None) + if not callable(raw_sql): + raise ExecutionError( + "Ibis executor write error: backend does not support raw_sql for cohort-table deletes." + ) + + catalog, database = _catalog_db_tuple(backend, results_schema) + quoted = getattr(getattr(backend, "compiler", None), "quoted", False) + statement = sge.delete(sg.table(cohort_table, db=database, catalog=catalog, quoted=quoted)).where( + sg.column("cohort_definition_id", quoted=quoted).eq(sge.convert(int(cohort_id))) + ) + + try: + raw_sql(statement) + except Exception as exc: + raise ExecutionError( + "Ibis executor write error: failed deleting existing cohort rows from " + f"'{cohort_table}' for cohort_id={cohort_id}." + ) from exc + + +def supports_transactional_replace(backend: IbisBackendLike) -> bool: + """Return whether cohort-scoped delete+insert can run transactionally.""" + return getattr(backend, "name", None) in {"duckdb", "postgres"} + + +def replace_cohort_rows_transactionally( + relation, + *, + backend: IbisBackendLike, + cohort_table: str, + results_schema: str | None, + cohort_id: int, +) -> None: + """Replace one cohort's rows atomically using delete+insert when supported.""" + if not supports_transactional_replace(backend): + raise ExecutionError( + "Ibis executor write error: backend does not support transactional cohort-table replace." + ) + + _run_transaction_control(backend, "BEGIN") + try: + delete_cohort_rows( + backend, + cohort_table=cohort_table, + results_schema=results_schema, + cohort_id=cohort_id, + ) + insert_relation( + relation, + backend=backend, + target_table=cohort_table, + target_schema=results_schema, + ) + except Exception: + _run_transaction_control(backend, "ROLLBACK") + raise + else: + _run_transaction_control(backend, "COMMIT") + + +def exclude_cohort_rows(table, *, cohort_id: int): + """Filter an existing cohort table to all cohort ids except one.""" + import ibis + + cohort_id_expr = ibis.literal(int(cohort_id), type="int64") + try: + return table.filter(table.cohort_definition_id.cast("int64") != cohort_id_expr) + except Exception as exc: + raise ExecutionError( + f"Ibis executor write error: failed removing existing rows for cohort_id={cohort_id}." + ) from exc + + +def insert_relation( + relation, + *, + backend: IbisBackendLike, + target_table: str, + target_schema: str | None, +) -> None: + """Insert an Ibis relation into an existing backend table.""" + insert = getattr(backend, "insert", None) + if not callable(insert): + raise ExecutionError( + "Ibis executor write error: backend does not support insert for cohort-table writes." + ) + + try: + _call_with_optional_database( + insert, + target_table, + relation, + database=target_schema, + overwrite=False, + ) + except Exception as exc: + schema_label = target_schema if target_schema is not None else "" + raise ExecutionError( + "Ibis executor write error: failed inserting relation into " + f"table '{target_table}' in schema '{schema_label}'." + ) from exc + + +def _run_transaction_control(backend: IbisBackendLike, statement: str) -> None: + raw_sql = getattr(backend, "raw_sql", None) + if not callable(raw_sql): + raise ExecutionError( + "Ibis executor write error: backend does not support raw_sql for transactional cohort writes." + ) + + try: + raw_sql(statement) + except Exception as exc: + raise ExecutionError( + f"Ibis executor write error: failed executing transaction statement {statement!r}." + ) from exc + + +def _catalog_db_tuple(backend: IbisBackendLike, schema: str | None) -> tuple[str | None, str | None]: + if schema is None: + return None, None + + to_sqlglot_table = getattr(backend, "_to_sqlglot_table", None) + to_catalog_db_tuple = getattr(backend, "_to_catalog_db_tuple", None) + if callable(to_sqlglot_table) and callable(to_catalog_db_tuple): + try: + return to_catalog_db_tuple(to_sqlglot_table(schema)) + except Exception: + pass + + return None, schema diff --git a/circe/execution/ibis/person_filters.py b/circe/execution/ibis/person_filters.py new file mode 100644 index 00000000..b46bd998 --- /dev/null +++ b/circe/execution/ibis/person_filters.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import ibis + +from ..errors import CompilationError +from ..plan.predicates import NumericRangePredicate +from ..plan.schema import PERSON_ID +from .context import ExecutionContext + + +def _apply_numeric_predicate(expr, predicate: NumericRangePredicate): + op = (predicate.op or "eq").lower() + value = predicate.value + extent = predicate.extent + + if value is None: + return ibis.literal(True) + + if op in {"eq", "="}: + return expr == value + if op in {"neq", "!=", "ne"}: + return expr != value + if op in {"gt", ">"}: + return expr > value + if op in {"gte", ">="}: + return expr >= value + if op in {"lt", "<"}: + return expr < value + if op in {"lte", "<="}: + return expr <= value + if op in {"bt", "between"}: + if extent is None: + raise CompilationError( + "Ibis executor compilation error: person numeric range 'between' requires an extent value." + ) + lower = min(value, extent) + upper = max(value, extent) + return (expr >= lower) & (expr <= upper) + + raise CompilationError( + f"Ibis executor compilation error: unsupported person numeric range op {predicate.op!r}." + ) + + +def apply_person_age_filter(table, ctx: ExecutionContext, *, date_column: str, predicate): + person = ctx.table("person").select( + PERSON_ID, + "year_of_birth", + ) + joined = table.join(person, table[PERSON_ID] == person[PERSON_ID]) + event_date = joined[date_column].cast("date") + age_years = event_date.year() - joined.year_of_birth + filtered = joined.filter(_apply_numeric_predicate(age_years, predicate)) + return filtered.select(*[filtered[c] for c in table.columns]) + + +def apply_person_gender_filter( + table, + ctx: ExecutionContext, + *, + concept_ids: tuple[int, ...], + codeset_id: int | None, +): + all_ids = list(concept_ids) + if codeset_id is not None: + for cid in ctx.concept_ids_for_codeset(codeset_id): + if cid not in all_ids: + all_ids.append(cid) + + if not all_ids: + return table + + person = ctx.table("person").select(PERSON_ID, "gender_concept_id") + joined = table.join(person, table[PERSON_ID] == person[PERSON_ID]) + filtered = joined.filter(joined.gender_concept_id.isin(all_ids)) + return filtered.select(*[filtered[c] for c in table.columns]) + + +def _apply_person_concept_filter( + table, + ctx: ExecutionContext, + *, + person_column: str, + concept_ids: tuple[int, ...], + codeset_id: int | None, +): + all_ids = list(concept_ids) + if codeset_id is not None: + for cid in ctx.concept_ids_for_codeset(codeset_id): + if cid not in all_ids: + all_ids.append(cid) + + if not all_ids: + return table + + person = ctx.table("person").select(PERSON_ID, person_column) + joined = table.join(person, table[PERSON_ID] == person[PERSON_ID]) + filtered = joined.filter(joined[person_column].isin(all_ids)) + return filtered.select(*[filtered[c] for c in table.columns]) + + +def apply_person_race_filter( + table, + ctx: ExecutionContext, + *, + concept_ids: tuple[int, ...], + codeset_id: int | None, +): + return _apply_person_concept_filter( + table, + ctx, + person_column="race_concept_id", + concept_ids=concept_ids, + codeset_id=codeset_id, + ) + + +def apply_person_ethnicity_filter( + table, + ctx: ExecutionContext, + *, + concept_ids: tuple[int, ...], + codeset_id: int | None, +): + return _apply_person_concept_filter( + table, + ctx, + person_column="ethnicity_concept_id", + concept_ids=concept_ids, + codeset_id=codeset_id, + ) diff --git a/circe/execution/ibis/standardize.py b/circe/execution/ibis/standardize.py new file mode 100644 index 00000000..a04435c3 --- /dev/null +++ b/circe/execution/ibis/standardize.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import ibis + +from ..plan.events import EventSource +from ..plan.schema import ( + CONCEPT_ID, + CRITERION_INDEX, + CRITERION_TYPE, + DAYS_SUPPLY, + DOMAIN, + DURATION, + END_DATE, + EVENT_ID, + GAP_DAYS, + OCCURRENCE_COUNT, + PERSON_ID, + QUANTITY, + RANGE_HIGH, + RANGE_LOW, + REFILLS, + SOURCE_CONCEPT_ID, + SOURCE_TABLE, + START_DATE, + UNIT_CONCEPT_ID, + VALUE_AS_NUMBER, + VISIT_DETAIL_ID, + VISIT_OCCURRENCE_ID, +) + + +def _typed_optional_column(table, column_name: str | None, dtype: str): + if column_name and column_name in table.columns: + return table[column_name].cast(dtype) + return ibis.null().cast(dtype) + + +def _base_start_expr(table, *, source: EventSource): + return table[source.start_date_column].cast("date") + + +def _base_end_expr(table, *, source: EventSource, start_expr): + raw_end_expr = _typed_optional_column(table, source.end_date_column, "date") + + if source.table_name == "condition_occurrence": + return ibis.coalesce(raw_end_expr, start_expr + ibis.interval(days=1)) + if source.table_name == "drug_exposure": + days_supply_expr = _typed_optional_column(table, "days_supply", "int64") + supply_end_expr = start_expr + days_supply_expr.as_interval("D") + return ibis.coalesce(raw_end_expr, supply_end_expr, start_expr + ibis.interval(days=1)) + if source.table_name == "device_exposure": + return ibis.coalesce(raw_end_expr, start_expr + ibis.interval(days=1)) + if source.table_name in {"procedure_occurrence", "measurement", "observation", "death"}: + return start_expr + ibis.interval(days=1) + if source.table_name == "specimen": + return start_expr + + return raw_end_expr + + +def _adjust_dates( + start_expr, + end_expr, + *, + start_offset_days: int, + end_offset_days: int, + start_with: str, + end_with: str, +): + start_anchor = start_expr if start_with == START_DATE else end_expr + end_anchor = start_expr if end_with == START_DATE else end_expr + adjusted_start = start_anchor + ibis.interval(days=int(start_offset_days)) + adjusted_end = end_anchor + ibis.interval(days=int(end_offset_days)) + return adjusted_start, adjusted_end + + +def _duration_expr(*, source: EventSource, start_expr, end_expr): + if source.table_name in {"measurement", "observation"}: + return ibis.null().cast("int64") + if source.table_name in {"death", "specimen"}: + return ibis.literal(1, type="int64") + return end_expr.delta(start_expr, unit="day").cast("int64") + + +def _supplemental_exprs(table, *, source: EventSource, start_expr, end_expr) -> dict[str, object]: + value_as_number_expr = _typed_optional_column(table, "value_as_number", "float64") + if source.table_name == "dose_era" and "dose_value" in table.columns: + value_as_number_expr = table["dose_value"].cast("float64") + + unit_concept_expr = _typed_optional_column(table, "unit_concept_id", "int64") + if source.table_name == "drug_exposure" and "dose_unit_concept_id" in table.columns: + unit_concept_expr = table["dose_unit_concept_id"].cast("int64") + + occurrence_count_expr = ibis.null().cast("int64") + if "occurrence_count" in table.columns: + occurrence_count_expr = table["occurrence_count"].cast("int64") + elif "condition_occurrence_count" in table.columns: + occurrence_count_expr = table["condition_occurrence_count"].cast("int64") + elif "drug_exposure_count" in table.columns: + occurrence_count_expr = table["drug_exposure_count"].cast("int64") + + return { + QUANTITY: _typed_optional_column(table, "quantity", "float64"), + DAYS_SUPPLY: _typed_optional_column(table, "days_supply", "float64"), + REFILLS: _typed_optional_column(table, "refills", "float64"), + RANGE_LOW: _typed_optional_column(table, "range_low", "float64"), + RANGE_HIGH: _typed_optional_column(table, "range_high", "float64"), + VALUE_AS_NUMBER: value_as_number_expr, + UNIT_CONCEPT_ID: unit_concept_expr, + VISIT_DETAIL_ID: _typed_optional_column(table, "visit_detail_id", "int64"), + OCCURRENCE_COUNT: occurrence_count_expr, + GAP_DAYS: _typed_optional_column(table, "gap_days", "int64"), + DURATION: _duration_expr(source=source, start_expr=start_expr, end_expr=end_expr), + } + + +def standardize_event_table( + table, + *, + source: EventSource, + criterion_type: str, + criterion_index: int, + start_offset_days: int = 0, + end_offset_days: int = 0, + start_with: str = START_DATE, + end_with: str = END_DATE, +): + base_start_expr = _base_start_expr(table, source=source) + base_end_expr = _base_end_expr(table, source=source, start_expr=base_start_expr) + start_expr, end_expr = _adjust_dates( + base_start_expr, + base_end_expr, + start_offset_days=start_offset_days, + end_offset_days=end_offset_days, + start_with=start_with, + end_with=end_with, + ) + + concept_expr = ibis.null().cast("int64") + if source.concept_column and source.concept_column in table.columns: + concept_expr = table[source.concept_column].cast("int64") + if source.table_name == "death": + concept_expr = ibis.coalesce(concept_expr, ibis.literal(0, type="int64")) + + source_concept_expr = ibis.null().cast("int64") + if source.source_concept_column and source.source_concept_column in table.columns: + source_concept_expr = table[source.source_concept_column].cast("int64") + + visit_occ_expr = ibis.null().cast("int64") + if source.visit_occurrence_column and source.visit_occurrence_column in table.columns: + visit_occ_expr = table[source.visit_occurrence_column].cast("int64") + + supplemental_exprs = _supplemental_exprs( + table, + source=source, + start_expr=start_expr, + end_expr=end_expr, + ) + standardized = table.select( + table[source.person_id_column].cast("int64").name(PERSON_ID), + table[source.event_id_column].cast("int64").name(EVENT_ID), + start_expr.name(START_DATE), + end_expr.name(END_DATE), + ibis.literal(source.domain).name(DOMAIN), + concept_expr.name(CONCEPT_ID), + source_concept_expr.name(SOURCE_CONCEPT_ID), + visit_occ_expr.name(VISIT_OCCURRENCE_ID), + supplemental_exprs[VISIT_DETAIL_ID].name(VISIT_DETAIL_ID), + supplemental_exprs[QUANTITY].name(QUANTITY), + supplemental_exprs[DAYS_SUPPLY].name(DAYS_SUPPLY), + supplemental_exprs[REFILLS].name(REFILLS), + supplemental_exprs[RANGE_LOW].name(RANGE_LOW), + supplemental_exprs[RANGE_HIGH].name(RANGE_HIGH), + supplemental_exprs[VALUE_AS_NUMBER].name(VALUE_AS_NUMBER), + supplemental_exprs[UNIT_CONCEPT_ID].name(UNIT_CONCEPT_ID), + supplemental_exprs[OCCURRENCE_COUNT].name(OCCURRENCE_COUNT), + supplemental_exprs[GAP_DAYS].name(GAP_DAYS), + supplemental_exprs[DURATION].name(DURATION), + ibis.literal(int(criterion_index), type="int64").name(CRITERION_INDEX), + ibis.literal(criterion_type).name(CRITERION_TYPE), + ibis.literal(source.table_name).name(SOURCE_TABLE), + ) + return standardized diff --git a/circe/execution/ibis_compat.py b/circe/execution/ibis_compat.py index d5f215f9..2e4b379e 100644 --- a/circe/execution/ibis_compat.py +++ b/circe/execution/ibis_compat.py @@ -1,36 +1,49 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Mapping, Sequence +from typing import Any import ibis import ibis.expr.operations as ops -import ibis.expr.types as ir from ibis.common.collections import FrozenOrderedDict +from .typing import IbisBackendLike, Table -def table_from_literal_list( - values: Iterable[int], + +def _is_nullish(value: Any) -> bool: + if value is None: + return True + try: + return bool(value != value) + except Exception: + return False + + +def _typed_literal(value: Any, *, dtype: str) -> Any: + if _is_nullish(value): + return ibis.null().cast(dtype) + return ibis.literal(value).cast(dtype) + + +def literal_column_relation( + values: Iterable[Any], *, column_name: str, - element_type: str = "int64", -) -> ir.Table: - """ - Build a 1-column table from a Python list without using `ibis.memtable`. - - This avoids Databricks' memtable upload machinery (which depends on a writable - Unity Catalog volume) while still producing a pure Ibis expression. - """ + dtype: str, + backend: IbisBackendLike | None = None, +) -> Table: + """Build a 1-column relation from Python literals without `ibis.memtable(...)`.""" + _ = backend values_list = list(values) if not values_list: dummy = ops.DummyTable( - values=FrozenOrderedDict({column_name: ibis.null().cast(element_type).op()}) + values=FrozenOrderedDict({column_name: ibis.null().cast(dtype).op()}) ).to_expr() return dummy.select(dummy[column_name]).filter(ibis.literal(False)) - array_type = f"array<{element_type}>" - arr = ibis.literal(values_list, type=array_type) - - dummy = ops.DummyTable(values=FrozenOrderedDict({"__values__": arr.op()})).to_expr() + array_type = f"array<{dtype}>" + literal_array = ibis.literal(values_list, type=array_type) + dummy = ops.DummyTable(values=FrozenOrderedDict({"__values__": literal_array.op()})).to_expr() unnested = ops.TableUnnest( dummy.op(), dummy["__values__"].op(), @@ -39,3 +52,52 @@ def table_from_literal_list( False, ).to_expr() return unnested.select(unnested[column_name]) + + +def _single_row_relation( + row: Mapping[str, Any], + *, + schema: Mapping[str, str], +) -> Table: + return ops.DummyTable( + values=FrozenOrderedDict( + {column: _typed_literal(row.get(column), dtype=dtype).op() for column, dtype in schema.items()} + ) + ).to_expr() + + +def literal_rows_relation( + rows: Sequence[Mapping[str, Any]], + *, + schema: Mapping[str, str], + backend: IbisBackendLike | None = None, +) -> Table: + """Build a typed relation from row dictionaries without `ibis.memtable(...)`.""" + _ = backend + if not schema: + raise ValueError("literal_rows_relation requires a non-empty schema.") + + if not rows: + empty_row = _single_row_relation( + dict.fromkeys(schema), + schema=schema, + ) + return empty_row.filter(ibis.literal(False)) + + relation: Table = _single_row_relation(rows[0], schema=schema) + for row in rows[1:]: + relation = relation.union(_single_row_relation(row, schema=schema), distinct=False) + return relation + + +def table_from_literal_list( + values: Iterable[int], + *, + column_name: str, + element_type: str = "int64", +) -> Table: + """Backward-compatible wrapper over `literal_column_relation`.""" + return literal_column_relation(values, column_name=column_name, dtype=element_type) + + +__all__ = ["literal_column_relation", "literal_rows_relation", "table_from_literal_list"] diff --git a/circe/execution/lower/__init__.py b/circe/execution/lower/__init__.py new file mode 100644 index 00000000..95d11a79 --- /dev/null +++ b/circe/execution/lower/__init__.py @@ -0,0 +1,3 @@ +from .criteria import LOWERERS, lower_criterion + +__all__ = ["LOWERERS", "lower_criterion"] diff --git a/circe/execution/lower/common.py b/circe/execution/lower/common.py new file mode 100644 index 00000000..3b474989 --- /dev/null +++ b/circe/execution/lower/common.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +from ...cohortdefinition.core import ConceptSetSelection, NumericRange, TextFilter +from ...vocabulary.concept import Concept +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import ( + EventPlan, + EventSource, + FilterByCareSite, + FilterByCareSiteLocationRegion, + FilterByCodeset, + FilterByConceptSet, + FilterByDateRange, + FilterByNumericRange, + FilterByPersonAge, + FilterByPersonEthnicity, + FilterByPersonGender, + FilterByPersonRace, + FilterByProviderSpecialty, + FilterByText, + FilterByVisit, + KeepFirstPerPerson, + PlanStep, + StandardizeEventShape, +) +from ..plan.predicates import DateRangePredicate, NumericRangePredicate +from ..plan.schema import DURATION, END_DATE, START_DATE + + +def lower_common_steps(criterion: NormalizedCriterion) -> list[PlanStep]: + steps: list[PlanStep] = [] + + if criterion.codeset_id is not None and criterion.concept_column is not None: + steps.append( + FilterByCodeset( + column=criterion.concept_column, + codeset_id=int(criterion.codeset_id), + ) + ) + + if criterion.person_filters.gender_concept_ids or criterion.person_filters.gender_codeset_id is not None: + steps.append( + FilterByPersonGender( + concept_ids=criterion.person_filters.gender_concept_ids, + codeset_id=criterion.person_filters.gender_codeset_id, + ) + ) + + if criterion.person_filters.race_concept_ids or criterion.person_filters.race_codeset_id is not None: + steps.append( + FilterByPersonRace( + concept_ids=criterion.person_filters.race_concept_ids, + codeset_id=criterion.person_filters.race_codeset_id, + ) + ) + + if ( + criterion.person_filters.ethnicity_concept_ids + or criterion.person_filters.ethnicity_codeset_id is not None + ): + steps.append( + FilterByPersonEthnicity( + concept_ids=criterion.person_filters.ethnicity_concept_ids, + codeset_id=criterion.person_filters.ethnicity_codeset_id, + ) + ) + + if criterion.first: + steps.append( + KeepFirstPerPerson( + order_by=(criterion.start_date_column, criterion.event_id_column), + ) + ) + + return steps + + +def concept_ids(values: list[Concept] | None) -> tuple[int, ...]: + if not values: + return () + output: list[int] = [] + for concept in values: + if concept is None or concept.concept_id is None: + continue + cid = int(concept.concept_id) + if cid not in output: + output.append(cid) + return tuple(output) + + +def append_numeric_filter( + steps: list[PlanStep], + *, + column: str, + value: NumericRange | None, +) -> None: + if value is None: + return + steps.append( + FilterByNumericRange( + column=column, + predicate=NumericRangePredicate( + op=value.op, + value=value.value, + extent=value.extent, + ), + ) + ) + + +def append_text_filter( + steps: list[PlanStep], + *, + column: str, + value: TextFilter | None, +) -> None: + if value is None: + return + steps.append( + FilterByText( + column=column, + op=value.op, + text=value.text, + ) + ) + + +def append_concept_filters( + steps: list[PlanStep], + *, + column: str, + concepts: list[Concept] | None = None, + codeset_selection: ConceptSetSelection | None = None, + exclude: bool = False, +) -> None: + ids = concept_ids(concepts) + if ids: + steps.append( + FilterByConceptSet( + column=column, + concept_ids=ids, + exclude=bool(exclude), + ) + ) + + if codeset_selection and codeset_selection.codeset_id is not None: + steps.append( + FilterByCodeset( + column=column, + codeset_id=int(codeset_selection.codeset_id), + exclude=bool(codeset_selection.is_exclusion) or bool(exclude), + ) + ) + + +def append_visit_filters( + steps: list[PlanStep], + *, + visit_occurrence_column: str, + concepts: list[Concept] | None = None, + codeset_selection: ConceptSetSelection | None = None, + exclude: bool = False, +) -> None: + ids = concept_ids(concepts) + if ids: + steps.append( + FilterByVisit( + visit_occurrence_column=visit_occurrence_column, + concept_ids=ids, + exclude=bool(exclude), + ) + ) + + if codeset_selection and codeset_selection.codeset_id is not None: + steps.append( + FilterByVisit( + visit_occurrence_column=visit_occurrence_column, + codeset_id=int(codeset_selection.codeset_id), + exclude=bool(codeset_selection.is_exclusion), + ) + ) + + +def append_provider_specialty_filters( + steps: list[PlanStep], + *, + provider_id_column: str = "provider_id", + concepts: list[Concept] | None = None, + codeset_selection: ConceptSetSelection | None = None, +) -> None: + ids = concept_ids(concepts) + if ids: + steps.append( + FilterByProviderSpecialty( + provider_id_column=provider_id_column, + concept_ids=ids, + ) + ) + + if codeset_selection and codeset_selection.codeset_id is not None: + steps.append( + FilterByProviderSpecialty( + provider_id_column=provider_id_column, + codeset_id=int(codeset_selection.codeset_id), + exclude=bool(codeset_selection.is_exclusion), + ) + ) + + +def append_care_site_filters( + steps: list[PlanStep], + *, + care_site_id_column: str = "care_site_id", + concepts: list[Concept] | None = None, + codeset_selection: ConceptSetSelection | None = None, +) -> None: + ids = concept_ids(concepts) + if ids: + steps.append( + FilterByCareSite( + care_site_id_column=care_site_id_column, + concept_ids=ids, + ) + ) + + if codeset_selection and codeset_selection.codeset_id is not None: + steps.append( + FilterByCareSite( + care_site_id_column=care_site_id_column, + codeset_id=int(codeset_selection.codeset_id), + exclude=bool(codeset_selection.is_exclusion), + ) + ) + + +def append_care_site_location_region_filter( + steps: list[PlanStep], + *, + care_site_id_column: str = "care_site_id", + start_date_column: str, + end_date_column: str, + codeset_id: int | None, +) -> None: + if codeset_id is None: + return + steps.append( + FilterByCareSiteLocationRegion( + care_site_id_column=care_site_id_column, + start_date_column=start_date_column, + end_date_column=end_date_column, + codeset_id=int(codeset_id), + ) + ) + + +def append_post_standardization_common_steps( + criterion: NormalizedCriterion, + *, + steps: list[PlanStep], +) -> None: + if criterion.person_filters.age is not None: + steps.append( + FilterByPersonAge( + date_column=START_DATE, + predicate=NumericRangePredicate( + op=criterion.person_filters.age.op, + value=criterion.person_filters.age.value, + extent=criterion.person_filters.age.extent, + ), + ) + ) + + if criterion.occurrence_start_date is not None: + steps.append( + FilterByDateRange( + column=START_DATE, + predicate=DateRangePredicate( + op=criterion.occurrence_start_date.op, + value=criterion.occurrence_start_date.value, + extent=criterion.occurrence_start_date.extent, + ), + ) + ) + + if criterion.occurrence_end_date is not None: + steps.append( + FilterByDateRange( + column=END_DATE, + predicate=DateRangePredicate( + op=criterion.occurrence_end_date.op, + value=criterion.occurrence_end_date.value, + extent=criterion.occurrence_end_date.extent, + ), + ) + ) + + +def append_duration_filter( + steps: list[PlanStep], + *, + value: NumericRange | None, +) -> None: + append_numeric_filter(steps, column=DURATION, value=value) + + +def build_standard_domain_plan( + criterion: NormalizedCriterion, + *, + criterion_index: int, + steps: list[PlanStep], + post_standardize_steps: list[PlanStep] | None = None, +) -> EventPlan: + plan_steps = list(steps) + date_adjustment = getattr(criterion.raw_criteria, "date_adjustment", None) + start_with = START_DATE + end_with = END_DATE + start_offset_days = 0 + end_offset_days = 0 + if date_adjustment is not None: + start_with = ( + date_adjustment.start_with.value + if getattr(date_adjustment, "start_with", None) is not None + else START_DATE + ) + end_with = ( + date_adjustment.end_with.value + if getattr(date_adjustment, "end_with", None) is not None + else END_DATE + ) + start_offset_days = int(date_adjustment.start_offset) + end_offset_days = int(date_adjustment.end_offset) + + plan_steps.append( + StandardizeEventShape( + criterion_type=criterion.criterion_type, + criterion_index=criterion_index, + start_offset_days=start_offset_days, + end_offset_days=end_offset_days, + start_with=start_with, + end_with=end_with, + ) + ) + + standard_post_steps = list(post_standardize_steps or []) + append_post_standardization_common_steps(criterion, steps=standard_post_steps) + plan_steps.extend(standard_post_steps) + + return EventPlan( + source=EventSource( + table_name=criterion.source_table, + domain=criterion.domain, + event_id_column=criterion.event_id_column, + start_date_column=criterion.start_date_column, + end_date_column=criterion.end_date_column, + concept_column=criterion.concept_column, + source_concept_column=criterion.source_concept_column, + visit_occurrence_column=criterion.visit_occurrence_column, + ), + criterion_type=criterion.criterion_type, + criterion_index=criterion_index, + steps=tuple(plan_steps), + ) + + +def lower_standard_domain_plan( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + steps = lower_common_steps(criterion) + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/condition_era.py b/circe/execution/lower/condition_era.py new file mode 100644 index 00000000..9162e510 --- /dev/null +++ b/circe/execution/lower/condition_era.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from ..plan.schema import OCCURRENCE_COUNT +from .common import ( + append_duration_filter, + append_numeric_filter, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_condition_era( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + steps = lower_common_steps(criterion) + post_standardize_steps = [] + raw = criterion.raw_criteria + + append_numeric_filter( + post_standardize_steps, + column=OCCURRENCE_COUNT, + value=raw.occurrence_count, + ) + append_duration_filter(post_standardize_steps, value=raw.era_length) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + post_standardize_steps=post_standardize_steps, + ) diff --git a/circe/execution/lower/condition_occurrence.py b/circe/execution/lower/condition_occurrence.py new file mode 100644 index 00000000..8d11c2dc --- /dev/null +++ b/circe/execution/lower/condition_occurrence.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import ConditionOccurrence +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import ( + append_concept_filters, + append_provider_specialty_filters, + append_text_filter, + append_visit_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_condition_occurrence( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, ConditionOccurrence): + raise TypeError("lower_condition_occurrence requires ConditionOccurrence criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="condition_type_concept_id", + concepts=raw.condition_type, + codeset_selection=raw.condition_type_cs, + exclude=bool(raw.condition_type_exclude), + ) + append_text_filter(steps, column="stop_reason", value=raw.stop_reason) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_visit_filters( + steps, + visit_occurrence_column="visit_occurrence_id", + concepts=raw.visit_type, + codeset_selection=raw.visit_type_cs, + ) + append_concept_filters( + steps, + column="condition_status_concept_id", + concepts=raw.condition_status, + codeset_selection=raw.condition_status_cs, + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/criteria.py b/circe/execution/lower/criteria.py new file mode 100644 index 00000000..90171f78 --- /dev/null +++ b/circe/execution/lower/criteria.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Protocol + +from ...cohortdefinition.criteria import ( + ConditionEra, + ConditionOccurrence, + Criteria, + Death, + DeviceExposure, + DoseEra, + DrugEra, + DrugExposure, + LocationRegion, + Measurement, + Observation, + ObservationPeriod, + PayerPlanPeriod, + ProcedureOccurrence, + Specimen, + VisitDetail, + VisitOccurrence, +) +from ..errors import UnsupportedCriterionError +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .condition_era import lower_condition_era +from .condition_occurrence import lower_condition_occurrence +from .death import lower_death +from .device_exposure import lower_device_exposure +from .dose_era import lower_dose_era +from .drug_era import lower_drug_era +from .drug_exposure import lower_drug_exposure +from .location_region import lower_location_region +from .measurement import lower_measurement +from .observation import lower_observation +from .observation_period import lower_observation_period +from .payer_plan_period import lower_payer_plan_period +from .procedure_occurrence import lower_procedure_occurrence +from .specimen import lower_specimen +from .visit_detail import lower_visit_detail +from .visit_occurrence import lower_visit_occurrence + + +class LowerFn(Protocol): + def __call__( + self, + criterion: NormalizedCriterion, + *, + criterion_index: int, + ) -> EventPlan: ... + + +LOWERERS: dict[type[Criteria], LowerFn] = { + ConditionOccurrence: lower_condition_occurrence, + DrugExposure: lower_drug_exposure, + VisitOccurrence: lower_visit_occurrence, + Measurement: lower_measurement, + ProcedureOccurrence: lower_procedure_occurrence, + Observation: lower_observation, + VisitDetail: lower_visit_detail, + DeviceExposure: lower_device_exposure, + Specimen: lower_specimen, + Death: lower_death, + ObservationPeriod: lower_observation_period, + PayerPlanPeriod: lower_payer_plan_period, + ConditionEra: lower_condition_era, + DrugEra: lower_drug_era, + DoseEra: lower_dose_era, + LocationRegion: lower_location_region, +} + + +def lower_criterion( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + lowerer = LOWERERS.get(type(criterion.raw_criteria)) + if lowerer is not None: + return lowerer(criterion, criterion_index=criterion_index) + raise UnsupportedCriterionError( + f"Ibis executor lowering error: no lowerer registered for {criterion.criterion_type}." + ) diff --git a/circe/execution/lower/death.py b/circe/execution/lower/death.py new file mode 100644 index 00000000..0c95c21d --- /dev/null +++ b/circe/execution/lower/death.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import Death +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import append_concept_filters, build_standard_domain_plan, lower_common_steps + + +def lower_death( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, Death): + raise TypeError("lower_death requires Death criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="death_type_concept_id", + concepts=raw.death_type, + codeset_selection=raw.death_type_cs, + exclude=bool(raw.death_type_exclude), + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/device_exposure.py b/circe/execution/lower/device_exposure.py new file mode 100644 index 00000000..c5c2e19d --- /dev/null +++ b/circe/execution/lower/device_exposure.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import DeviceExposure +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import ( + append_concept_filters, + append_numeric_filter, + append_provider_specialty_filters, + append_text_filter, + append_visit_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_device_exposure( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, DeviceExposure): + raise TypeError("lower_device_exposure requires DeviceExposure criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="device_type_concept_id", + concepts=raw.device_type, + codeset_selection=raw.device_type_cs, + exclude=bool(raw.device_type_exclude), + ) + append_text_filter(steps, column="unique_device_id", value=raw.unique_device_id) + append_numeric_filter(steps, column="quantity", value=raw.quantity) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_visit_filters( + steps, + visit_occurrence_column="visit_occurrence_id", + concepts=raw.visit_type, + codeset_selection=raw.visit_type_cs, + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/dose_era.py b/circe/execution/lower/dose_era.py new file mode 100644 index 00000000..503a5a8a --- /dev/null +++ b/circe/execution/lower/dose_era.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import append_duration_filter, build_standard_domain_plan, lower_common_steps + + +def lower_dose_era( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + steps = lower_common_steps(criterion) + post_standardize_steps = [] + append_duration_filter(post_standardize_steps, value=criterion.raw_criteria.era_length) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + post_standardize_steps=post_standardize_steps, + ) diff --git a/circe/execution/lower/drug_era.py b/circe/execution/lower/drug_era.py new file mode 100644 index 00000000..02afe74c --- /dev/null +++ b/circe/execution/lower/drug_era.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from ..plan.schema import GAP_DAYS, OCCURRENCE_COUNT +from .common import ( + append_duration_filter, + append_numeric_filter, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_drug_era( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + steps = lower_common_steps(criterion) + post_standardize_steps = [] + raw = criterion.raw_criteria + + append_numeric_filter( + post_standardize_steps, + column=OCCURRENCE_COUNT, + value=raw.occurrence_count, + ) + append_numeric_filter( + post_standardize_steps, + column=GAP_DAYS, + value=raw.gap_days, + ) + append_duration_filter(post_standardize_steps, value=raw.era_length) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + post_standardize_steps=post_standardize_steps, + ) diff --git a/circe/execution/lower/drug_exposure.py b/circe/execution/lower/drug_exposure.py new file mode 100644 index 00000000..cdbbf93b --- /dev/null +++ b/circe/execution/lower/drug_exposure.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import DrugExposure +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import ( + append_concept_filters, + append_numeric_filter, + append_provider_specialty_filters, + append_text_filter, + append_visit_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_drug_exposure( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, DrugExposure): + raise TypeError("lower_drug_exposure requires DrugExposure criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="drug_type_concept_id", + concepts=raw.drug_type, + codeset_selection=raw.drug_type_cs, + exclude=bool(raw.drug_type_exclude), + ) + append_text_filter(steps, column="stop_reason", value=raw.stop_reason) + append_concept_filters( + steps, + column="route_concept_id", + concepts=raw.route_concept, + codeset_selection=raw.route_concept_cs, + ) + append_concept_filters( + steps, + column="dose_unit_concept_id", + concepts=raw.dose_unit, + codeset_selection=raw.dose_unit_cs, + ) + append_text_filter(steps, column="lot_number", value=raw.lot_number) + append_numeric_filter(steps, column="refills", value=raw.refills) + append_numeric_filter(steps, column="quantity", value=raw.quantity) + append_numeric_filter(steps, column="days_supply", value=raw.days_supply) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_visit_filters( + steps, + visit_occurrence_column="visit_occurrence_id", + concepts=raw.visit_type, + codeset_selection=raw.visit_type_cs, + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/location_region.py b/circe/execution/lower/location_region.py new file mode 100644 index 00000000..21afdfae --- /dev/null +++ b/circe/execution/lower/location_region.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import ( + EventPlan, + EventSource, + FilterByCodeset, + FilterByText, + JoinLocationRegion, + StandardizeEventShape, +) + + +def lower_location_region( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + steps = [ + FilterByText(column="domain_id", op="eq", text="PERSON"), + JoinLocationRegion(location_id_column="location_id", region_column="region_concept_id"), + ] + if criterion.codeset_id is not None: + steps.append(FilterByCodeset(column="region_concept_id", codeset_id=int(criterion.codeset_id))) + steps.append( + StandardizeEventShape( + criterion_type=criterion.criterion_type, + criterion_index=criterion_index, + ) + ) + + return EventPlan( + source=EventSource( + table_name=criterion.source_table, + domain=criterion.domain, + event_id_column=criterion.event_id_column, + start_date_column=criterion.start_date_column, + end_date_column=criterion.end_date_column, + person_id_column="entity_id", + concept_column=criterion.concept_column, + source_concept_column=criterion.source_concept_column, + visit_occurrence_column=criterion.visit_occurrence_column, + ), + criterion_type=criterion.criterion_type, + criterion_index=criterion_index, + steps=tuple(steps), + ) diff --git a/circe/execution/lower/measurement.py b/circe/execution/lower/measurement.py new file mode 100644 index 00000000..d2491945 --- /dev/null +++ b/circe/execution/lower/measurement.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import Measurement +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import ( + append_concept_filters, + append_numeric_filter, + append_provider_specialty_filters, + append_text_filter, + append_visit_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_measurement( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, Measurement): + raise TypeError("lower_measurement requires Measurement criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="measurement_type_concept_id", + concepts=raw.measurement_type, + codeset_selection=raw.measurement_type_cs, + exclude=bool(raw.measurement_type_exclude), + ) + append_concept_filters( + steps, + column="operator_concept_id", + concepts=raw.operator, + codeset_selection=raw.operator_cs, + ) + append_numeric_filter(steps, column="value_as_number", value=raw.value_as_number) + append_text_filter(steps, column="value_as_string", value=raw.value_as_string) + append_concept_filters( + steps, + column="value_as_concept_id", + concepts=raw.value_as_concept, + codeset_selection=raw.value_as_concept_cs, + ) + append_concept_filters( + steps, + column="unit_concept_id", + concepts=raw.unit, + codeset_selection=raw.unit_cs, + ) + append_numeric_filter(steps, column="range_low", value=raw.range_low) + append_numeric_filter(steps, column="range_high", value=raw.range_high) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_visit_filters( + steps, + visit_occurrence_column="visit_occurrence_id", + concepts=raw.visit_type, + codeset_selection=raw.visit_type_cs, + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/observation.py b/circe/execution/lower/observation.py new file mode 100644 index 00000000..7ff85c21 --- /dev/null +++ b/circe/execution/lower/observation.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import Observation +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import ( + append_concept_filters, + append_numeric_filter, + append_provider_specialty_filters, + append_text_filter, + append_visit_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_observation( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, Observation): + raise TypeError("lower_observation requires Observation criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="observation_type_concept_id", + concepts=raw.observation_type, + codeset_selection=raw.observation_type_cs, + exclude=bool(raw.observation_type_exclude), + ) + append_numeric_filter(steps, column="value_as_number", value=raw.value_as_number) + append_text_filter(steps, column="value_as_string", value=raw.value_as_string) + append_concept_filters( + steps, + column="value_as_concept_id", + concepts=raw.value_as_concept, + codeset_selection=raw.value_as_concept_cs, + ) + append_concept_filters( + steps, + column="unit_concept_id", + concepts=raw.unit, + codeset_selection=raw.unit_cs, + ) + append_concept_filters( + steps, + column="qualifier_concept_id", + concepts=raw.qualifier, + codeset_selection=raw.qualifier_cs, + ) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_visit_filters( + steps, + visit_occurrence_column="visit_occurrence_id", + concepts=raw.visit_type, + codeset_selection=raw.visit_type_cs, + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/observation_period.py b/circe/execution/lower/observation_period.py new file mode 100644 index 00000000..ef2b8a90 --- /dev/null +++ b/circe/execution/lower/observation_period.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import lower_standard_domain_plan + + +def lower_observation_period( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + return lower_standard_domain_plan(criterion, criterion_index=criterion_index) diff --git a/circe/execution/lower/payer_plan_period.py b/circe/execution/lower/payer_plan_period.py new file mode 100644 index 00000000..3a08d1ea --- /dev/null +++ b/circe/execution/lower/payer_plan_period.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import lower_standard_domain_plan + + +def lower_payer_plan_period( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + return lower_standard_domain_plan(criterion, criterion_index=criterion_index) diff --git a/circe/execution/lower/procedure_occurrence.py b/circe/execution/lower/procedure_occurrence.py new file mode 100644 index 00000000..caded791 --- /dev/null +++ b/circe/execution/lower/procedure_occurrence.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import ProcedureOccurrence +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import ( + append_concept_filters, + append_numeric_filter, + append_provider_specialty_filters, + append_visit_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_procedure_occurrence( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, ProcedureOccurrence): + raise TypeError("lower_procedure_occurrence requires ProcedureOccurrence criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="procedure_type_concept_id", + concepts=raw.procedure_type, + codeset_selection=raw.procedure_type_cs, + exclude=bool(raw.procedure_type_exclude), + ) + append_concept_filters( + steps, + column="modifier_concept_id", + concepts=raw.modifier, + codeset_selection=raw.modifier_cs, + ) + append_numeric_filter(steps, column="quantity", value=raw.quantity) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_visit_filters( + steps, + visit_occurrence_column="visit_occurrence_id", + concepts=raw.visit_type, + codeset_selection=raw.visit_type_cs, + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/specimen.py b/circe/execution/lower/specimen.py new file mode 100644 index 00000000..02710630 --- /dev/null +++ b/circe/execution/lower/specimen.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import Specimen +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan +from .common import ( + append_concept_filters, + append_numeric_filter, + append_text_filter, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_specimen( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, Specimen): + raise TypeError("lower_specimen requires Specimen criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="specimen_type_concept_id", + concepts=raw.specimen_type, + codeset_selection=raw.specimen_type_cs, + exclude=bool(raw.specimen_type_exclude), + ) + append_numeric_filter(steps, column="quantity", value=raw.quantity) + append_concept_filters( + steps, + column="unit_concept_id", + concepts=raw.unit, + codeset_selection=raw.unit_cs, + ) + append_concept_filters( + steps, + column="anatomic_site_concept_id", + concepts=raw.anatomic_site, + codeset_selection=raw.anatomic_site_cs, + ) + append_concept_filters( + steps, + column="disease_status_concept_id", + concepts=raw.disease_status, + codeset_selection=raw.disease_status_cs, + ) + append_text_filter(steps, column="specimen_source_id", value=raw.source_id) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/execution/lower/visit_detail.py b/circe/execution/lower/visit_detail.py new file mode 100644 index 00000000..73985026 --- /dev/null +++ b/circe/execution/lower/visit_detail.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import VisitDetail +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan, PlanStep +from .common import ( + append_care_site_filters, + append_care_site_location_region_filter, + append_concept_filters, + append_duration_filter, + append_provider_specialty_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_visit_detail( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, VisitDetail): + raise TypeError("lower_visit_detail requires VisitDetail criteria") + + steps = lower_common_steps(criterion) + post_standardize_steps: list[PlanStep] = [] + + append_concept_filters( + steps, + column="visit_detail_type_concept_id", + concepts=raw.visit_detail_type, + codeset_selection=raw.visit_detail_type_cs, + exclude=bool(raw.visit_detail_type_exclude), + ) + append_concept_filters( + steps, + column="discharge_to_concept_id", + concepts=raw.discharge_to, + codeset_selection=raw.discharge_to_cs, + ) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_care_site_filters( + steps, + concepts=raw.place_of_service, + codeset_selection=raw.place_of_service_cs, + ) + append_care_site_location_region_filter( + steps, + start_date_column=criterion.start_date_column, + end_date_column=criterion.end_date_column, + codeset_id=raw.place_of_service_location, + ) + append_duration_filter(post_standardize_steps, value=raw.visit_detail_length) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + post_standardize_steps=post_standardize_steps, + ) diff --git a/circe/execution/lower/visit_occurrence.py b/circe/execution/lower/visit_occurrence.py new file mode 100644 index 00000000..ef7e9d9e --- /dev/null +++ b/circe/execution/lower/visit_occurrence.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import VisitOccurrence +from ..normalize.criteria import NormalizedCriterion +from ..plan.events import EventPlan, PlanStep +from .common import ( + append_care_site_filters, + append_care_site_location_region_filter, + append_concept_filters, + append_duration_filter, + append_provider_specialty_filters, + build_standard_domain_plan, + lower_common_steps, +) + + +def lower_visit_occurrence( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, VisitOccurrence): + raise TypeError("lower_visit_occurrence requires VisitOccurrence criteria") + + steps = lower_common_steps(criterion) + post_standardize_steps: list[PlanStep] = [] + + append_concept_filters( + steps, + column="visit_type_concept_id", + concepts=raw.visit_type, + codeset_selection=raw.visit_type_cs, + exclude=bool(raw.visit_type_exclude), + ) + append_provider_specialty_filters( + steps, + concepts=raw.provider_specialty, + codeset_selection=raw.provider_specialty_cs, + ) + append_care_site_filters( + steps, + concepts=raw.place_of_service, + codeset_selection=raw.place_of_service_cs, + ) + append_care_site_location_region_filter( + steps, + start_date_column=criterion.start_date_column, + end_date_column=criterion.end_date_column, + codeset_id=raw.place_of_service_location, + ) + append_duration_filter(post_standardize_steps, value=raw.visit_length) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + post_standardize_steps=post_standardize_steps, + ) diff --git a/circe/execution/normalize/__init__.py b/circe/execution/normalize/__init__.py new file mode 100644 index 00000000..f5ec8a35 --- /dev/null +++ b/circe/execution/normalize/__init__.py @@ -0,0 +1,54 @@ +from .cohort import ( + NormalizedCohort, + NormalizedConceptSet, + NormalizedConceptSetItem, + NormalizedPrimaryCriteria, + normalize_cohort, +) +from .collapse import NormalizedCollapseSettings, normalize_collapse_settings +from .criteria import NormalizedCriterion, NormalizedPersonFilters, normalize_criterion +from .end_strategy import NormalizedEndStrategy +from .groups import ( + NormalizedCorrelatedCriteria, + NormalizedCriteriaGroup, + NormalizedDemographicCriteria, + NormalizedInclusionRule, + normalize_criteria_group, + normalize_inclusion_rule, +) +from .windows import ( + NormalizedDateRange, + NormalizedNumericRange, + NormalizedObservationWindow, + NormalizedPeriod, + NormalizedWindow, + NormalizedWindowBound, + normalize_period, +) + +__all__ = [ + "normalize_cohort", + "normalize_criterion", + "normalize_collapse_settings", + "normalize_period", + "NormalizedCohort", + "NormalizedConceptSet", + "NormalizedConceptSetItem", + "NormalizedPrimaryCriteria", + "NormalizedCollapseSettings", + "NormalizedCriterion", + "NormalizedPersonFilters", + "NormalizedEndStrategy", + "NormalizedCorrelatedCriteria", + "NormalizedCriteriaGroup", + "NormalizedDemographicCriteria", + "NormalizedInclusionRule", + "normalize_criteria_group", + "normalize_inclusion_rule", + "NormalizedDateRange", + "NormalizedNumericRange", + "NormalizedObservationWindow", + "NormalizedPeriod", + "NormalizedWindow", + "NormalizedWindowBound", +] diff --git a/circe/execution/normalize/cohort.py b/circe/execution/normalize/cohort.py new file mode 100644 index 00000000..5838fe64 --- /dev/null +++ b/circe/execution/normalize/cohort.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from ...cohortdefinition import CohortExpression +from ...vocabulary.concept import ConceptSet +from .._dataclass import frozen_slots_dataclass +from ..errors import ExecutionNormalizationError, UnsupportedFeatureError +from .collapse import NormalizedCollapseSettings, normalize_collapse_settings +from .criteria import NormalizedCriterion, normalize_criterion +from .end_strategy import NormalizedEndStrategy, normalize_end_strategy +from .groups import ( + NormalizedCriteriaGroup, + NormalizedInclusionRule, + normalize_criteria_group, + normalize_inclusion_rule, +) +from .windows import ( + NormalizedObservationWindow, + NormalizedPeriod, + normalize_observation_window, + normalize_period, +) + + +@frozen_slots_dataclass +class NormalizedPrimaryCriteria: + criteria: tuple[NormalizedCriterion, ...] + observation_window: NormalizedObservationWindow | None + primary_limit_type: str + + +@frozen_slots_dataclass +class NormalizedResultLimits: + qualified_limit_type: str + expression_limit_type: str + + +@frozen_slots_dataclass +class NormalizedConceptSetItem: + concept_id: int + is_excluded: bool + include_descendants: bool + include_mapped: bool + + +@frozen_slots_dataclass +class NormalizedConceptSet: + set_id: int + items: tuple[NormalizedConceptSetItem, ...] + + +@frozen_slots_dataclass +class NormalizedCohort: + title: str | None + concept_sets: dict[int, NormalizedConceptSet] + primary: NormalizedPrimaryCriteria + result_limits: NormalizedResultLimits + additional_criteria: NormalizedCriteriaGroup | None + inclusion_rules: tuple[NormalizedInclusionRule, ...] + censoring_criteria: tuple[NormalizedCriterion, ...] + censor_window: NormalizedPeriod | None + collapse_settings: NormalizedCollapseSettings | None + end_strategy: NormalizedEndStrategy | None + + +def _normalized_item( + *, + concept_id: int, + is_excluded: bool, + include_descendants: bool, + include_mapped: bool, +) -> NormalizedConceptSetItem: + return NormalizedConceptSetItem( + concept_id=int(concept_id), + is_excluded=bool(is_excluded), + include_descendants=bool(include_descendants), + include_mapped=bool(include_mapped), + ) + + +def _extract_codesets(concept_sets: list[ConceptSet]) -> dict[int, NormalizedConceptSet]: + output: dict[int, NormalizedConceptSet] = {} + + for concept_set in concept_sets or []: + if concept_set is None or concept_set.id is None: + continue + set_id = int(concept_set.id) + expression = concept_set.expression + if not expression: + continue + + items: list[NormalizedConceptSetItem] = [] + + if expression.concept is not None and expression.concept.concept_id is not None: + items.append( + _normalized_item( + concept_id=int(expression.concept.concept_id), + is_excluded=bool(expression.is_excluded), + include_descendants=bool(expression.include_descendants), + include_mapped=bool(expression.include_mapped), + ) + ) + + for item in expression.items or []: + if item is None: + continue + if item.concept is None or item.concept.concept_id is None: + continue + items.append( + _normalized_item( + concept_id=int(item.concept.concept_id), + is_excluded=bool(item.is_excluded), + include_descendants=bool(item.include_descendants), + include_mapped=bool(item.include_mapped), + ) + ) + + output[set_id] = NormalizedConceptSet( + set_id=set_id, + items=tuple(items), + ) + + return output + + +def normalize_cohort( + expression: CohortExpression, +) -> NormalizedCohort: + primary = expression.primary_criteria + if primary is None or not primary.criteria_list: + raise ExecutionNormalizationError( + "Ibis executor normalization error: CohortExpression must contain at least one primary criterion." + ) + + normalized_criteria = tuple(normalize_criterion(criteria) for criteria in primary.criteria_list) + normalized_primary = NormalizedPrimaryCriteria( + criteria=normalized_criteria, + observation_window=normalize_observation_window(primary.observation_window), + primary_limit_type=( + (primary.primary_limit.type if primary.primary_limit else "all") or "all" + ).lower(), + ) + normalized_limits = NormalizedResultLimits( + qualified_limit_type=( + (expression.qualified_limit.type if expression.qualified_limit else "all") or "all" + ).lower(), + expression_limit_type=( + (expression.expression_limit.type if expression.expression_limit else "all") or "all" + ).lower(), + ) + + normalized_end_strategy = normalize_end_strategy(expression.end_strategy) + if normalized_end_strategy is not None and normalized_end_strategy.kind == "custom_era": + raise UnsupportedFeatureError( + "Ibis executor normalization error: custom_era end strategy is not supported." + ) + + return NormalizedCohort( + title=expression.title, + concept_sets=_extract_codesets(expression.concept_sets), + primary=normalized_primary, + result_limits=normalized_limits, + additional_criteria=normalize_criteria_group(expression.additional_criteria), + inclusion_rules=tuple(normalize_inclusion_rule(rule) for rule in expression.inclusion_rules), + censoring_criteria=tuple(normalize_criterion(criteria) for criteria in expression.censoring_criteria), + censor_window=normalize_period(expression.censor_window), + collapse_settings=normalize_collapse_settings(expression.collapse_settings), + end_strategy=normalized_end_strategy, + ) diff --git a/circe/execution/normalize/collapse.py b/circe/execution/normalize/collapse.py new file mode 100644 index 00000000..1ecd18c0 --- /dev/null +++ b/circe/execution/normalize/collapse.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from ...cohortdefinition.core import CollapseSettings +from .._dataclass import frozen_slots_dataclass + + +@frozen_slots_dataclass +class NormalizedCollapseSettings: + era_pad: int + collapse_type: str + + +def normalize_collapse_settings( + value: CollapseSettings | None, +) -> NormalizedCollapseSettings | None: + if value is None: + return None + collapse_type = "era" + if value.collapse_type is not None: + collapse_type = str(value.collapse_type).lower() + return NormalizedCollapseSettings( + era_pad=int(value.era_pad), + collapse_type=collapse_type, + ) diff --git a/circe/execution/normalize/criteria.py b/circe/execution/normalize/criteria.py new file mode 100644 index 00000000..f0af5979 --- /dev/null +++ b/circe/execution/normalize/criteria.py @@ -0,0 +1,493 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING + +from ...cohortdefinition.criteria import ( + ConditionEra, + ConditionOccurrence, + Criteria, + Death, + DeviceExposure, + DoseEra, + DrugEra, + DrugExposure, + LocationRegion, + Measurement, + Observation, + ObservationPeriod, + PayerPlanPeriod, + ProcedureOccurrence, + Specimen, + VisitDetail, + VisitOccurrence, +) +from ...vocabulary.concept import Concept +from .._dataclass import frozen_slots_dataclass +from ..errors import UnsupportedCriterionError +from .windows import ( + NormalizedDateRange, + NormalizedNumericRange, + normalize_date_range, + normalize_numeric_range, +) + +if TYPE_CHECKING: + from .groups import NormalizedCriteriaGroup + + +@frozen_slots_dataclass +class NormalizedPersonFilters: + age: NormalizedNumericRange | None = None + gender_concept_ids: tuple[int, ...] = () + gender_codeset_id: int | None = None + race_concept_ids: tuple[int, ...] = () + race_codeset_id: int | None = None + ethnicity_concept_ids: tuple[int, ...] = () + ethnicity_codeset_id: int | None = None + + +@frozen_slots_dataclass +class NormalizedCriterion: + raw_criteria: Criteria + criterion_type: str + domain: str + source_table: str + event_id_column: str + start_date_column: str + end_date_column: str + concept_column: str | None + source_concept_column: str | None + visit_occurrence_column: str | None + codeset_id: int | None + first: bool + occurrence_start_date: NormalizedDateRange | None + occurrence_end_date: NormalizedDateRange | None + person_filters: NormalizedPersonFilters + correlated_criteria: NormalizedCriteriaGroup | None = None + + +def _concept_ids(values: list[Concept] | None) -> tuple[int, ...]: + if not values: + return () + output: list[int] = [] + for concept in values: + if concept is None or concept.concept_id is None: + continue + cid = int(concept.concept_id) + if cid not in output: + output.append(cid) + return tuple(output) + + +def _person_filters_from_criterion(criteria: Criteria) -> NormalizedPersonFilters: + return NormalizedPersonFilters( + age=normalize_numeric_range(getattr(criteria, "age", None)), + gender_concept_ids=_concept_ids(getattr(criteria, "gender", None)), + gender_codeset_id=( + int(criteria.gender_cs.codeset_id) + if getattr(criteria, "gender_cs", None) and criteria.gender_cs.codeset_id is not None + else None + ), + race_concept_ids=_concept_ids(getattr(criteria, "race", None)), + race_codeset_id=( + int(criteria.race_cs.codeset_id) + if getattr(criteria, "race_cs", None) and criteria.race_cs.codeset_id is not None + else None + ), + ethnicity_concept_ids=_concept_ids(getattr(criteria, "ethnicity", None)), + ethnicity_codeset_id=( + int(criteria.ethnicity_cs.codeset_id) + if getattr(criteria, "ethnicity_cs", None) and criteria.ethnicity_cs.codeset_id is not None + else None + ), + ) + + +def _build_normalized_criterion( + *, + criteria: Criteria, + criterion_type: str, + domain: str, + source_table: str, + event_id_column: str, + start_date_column: str, + end_date_column: str, + concept_column: str | None, + source_concept_column: str | None, + visit_occurrence_column: str | None, + codeset_id: int | None, + first: bool, + occurrence_start_date: NormalizedDateRange | None, + occurrence_end_date: NormalizedDateRange | None, +) -> NormalizedCriterion: + return NormalizedCriterion( + raw_criteria=criteria, + criterion_type=criterion_type, + domain=domain, + source_table=source_table, + event_id_column=event_id_column, + start_date_column=start_date_column, + end_date_column=end_date_column, + concept_column=concept_column, + source_concept_column=source_concept_column, + visit_occurrence_column=visit_occurrence_column, + codeset_id=codeset_id, + first=first, + occurrence_start_date=occurrence_start_date, + occurrence_end_date=occurrence_end_date, + person_filters=_person_filters_from_criterion(criteria), + ) + + +def _normalize_condition_occurrence(criteria: ConditionOccurrence) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="ConditionOccurrence", + domain="condition_occurrence", + source_table="condition_occurrence", + event_id_column="condition_occurrence_id", + start_date_column="condition_start_date", + end_date_column="condition_end_date", + concept_column="condition_concept_id", + source_concept_column="condition_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_drug_exposure(criteria: DrugExposure) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="DrugExposure", + domain="drug_exposure", + source_table="drug_exposure", + event_id_column="drug_exposure_id", + start_date_column="drug_exposure_start_date", + end_date_column="drug_exposure_end_date", + concept_column="drug_concept_id", + source_concept_column="drug_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_visit_occurrence(criteria: VisitOccurrence) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="VisitOccurrence", + domain="visit_occurrence", + source_table="visit_occurrence", + event_id_column="visit_occurrence_id", + start_date_column="visit_start_date", + end_date_column="visit_end_date", + concept_column="visit_concept_id", + source_concept_column="visit_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_measurement(criteria: Measurement) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="Measurement", + domain="measurement", + source_table="measurement", + event_id_column="measurement_id", + start_date_column="measurement_date", + end_date_column="measurement_date", + concept_column="measurement_concept_id", + source_concept_column="measurement_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_procedure_occurrence( + criteria: ProcedureOccurrence, +) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="ProcedureOccurrence", + domain="procedure_occurrence", + source_table="procedure_occurrence", + event_id_column="procedure_occurrence_id", + start_date_column="procedure_date", + end_date_column="procedure_date", + concept_column="procedure_concept_id", + source_concept_column="procedure_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_observation(criteria: Observation) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="Observation", + domain="observation", + source_table="observation", + event_id_column="observation_id", + start_date_column="observation_date", + end_date_column="observation_date", + concept_column="observation_concept_id", + source_concept_column="observation_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_visit_detail(criteria: VisitDetail) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="VisitDetail", + domain="visit_detail", + source_table="visit_detail", + event_id_column="visit_detail_id", + start_date_column="visit_detail_start_date", + end_date_column="visit_detail_end_date", + concept_column="visit_detail_concept_id", + source_concept_column="visit_detail_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.visit_detail_start_date), + occurrence_end_date=normalize_date_range(criteria.visit_detail_end_date), + ) + + +def _normalize_device_exposure(criteria: DeviceExposure) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="DeviceExposure", + domain="device_exposure", + source_table="device_exposure", + event_id_column="device_exposure_id", + start_date_column="device_exposure_start_date", + end_date_column="device_exposure_end_date", + concept_column="device_concept_id", + source_concept_column="device_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_specimen(criteria: Specimen) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="Specimen", + domain="specimen", + source_table="specimen", + event_id_column="specimen_id", + start_date_column="specimen_date", + end_date_column="specimen_date", + concept_column="specimen_concept_id", + source_concept_column="specimen_source_concept_id", + visit_occurrence_column="visit_occurrence_id", + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_date), + ) + + +def _normalize_death(criteria: Death) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="Death", + domain="death", + source_table="death", + event_id_column="person_id", + start_date_column="death_date", + end_date_column="death_date", + concept_column="cause_concept_id", + source_concept_column="cause_source_concept_id", + visit_occurrence_column=None, + codeset_id=criteria.codeset_id, + first=False, + occurrence_start_date=normalize_date_range(criteria.occurrence_start_date), + occurrence_end_date=None, + ) + + +def _normalize_observation_period(criteria: ObservationPeriod) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="ObservationPeriod", + domain="observation_period", + source_table="observation_period", + event_id_column="observation_period_id", + start_date_column="observation_period_start_date", + end_date_column="observation_period_end_date", + concept_column="period_type_concept_id", + source_concept_column=None, + visit_occurrence_column=None, + codeset_id=None, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.period_start_date), + occurrence_end_date=normalize_date_range(criteria.period_end_date), + ) + + +def _normalize_payer_plan_period(criteria: PayerPlanPeriod) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="PayerPlanPeriod", + domain="payer_plan_period", + source_table="payer_plan_period", + event_id_column="payer_plan_period_id", + start_date_column="payer_plan_period_start_date", + end_date_column="payer_plan_period_end_date", + concept_column="payer_concept_id", + source_concept_column="payer_source_concept_id", + visit_occurrence_column=None, + codeset_id=None, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.period_start_date), + occurrence_end_date=normalize_date_range(criteria.period_end_date), + ) + + +def _normalize_condition_era(criteria: ConditionEra) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="ConditionEra", + domain="condition_era", + source_table="condition_era", + event_id_column="condition_era_id", + start_date_column="condition_era_start_date", + end_date_column="condition_era_end_date", + concept_column="condition_concept_id", + source_concept_column=None, + visit_occurrence_column=None, + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.era_start_date), + occurrence_end_date=normalize_date_range(criteria.era_end_date), + ) + + +def _normalize_drug_era(criteria: DrugEra) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="DrugEra", + domain="drug_era", + source_table="drug_era", + event_id_column="drug_era_id", + start_date_column="drug_era_start_date", + end_date_column="drug_era_end_date", + concept_column="drug_concept_id", + source_concept_column=None, + visit_occurrence_column=None, + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.era_start_date), + occurrence_end_date=normalize_date_range(criteria.era_end_date), + ) + + +def _normalize_dose_era(criteria: DoseEra) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="DoseEra", + domain="dose_era", + source_table="dose_era", + event_id_column="dose_era_id", + start_date_column="dose_era_start_date", + end_date_column="dose_era_end_date", + concept_column="drug_concept_id", + source_concept_column=None, + visit_occurrence_column=None, + codeset_id=criteria.codeset_id, + first=bool(criteria.first), + occurrence_start_date=normalize_date_range(criteria.era_start_date), + occurrence_end_date=normalize_date_range(criteria.era_end_date), + ) + + +def _normalize_location_region(criteria: LocationRegion) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="LocationRegion", + domain="location_region", + source_table="location_history", + event_id_column="location_id", + start_date_column="start_date", + end_date_column="end_date", + concept_column="region_concept_id", + source_concept_column=None, + visit_occurrence_column=None, + codeset_id=criteria.codeset_id, + first=False, + occurrence_start_date=None, + occurrence_end_date=None, + ) + + +def normalize_criterion(criteria: Criteria) -> NormalizedCriterion: + if isinstance(criteria, ConditionOccurrence): + normalized = _normalize_condition_occurrence(criteria) + elif isinstance(criteria, DrugExposure): + normalized = _normalize_drug_exposure(criteria) + elif isinstance(criteria, VisitOccurrence): + normalized = _normalize_visit_occurrence(criteria) + elif isinstance(criteria, Measurement): + normalized = _normalize_measurement(criteria) + elif isinstance(criteria, ProcedureOccurrence): + normalized = _normalize_procedure_occurrence(criteria) + elif isinstance(criteria, Observation): + normalized = _normalize_observation(criteria) + elif isinstance(criteria, VisitDetail): + normalized = _normalize_visit_detail(criteria) + elif isinstance(criteria, DeviceExposure): + normalized = _normalize_device_exposure(criteria) + elif isinstance(criteria, Specimen): + normalized = _normalize_specimen(criteria) + elif isinstance(criteria, Death): + normalized = _normalize_death(criteria) + elif isinstance(criteria, ObservationPeriod): + normalized = _normalize_observation_period(criteria) + elif isinstance(criteria, PayerPlanPeriod): + normalized = _normalize_payer_plan_period(criteria) + elif isinstance(criteria, ConditionEra): + normalized = _normalize_condition_era(criteria) + elif isinstance(criteria, DrugEra): + normalized = _normalize_drug_era(criteria) + elif isinstance(criteria, DoseEra): + normalized = _normalize_dose_era(criteria) + elif isinstance(criteria, LocationRegion): + normalized = _normalize_location_region(criteria) + else: + raise UnsupportedCriterionError( + f"Ibis executor normalization error: unsupported criterion type {criteria.__class__.__name__}." + ) + + if criteria.correlated_criteria is not None and not criteria.correlated_criteria.is_empty(): + from .groups import normalize_criteria_group + + normalized_group = normalize_criteria_group(criteria.correlated_criteria) + normalized = replace(normalized, correlated_criteria=normalized_group) + + return normalized diff --git a/circe/execution/normalize/end_strategy.py b/circe/execution/normalize/end_strategy.py new file mode 100644 index 00000000..62ff666b --- /dev/null +++ b/circe/execution/normalize/end_strategy.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any + +from ...cohortdefinition.core import CustomEraStrategy, DateOffsetStrategy, EndStrategy +from .._dataclass import frozen_slots_dataclass + + +@frozen_slots_dataclass +class NormalizedEndStrategy: + kind: str + payload: dict[str, Any] + + +def normalize_end_strategy( + value: EndStrategy | DateOffsetStrategy | CustomEraStrategy | None, +) -> NormalizedEndStrategy | None: + if value is None: + return None + if isinstance(value, DateOffsetStrategy): + return NormalizedEndStrategy( + kind="date_offset", + payload={ + "offset": int(value.offset), + "date_field": str(value.date_field), + }, + ) + if isinstance(value, CustomEraStrategy): + return NormalizedEndStrategy( + kind="custom_era", + payload={ + "drug_codeset_id": value.drug_codeset_id, + "offset": int(value.offset), + "gap_days": int(value.gap_days), + }, + ) + return NormalizedEndStrategy(kind="end_strategy", payload={}) diff --git a/circe/execution/normalize/groups.py b/circe/execution/normalize/groups.py new file mode 100644 index 00000000..23c8d55d --- /dev/null +++ b/circe/execution/normalize/groups.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from ...cohortdefinition.criteria import ( + CorelatedCriteria, + CriteriaGroup, + DemographicCriteria, + InclusionRule, + Occurrence, +) +from ...vocabulary.concept import Concept +from .._dataclass import frozen_slots_dataclass +from .criteria import NormalizedCriterion, normalize_criterion +from .windows import ( + NormalizedDateRange, + NormalizedNumericRange, + NormalizedWindow, + normalize_date_range, + normalize_numeric_range, + normalize_window, +) + + +@frozen_slots_dataclass +class NormalizedDemographicCriteria: + age: NormalizedNumericRange | None = None + gender_codeset_id: int | None = None + gender_concept_ids: tuple[int, ...] = () + race_codeset_id: int | None = None + race_concept_ids: tuple[int, ...] = () + ethnicity_codeset_id: int | None = None + ethnicity_concept_ids: tuple[int, ...] = () + occurrence_start_date: NormalizedDateRange | None = None + occurrence_end_date: NormalizedDateRange | None = None + + +@frozen_slots_dataclass +class NormalizedCorrelatedCriteria: + criterion: NormalizedCriterion + occurrence_type: int + occurrence_count: int + occurrence_is_distinct: bool + occurrence_count_column: str | None + start_window: NormalizedWindow | None + end_window: NormalizedWindow | None + restrict_visit: bool + ignore_observation_period: bool + + +@frozen_slots_dataclass +class NormalizedCriteriaGroup: + mode: str + count: int | None = None + criteria: tuple[NormalizedCorrelatedCriteria, ...] = () + groups: tuple[NormalizedCriteriaGroup, ...] = () + demographics: tuple[NormalizedDemographicCriteria, ...] = () + + def is_empty(self) -> bool: + return not self.criteria and not self.groups and not self.demographics + + +@frozen_slots_dataclass +class NormalizedInclusionRule: + name: str | None + description: str | None + expression: NormalizedCriteriaGroup | None + + +def _concept_ids(values: list[Concept] | None) -> tuple[int, ...]: + if not values: + return () + output: list[int] = [] + for concept in values: + if concept is None or concept.concept_id is None: + continue + cid = int(concept.concept_id) + if cid not in output: + output.append(cid) + return tuple(output) + + +def _normalize_demographic( + demographic: DemographicCriteria, +) -> NormalizedDemographicCriteria: + return NormalizedDemographicCriteria( + age=normalize_numeric_range(demographic.age), + gender_codeset_id=( + int(demographic.gender_cs.codeset_id) + if demographic.gender_cs and demographic.gender_cs.codeset_id is not None + else None + ), + gender_concept_ids=_concept_ids(demographic.gender), + race_codeset_id=( + int(demographic.race_cs.codeset_id) + if demographic.race_cs and demographic.race_cs.codeset_id is not None + else None + ), + race_concept_ids=_concept_ids(demographic.race), + ethnicity_codeset_id=( + int(demographic.ethnicity_cs.codeset_id) + if demographic.ethnicity_cs and demographic.ethnicity_cs.codeset_id is not None + else None + ), + ethnicity_concept_ids=_concept_ids(demographic.ethnicity), + occurrence_start_date=normalize_date_range(demographic.occurrence_start_date), + occurrence_end_date=normalize_date_range(demographic.occurrence_end_date), + ) + + +def _normalize_correlated_criteria( + correlated: CorelatedCriteria, +) -> NormalizedCorrelatedCriteria: + occurrence = correlated.occurrence or Occurrence( + type=Occurrence._AT_LEAST, + count=1, + is_distinct=False, + ) + + count_column = None + if occurrence.count_column is not None: + count_column = occurrence.count_column.value + + return NormalizedCorrelatedCriteria( + criterion=normalize_criterion(correlated.criteria), + occurrence_type=int(occurrence.type), + occurrence_count=int(occurrence.count), + occurrence_is_distinct=bool(occurrence.is_distinct), + occurrence_count_column=count_column, + start_window=normalize_window(correlated.start_window), + end_window=normalize_window(correlated.end_window), + restrict_visit=bool(correlated.restrict_visit), + ignore_observation_period=bool(correlated.ignore_observation_period), + ) + + +def normalize_criteria_group( + group: CriteriaGroup | None, +) -> NormalizedCriteriaGroup | None: + if group is None: + return None + + normalized_children: list[NormalizedCriteriaGroup] = [] + for child in group.groups or []: + normalized_child = normalize_criteria_group(child) + if normalized_child is not None: + normalized_children.append(normalized_child) + + return NormalizedCriteriaGroup( + mode=((group.type or "ALL").upper()), + count=(int(group.count) if group.count is not None else None), + criteria=tuple( + _normalize_correlated_criteria(correlated) for correlated in (group.criteria_list or []) + ), + groups=tuple(normalized_children), + demographics=tuple( + _normalize_demographic(demographic) for demographic in (group.demographic_criteria_list or []) + ), + ) + + +def normalize_inclusion_rule(rule: InclusionRule) -> NormalizedInclusionRule: + return NormalizedInclusionRule( + name=rule.name, + description=rule.description, + expression=normalize_criteria_group(rule.expression), + ) diff --git a/circe/execution/normalize/windows.py b/circe/execution/normalize/windows.py new file mode 100644 index 00000000..ab87aa8e --- /dev/null +++ b/circe/execution/normalize/windows.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import Any + +from ...cohortdefinition.core import ( + DateRange, + NumericRange, + ObservationFilter, + Period, + Window, + WindowBound, +) +from .._dataclass import frozen_slots_dataclass + + +@frozen_slots_dataclass +class NormalizedDateRange: + op: str | None + value: Any + extent: Any + + +@frozen_slots_dataclass +class NormalizedNumericRange: + op: str | None + value: float | int | None + extent: float | int | None + + +@frozen_slots_dataclass +class NormalizedObservationWindow: + prior_days: int + post_days: int + + +@frozen_slots_dataclass +class NormalizedPeriod: + start_date: str | None + end_date: str | None + + +@frozen_slots_dataclass +class NormalizedWindowBound: + coeff: int + days: int | None + + +@frozen_slots_dataclass +class NormalizedWindow: + start: NormalizedWindowBound | None + end: NormalizedWindowBound | None + use_event_end: bool | None + use_index_end: bool | None + + +def normalize_date_range(value: DateRange | None) -> NormalizedDateRange | None: + if value is None: + return None + return NormalizedDateRange(op=value.op, value=value.value, extent=value.extent) + + +def normalize_numeric_range( + value: NumericRange | None, +) -> NormalizedNumericRange | None: + if value is None: + return None + return NormalizedNumericRange(op=value.op, value=value.value, extent=value.extent) + + +def normalize_observation_window( + value: ObservationFilter | None, +) -> NormalizedObservationWindow | None: + if value is None: + return None + return NormalizedObservationWindow( + prior_days=int(value.prior_days), + post_days=int(value.post_days), + ) + + +def normalize_period(value: Period | None) -> NormalizedPeriod | None: + if value is None: + return None + return NormalizedPeriod(start_date=value.start_date, end_date=value.end_date) + + +def normalize_window_bound( + value: WindowBound | None, +) -> NormalizedWindowBound | None: + if value is None: + return None + return NormalizedWindowBound(coeff=int(value.coeff), days=value.days) + + +def normalize_window(value: Window | None) -> NormalizedWindow | None: + if value is None: + return None + return NormalizedWindow( + start=normalize_window_bound(value.start), + end=normalize_window_bound(value.end), + use_event_end=value.use_event_end, + use_index_end=value.use_index_end, + ) diff --git a/circe/execution/options.py b/circe/execution/options.py deleted file mode 100644 index b88f1a6f..00000000 --- a/circe/execution/options.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Execution options for backend-native cohort execution.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Union - -SchemaName = Union[str, tuple[str, str]] - - -@dataclass(frozen=True) -class ExecutionOptions: - """Runtime options for backend execution via ibis. - - This API is experimental and may evolve while execution parity is built out. - """ - - cdm_schema: SchemaName | None = None - vocabulary_schema: SchemaName | None = None - result_schema: SchemaName | None = None - - cohort_id: int | None = None - - materialize_stages: bool = False - materialize_codesets: bool = True - temp_emulation_schema: SchemaName | None = None - - capture_sql: bool = False - profile_dir: str | None = None - - -def schema_to_str(schema: SchemaName | None) -> str | None: - """Normalize schema names to a string representation.""" - if schema is None: - return None - if isinstance(schema, tuple): - return ".".join(schema) - return schema diff --git a/circe/execution/plan/__init__.py b/circe/execution/plan/__init__.py new file mode 100644 index 00000000..72d5ac2b --- /dev/null +++ b/circe/execution/plan/__init__.py @@ -0,0 +1,103 @@ +from .cohort import CohortPlan, PrimaryEventInput +from .events import ( + ApplyDateAdjustment, + EventPlan, + EventSource, + FilterByCareSite, + FilterByCareSiteLocationRegion, + FilterByCodeset, + FilterByConceptSet, + FilterByDateRange, + FilterByNumericRange, + FilterByPersonAge, + FilterByPersonEthnicity, + FilterByPersonGender, + FilterByPersonRace, + FilterByProviderSpecialty, + FilterByText, + FilterByVisit, + FilterByVisitDetail, + JoinLocationRegion, + KeepFirstPerPerson, + RestrictToCorrelatedWindow, + StandardizeEventShape, +) +from .groups import GroupPredicate +from .predicates import DateRangePredicate, NumericRangePredicate +from .schema import ( + CONCEPT_ID, + CRITERION_INDEX, + CRITERION_TYPE, + DAYS_SUPPLY, + DOMAIN, + DURATION, + END_DATE, + EVENT_ID, + GAP_DAYS, + OCCURRENCE_COUNT, + PERSON_ID, + QUANTITY, + RANGE_HIGH, + RANGE_LOW, + REFILLS, + SOURCE_CONCEPT_ID, + SOURCE_TABLE, + STANDARD_EVENT_COLUMNS, + START_DATE, + UNIT_CONCEPT_ID, + VALUE_AS_NUMBER, + VISIT_DETAIL_ID, + VISIT_OCCURRENCE_ID, +) + +__all__ = [ + "CohortPlan", + "PrimaryEventInput", + "EventPlan", + "EventSource", + "GroupPredicate", + "DateRangePredicate", + "NumericRangePredicate", + "PERSON_ID", + "EVENT_ID", + "START_DATE", + "END_DATE", + "VISIT_OCCURRENCE_ID", + "DOMAIN", + "CONCEPT_ID", + "SOURCE_CONCEPT_ID", + "CRITERION_INDEX", + "CRITERION_TYPE", + "QUANTITY", + "DAYS_SUPPLY", + "REFILLS", + "RANGE_LOW", + "RANGE_HIGH", + "VALUE_AS_NUMBER", + "UNIT_CONCEPT_ID", + "VISIT_DETAIL_ID", + "OCCURRENCE_COUNT", + "GAP_DAYS", + "DURATION", + "SOURCE_TABLE", + "STANDARD_EVENT_COLUMNS", + "FilterByCareSite", + "FilterByCareSiteLocationRegion", + "FilterByCodeset", + "FilterByConceptSet", + "FilterByDateRange", + "FilterByNumericRange", + "FilterByText", + "FilterByVisit", + "FilterByVisitDetail", + "JoinLocationRegion", + "FilterByProviderSpecialty", + "FilterByPersonAge", + "FilterByPersonGender", + "FilterByPersonRace", + "FilterByPersonEthnicity", + "KeepFirstPerPerson", + "ApplyDateAdjustment", + "RestrictToCorrelatedWindow", + "StandardizeEventShape", +] diff --git a/circe/execution/plan/cohort.py b/circe/execution/plan/cohort.py new file mode 100644 index 00000000..0e3a922a --- /dev/null +++ b/circe/execution/plan/cohort.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from .._dataclass import frozen_slots_dataclass +from ..normalize.groups import NormalizedCriteriaGroup +from ..normalize.windows import NormalizedObservationWindow +from .events import EventPlan + + +@frozen_slots_dataclass +class PrimaryEventInput: + event_plan: EventPlan + correlated_criteria: NormalizedCriteriaGroup | None = None + + +@frozen_slots_dataclass +class CohortPlan: + primary_event_plans: tuple[PrimaryEventInput, ...] + observation_window: NormalizedObservationWindow | None + primary_limit_type: str + qualified_limit_type: str + expression_limit_type: str diff --git a/circe/execution/plan/events.py b/circe/execution/plan/events.py new file mode 100644 index 00000000..99652fd7 --- /dev/null +++ b/circe/execution/plan/events.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from typing import Any, Union + +from .._dataclass import frozen_slots_dataclass +from .predicates import DateRangePredicate, NumericRangePredicate +from .schema import PERSON_ID + + +@frozen_slots_dataclass +class EventSource: + table_name: str + domain: str + event_id_column: str + start_date_column: str + end_date_column: str + person_id_column: str = PERSON_ID + concept_column: str | None = None + source_concept_column: str | None = None + visit_occurrence_column: str | None = None + + +@frozen_slots_dataclass +class FilterByCodeset: + column: str + codeset_id: int + exclude: bool = False + + +@frozen_slots_dataclass +class FilterByConceptSet: + column: str + concept_ids: tuple[int, ...] + exclude: bool = False + + +@frozen_slots_dataclass +class FilterByDateRange: + column: str + predicate: DateRangePredicate + + +@frozen_slots_dataclass +class FilterByNumericRange: + column: str + predicate: NumericRangePredicate + + +@frozen_slots_dataclass +class FilterByText: + column: str + op: str | None + text: str | None + + +@frozen_slots_dataclass +class JoinLocationRegion: + location_id_column: str = "location_id" + region_column: str = "region_concept_id" + + +@frozen_slots_dataclass +class FilterByVisit: + visit_occurrence_column: str = "visit_occurrence_id" + concept_ids: tuple[int, ...] = () + codeset_id: int | None = None + exclude: bool = False + + +@frozen_slots_dataclass +class FilterByVisitDetail: + visit_detail_codeset_id: int | None = None + + +@frozen_slots_dataclass +class FilterByProviderSpecialty: + provider_id_column: str = "provider_id" + concept_ids: tuple[int, ...] = () + codeset_id: int | None = None + exclude: bool = False + + +@frozen_slots_dataclass +class FilterByCareSite: + care_site_id_column: str = "care_site_id" + concept_ids: tuple[int, ...] = () + codeset_id: int | None = None + exclude: bool = False + + +@frozen_slots_dataclass +class FilterByCareSiteLocationRegion: + care_site_id_column: str = "care_site_id" + start_date_column: str = "start_date" + end_date_column: str = "end_date" + codeset_id: int = 0 + + +@frozen_slots_dataclass +class FilterByPersonAge: + date_column: str + predicate: NumericRangePredicate + + +@frozen_slots_dataclass +class FilterByPersonGender: + concept_ids: tuple[int, ...] = () + codeset_id: int | None = None + + +@frozen_slots_dataclass +class FilterByPersonRace: + concept_ids: tuple[int, ...] = () + codeset_id: int | None = None + + +@frozen_slots_dataclass +class FilterByPersonEthnicity: + concept_ids: tuple[int, ...] = () + codeset_id: int | None = None + + +@frozen_slots_dataclass +class KeepFirstPerPerson: + order_by: tuple[str, ...] + + +@frozen_slots_dataclass +class ApplyDateAdjustment: + start_offset_days: int + end_offset_days: int + start_with: str = "start_date" + end_with: str = "end_date" + + +@frozen_slots_dataclass +class RestrictToCorrelatedWindow: + payload: dict[str, Any] + + +@frozen_slots_dataclass +class StandardizeEventShape: + criterion_type: str + criterion_index: int + start_offset_days: int = 0 + end_offset_days: int = 0 + start_with: str = "start_date" + end_with: str = "end_date" + + +PlanStep = Union[ + FilterByCodeset, + FilterByConceptSet, + FilterByDateRange, + FilterByNumericRange, + FilterByText, + JoinLocationRegion, + FilterByVisit, + FilterByVisitDetail, + FilterByProviderSpecialty, + FilterByCareSite, + FilterByCareSiteLocationRegion, + FilterByPersonAge, + FilterByPersonGender, + FilterByPersonRace, + FilterByPersonEthnicity, + KeepFirstPerPerson, + ApplyDateAdjustment, + RestrictToCorrelatedWindow, + StandardizeEventShape, +] + + +@frozen_slots_dataclass +class EventPlan: + source: EventSource + criterion_type: str + criterion_index: int + steps: tuple[PlanStep, ...] diff --git a/circe/execution/plan/groups.py b/circe/execution/plan/groups.py new file mode 100644 index 00000000..c6ba3958 --- /dev/null +++ b/circe/execution/plan/groups.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from .._dataclass import frozen_slots_dataclass + + +@frozen_slots_dataclass +class GroupPredicate: + mode: str + count: int | None = None + children: tuple[GroupPredicate, ...] = () diff --git a/circe/execution/plan/predicates.py b/circe/execution/plan/predicates.py new file mode 100644 index 00000000..cbfca913 --- /dev/null +++ b/circe/execution/plan/predicates.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import Any + +from .._dataclass import frozen_slots_dataclass + + +@frozen_slots_dataclass +class DateRangePredicate: + op: str | None + value: Any + extent: Any + + +@frozen_slots_dataclass +class NumericRangePredicate: + op: str | None + value: float | int | None + extent: float | int | None diff --git a/circe/execution/plan/schema.py b/circe/execution/plan/schema.py new file mode 100644 index 00000000..061815f0 --- /dev/null +++ b/circe/execution/plan/schema.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +PERSON_ID = "person_id" +EVENT_ID = "event_id" +START_DATE = "start_date" +END_DATE = "end_date" +VISIT_OCCURRENCE_ID = "visit_occurrence_id" +VISIT_DETAIL_ID = "visit_detail_id" +DOMAIN = "domain" +CONCEPT_ID = "concept_id" +SOURCE_CONCEPT_ID = "source_concept_id" +QUANTITY = "quantity" +DAYS_SUPPLY = "days_supply" +REFILLS = "refills" +RANGE_LOW = "range_low" +RANGE_HIGH = "range_high" +VALUE_AS_NUMBER = "value_as_number" +UNIT_CONCEPT_ID = "unit_concept_id" +OCCURRENCE_COUNT = "occurrence_count" +GAP_DAYS = "gap_days" +DURATION = "duration" +CRITERION_INDEX = "criterion_index" +CRITERION_TYPE = "criterion_type" +SOURCE_TABLE = "source_table" + +STANDARD_EVENT_COLUMNS = ( + PERSON_ID, + EVENT_ID, + START_DATE, + END_DATE, + DOMAIN, + CONCEPT_ID, + SOURCE_CONCEPT_ID, + VISIT_OCCURRENCE_ID, + VISIT_DETAIL_ID, + QUANTITY, + DAYS_SUPPLY, + REFILLS, + RANGE_LOW, + RANGE_HIGH, + VALUE_AS_NUMBER, + UNIT_CONCEPT_ID, + OCCURRENCE_COUNT, + GAP_DAYS, + DURATION, + CRITERION_INDEX, + CRITERION_TYPE, + SOURCE_TABLE, +) diff --git a/circe/execution/typing.py b/circe/execution/typing.py new file mode 100644 index 00000000..edf6ba28 --- /dev/null +++ b/circe/execution/typing.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from typing_extensions import TypeAlias + +# Ibis does not currently ship usable type information for its table expressions. +# Treat them as `Any` at the compatibility boundary rather than propagating +# `import-untyped` errors through the executor. +Table: TypeAlias = Any + + +class IbisBackendLike(Protocol): + """Minimal backend surface required by the Ibis executor.""" + + def table(self, name: str, database: str | None = None) -> Table: ... + + def create_table( + self, + name: str, + /, + obj: Any = None, + *, + schema: Any | None = None, + database: str | None = None, + temp: bool = False, + overwrite: bool = False, + ) -> Any: ... diff --git a/docs/developer/architecture.rst b/docs/developer/architecture.rst index 46736501..6b2c439e 100644 --- a/docs/developer/architecture.rst +++ b/docs/developer/architecture.rst @@ -12,6 +12,7 @@ Package Structure * **helper/** - Utility functions * **api.py** - High-level API * **cli.py** - Command-line interface +* **execution/** - Experimental Ibis-based cohort execution engine SQL Generation -------------- @@ -23,3 +24,125 @@ Validation Framework The validation framework uses a checker pattern with pluggable validators. +Execution Engine +---------------- + +The ``circe.execution`` package is an experimental, table-first Ibis executor +for ``CohortExpression`` models. It runs in parallel with the existing SQL +builder. + +Public API +~~~~~~~~~~ + +The main execution entrypoints are: + +* ``build_cohort(...)`` - build a lazy Ibis relation in canonical execution shape +* ``write_cohort(...)`` - project to OHDSI cohort-table shape and write rows for one ``cohort_id`` + +The write contract is cohort-scoped: + +* ``if_exists="fail"`` errors only if rows already exist for that ``cohort_id`` +* ``if_exists="replace"`` replaces only that ``cohort_id`` and preserves other cohorts in the same table + +Layered Design +~~~~~~~~~~~~~~ + +The subsystem is intentionally split into five layers. + +1. ``normalize/`` + + * converts public cohort-definition models into frozen internal dataclasses + * removes aliasing and optional-shape noise from downstream code + * rejects explicitly unsupported semantics early + +2. ``lower/`` + + * turns normalized criteria into backend-agnostic execution plans + * encodes reusable event and predicate planning logic + * keeps domain-specific lowering separate from backend-specific compilation + +3. ``ibis/`` + + * compiles lowered plans into Ibis relations + * standardizes domain tables into the canonical event schema + * resolves concept sets and person filters + * provides backend operations used by the public write path + +4. ``engine/`` + + * evaluates cohort semantics over canonical event relations + * handles primary events, additional criteria, inclusion rules, censoring, + limits, collapse, and end strategy + +5. API materialization layer + + * connects public API calls to normalization, compilation, and engine execution + * projects final relations into OHDSI cohort-table shape + * handles backend table existence checks and cohort-scoped writes + +Canonical Event Schema +~~~~~~~~~~~~~~~~~~~~~~ + +Compiled domain event relations are standardized before engine orchestration. +The canonical columns are defined in ``circe/execution/plan/schema.py``. +Important columns include: + +* ``person_id`` +* ``event_id`` +* ``start_date`` +* ``end_date`` +* ``domain`` +* ``concept_id`` +* ``source_concept_id`` +* ``visit_occurrence_id`` +* ``criterion_index`` +* ``criterion_type`` +* ``source_table`` + +This standardization is one of the main design differences from the legacy +builder-based path. The engine operates on one event shape instead of many +domain-specific SQL-builder shapes. + +Data Flow +~~~~~~~~~ + +The end-to-end flow is: + +1. ``CohortExpression`` +2. normalize to frozen internal dataclasses +3. lower criteria into event and predicate plans +4. compile plans into canonical Ibis relations +5. run cohort semantics in ``engine/`` +6. optionally materialize to OHDSI cohort-table rows + +Codeset Resolution +~~~~~~~~~~~~~~~~~~ + +Codeset expansion is handled by ``CachedConceptSetResolver``. +Resolution semantics are: + +* direct inclusion +* descendant expansion through ``concept_ancestor`` +* mapped concept expansion through ``concept_relationship`` +* exclusion precedence after expansion + +The cache is scoped to one execution context run. + +Migration Notes +~~~~~~~~~~~~~~~ + +If you used the legacy execution prototype: + +* use ``build_cohort(...)`` to get the lazy relation +* use backend operations on that relation for inspection and collection +* use ``write_cohort(...)`` for cohort-table writes + +Current Execution Limitations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The executor should fail explicitly for unsupported execution semantics rather +than silently degrading behavior. + +Current explicit limitation: + +* ``custom_era`` end strategy is not implemented in this base execution branch diff --git a/docs/developer/testing.rst b/docs/developer/testing.rst index 71c09368..402956d0 100644 --- a/docs/developer/testing.rst +++ b/docs/developer/testing.rst @@ -1,7 +1,7 @@ Testing ======= -CIRCE Python has comprehensive test coverage (71%, 896 tests). +CIRCE Python has comprehensive test coverage. Running Tests ------------- @@ -25,5 +25,113 @@ Tests are organized by module in the ``tests/`` directory. Writing Tests ------------- -Follow existing test patterns. See CONTRIBUTING.md for guidelines. +Follow existing test patterns. See ``docs/developer/contributing.rst`` for +contribution guidelines. +Execution Engine Testing +------------------------ + +The ``circe.execution`` subsystem should be tested in layers, with each layer +optimized for a different failure mode. + +Goals +~~~~~ + +* keep the engine safe to refactor while the design is still evolving +* make regressions easy to localize to one layer +* avoid turning the test suite into a single large DuckDB integration harness + +Test Layers +~~~~~~~~~~~ + +1. Pure normalization and lowering unit tests + + * Scope: ``normalize/``, ``lower/``, ``plan/``, and small pure helpers + * Style: no backend, no SQL execution, frozen dataclass assertions + * Current files: + + * ``tests/execution/test_normalize.py`` + * ``tests/execution/test_normalize_contracts.py`` + * ``tests/execution/test_lowering.py`` + * ``tests/execution/test_lower_contracts.py`` + * ``tests/execution/test_compile_contracts.py`` + +2. Ibis helper unit tests + + * Scope: ``ibis/codesets.py``, ``ibis/operations.py``, ``ibis/context.py``, + ``ibis/standardize.py``, and engine helpers that do not need full cohort runs + * Style: fake backends where possible; DuckDB only when expression execution is + the thing under test + * Current files: + + * ``tests/execution/test_context_wiring.py`` + * ``tests/execution/test_operations.py`` + * ``tests/execution/test_ibis_compat.py`` + * ``tests/execution/test_group_demographics.py`` + * ``tests/execution/test_person_filters.py`` + +3. Engine semantics integration tests + + * Scope: primary events, correlated criteria, groups, inclusion rules, result + limits, end strategy, censoring, and parity-sensitive orchestration + * Style: minimal DuckDB fixtures with only the columns required for the + behavior under test + * Current files: + + * ``tests/execution/test_groups.py`` + * ``tests/execution/test_inclusion.py`` + * ``tests/execution/test_result_limits.py`` + * ``tests/execution/test_end_strategy_censoring.py`` + * ``tests/execution/test_parity_regressions.py`` + +4. Public API and wiring tests + + * Scope: ``build_cohort``, ``write_cohort``, package exports, and compat shims + * Style: verify entrypoint behavior, argument handling, and write semantics + without duplicating engine internals + * Current files: + + * ``tests/execution/test_api_public.py`` + * ``tests/execution/test_api_ibis.py`` + * ``tests/execution/test_scaffolding.py`` + +5. Error and limitation tests + + * Scope: explicit unsupported features, validation messages, and backend + capability failures + * Style: assert on error type and message text where the API contract matters + * Current files: + + * ``tests/execution/test_error_messages.py`` + +Rules +~~~~~ + +* each new execution module should get at least one direct test file in the same + layer as its responsibility +* prefer fake backends for capability and error branches, and DuckDB for + relational behavior +* keep fixtures local to a test file unless three or more files need the same setup +* when adding a new feature, add: + + * one layer-local unit or helper test + * one end-to-end or API-level assertion if the feature crosses layers + +* parity and regression tests should stay small and named after the bug or + contract they protect + +Local Gate +~~~~~~~~~~ + +Use this as the normal execution-engine check: + +.. code-block:: bash + + uv run pre-commit run --all-files + uv run pytest tests/execution -q + +Before merging broader refactors, also run: + +.. code-block:: bash + + uv run pytest diff --git a/pyproject.toml b/pyproject.toml index 5ae4b738..401d3dd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,6 @@ ibis = [ ] ibis-duckdb = [ "ibis-framework[duckdb]>=11.0.0; python_version >= '3.9'", - "polars>=0.20.0; python_version >= '3.9'", ] ibis-postgres = [ "ibis-framework[postgres]>=11.0.0; python_version >= '3.9'", diff --git a/tests/execution/_assertions.py b/tests/execution/_assertions.py new file mode 100644 index 00000000..cd16efda --- /dev/null +++ b/tests/execution/_assertions.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from circe.execution.plan.schema import STANDARD_EVENT_COLUMNS + + +def assert_standard_event_columns(columns) -> None: + """Assert a table-like object exposes the canonical standard event schema.""" + normalized = tuple(columns) + assert normalized == STANDARD_EVENT_COLUMNS diff --git a/tests/execution/_domain_cases.py b/tests/execution/_domain_cases.py new file mode 100644 index 00000000..c508140d --- /dev/null +++ b/tests/execution/_domain_cases.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from collections.abc import Callable + +from circe.cohortdefinition import ( + ConditionEra, + ConditionOccurrence, + Death, + DeviceExposure, + DoseEra, + DrugEra, + DrugExposure, + LocationRegion, + Measurement, + Observation, + ObservationPeriod, + PayerPlanPeriod, + ProcedureOccurrence, + Specimen, + VisitDetail, + VisitOccurrence, +) + +CriteriaFactory = Callable[[], object] + + +def domain_criteria_cases() -> list[tuple[str, CriteriaFactory, int | None]]: + """Domain criteria factories + default concept id used for codeset filters.""" + return [ + ("condition_occurrence", lambda: ConditionOccurrence(codeset_id=1), 111), + ("drug_exposure", lambda: DrugExposure(codeset_id=1), 222), + ("visit_occurrence", lambda: VisitOccurrence(codeset_id=1), 333), + ("measurement", lambda: Measurement(codeset_id=1), 444), + ("procedure_occurrence", lambda: ProcedureOccurrence(codeset_id=1), 555), + ("observation", lambda: Observation(codeset_id=1), 666), + ("visit_detail", lambda: VisitDetail(codeset_id=1), 777), + ("device_exposure", lambda: DeviceExposure(codeset_id=1), 888), + ("specimen", lambda: Specimen(codeset_id=1), 999), + ("death", lambda: Death(codeset_id=1), 1001), + ("observation_period", lambda: ObservationPeriod(), None), + ("payer_plan_period", lambda: PayerPlanPeriod(), None), + ("condition_era", lambda: ConditionEra(codeset_id=1), 1201), + ("drug_era", lambda: DrugEra(codeset_id=1), 1301), + ("dose_era", lambda: DoseEra(codeset_id=1), 1401), + ("location_history", lambda: LocationRegion(codeset_id=1), 15151), + ] diff --git a/tests/execution/test_api_ibis.py b/tests/execution/test_api_ibis.py new file mode 100644 index 00000000..ef0a73e4 --- /dev/null +++ b/tests/execution/test_api_ibis.py @@ -0,0 +1,1222 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionEra, + ConditionOccurrence, + CorelatedCriteria, + CriteriaGroup, + Death, + DeviceExposure, + DoseEra, + DrugEra, + DrugExposure, + LocationRegion, + Measurement, + Observation, + ObservationPeriod, + Occurrence, + PayerPlanPeriod, + PrimaryCriteria, + ProcedureOccurrence, + Specimen, + VisitDetail, + VisitOccurrence, +) +from circe.cohortdefinition.core import CustomEraStrategy, NumericRange +from circe.execution.errors import UnsupportedFeatureError +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def _seed_common_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2], + "year_of_birth": [1980, 2015], + "gender_concept_id": [8507, 8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1, 2], + "observation_period_id": [10, 11], + "observation_period_start_date": ["2019-01-01", "2019-01-01"], + "observation_period_end_date": ["2021-12-31", "2021-12-31"], + } + ), + overwrite=True, + ) + + +def _seed_vocabulary_tables(conn, ibis): + conn.create_table( + "concept", + obj=ibis.memtable( + { + "concept_id": [100, 101, 102, 200, 201], + "invalid_reason": [None, None, "D", None, None], + } + ), + overwrite=True, + ) + conn.create_table( + "concept_ancestor", + obj=ibis.memtable( + { + "ancestor_concept_id": [100, 100], + "descendant_concept_id": [101, 102], + } + ), + overwrite=True, + ) + conn.create_table( + "concept_relationship", + obj=ibis.memtable( + { + "concept_id_1": [200, 201], + "concept_id_2": [100, 101], + "relationship_id": ["Maps to", "Maps to"], + "invalid_reason": [None, "D"], + } + ), + overwrite=True, + ) + + +def test_build_cohort_condition_occurrence(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2], + "condition_occurrence_id": [100, 101, 102], + "condition_concept_id": [111, 111, 999], + "condition_start_date": ["2020-01-01", "2020-02-01", "2020-01-05"], + "condition_end_date": ["2020-01-02", "2020-02-02", "2020-01-06"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + ConditionOccurrence( + codeset_id=1, + first=True, + age=NumericRange(op="gte", value=18), + ) + ] + ), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.columns) >= { + "person_id", + "event_id", + "start_date", + "end_date", + "domain", + "criterion_type", + } + assert set(result.person_id) == {1} + assert len(result) == 1 + + +def test_build_cohort_condition_occurrence_with_race_and_ethnicity_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2], + "year_of_birth": [1980, 1980], + "gender_concept_id": [8507, 8507], + "race_concept_id": [8527, 8516], + "ethnicity_concept_id": [38003564, 38003563], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "condition_occurrence_id": [150, 151], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-01-01"], + "condition_end_date": ["2020-01-01", "2020-01-01"], + } + ), + overwrite=True, + ) + + criteria = ConditionOccurrence(codeset_id=1) + criteria.__dict__["race"] = [Concept(conceptId=8527)] + criteria.__dict__["ethnicity"] = [Concept(conceptId=38003564)] + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[criteria]), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_build_cohort_applies_criterion_local_correlated_criteria(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2], + "condition_occurrence_id": [160, 161, 260], + "condition_concept_id": [111, 222, 111], + "condition_start_date": ["2020-01-01", "2020-01-03", "2020-01-01"], + "condition_end_date": ["2020-01-01", "2020-01-03", "2020-01-01"], + "visit_occurrence_id": [10, 10, 20], + } + ), + overwrite=True, + ) + + criteria = ConditionOccurrence( + codeset_id=1, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[criteria]), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_build_cohort_concept_set_resolves_descendants_and_mapped(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + _seed_vocabulary_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 1, 1, 2], + "condition_occurrence_id": [1000, 1001, 1002, 1003, 1004, 1005], + "condition_concept_id": [100, 101, 102, 200, 201, 999], + "condition_start_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-03", + "2020-01-04", + "2020-01-05", + "2020-01-01", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-03", + "2020-01-04", + "2020-01-05", + "2020-01-01", + ], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + ConceptSet( + id=1, + expression=ConceptSetExpression( + items=[ + ConceptSetItem( + concept=Concept(conceptId=100), + includeDescendants=True, + includeMapped=True, + ), + ConceptSetItem( + concept=Concept(conceptId=101), + isExcluded=True, + includeMapped=True, + ), + ] + ), + ) + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + assert set(result.concept_id) == {100, 200} + + +def test_build_cohort_uses_vocabulary_schema_option_for_expansion(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + conn.raw_sql("CREATE SCHEMA vocab") + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [2000, 2001], + "condition_concept_id": [100, 101], + "condition_start_date": ["2020-01-01", "2020-01-02"], + "condition_end_date": ["2020-01-01", "2020-01-02"], + } + ), + overwrite=True, + ) + conn.create_table( + "concept", + obj=ibis.memtable({"concept_id": [100, 101, 102], "invalid_reason": [None, None, "D"]}), + database="vocab", + overwrite=True, + ) + conn.create_table( + "concept_ancestor", + obj=ibis.memtable({"ancestor_concept_id": [100], "descendant_concept_id": [101]}), + database="vocab", + overwrite=True, + ) + conn.create_table( + "concept_relationship", + obj=ibis.memtable( + { + "concept_id_1": [9999, 9998], + "concept_id_2": [100, 101], + "relationship_id": ["Maps to", "Maps to"], + "invalid_reason": [None, "D"], + } + ), + database="vocab", + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + ConceptSet( + id=1, + expression=ConceptSetExpression( + items=[ + ConceptSetItem( + concept=Concept(conceptId=100), + includeDescendants=True, + ) + ] + ), + ) + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + ) + + result = build_cohort( + expression, + backend=conn, + cdm_schema="main", + vocabulary_schema="vocab", + ).execute() + assert set(result.concept_id) == {100, 101} + + +def test_build_cohort_drug_exposure(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 2], + "drug_exposure_id": [200, 201], + "drug_concept_id": [222, 999], + "drug_exposure_start_date": ["2020-03-01", "2020-03-01"], + "drug_exposure_end_date": ["2020-03-02", "2020-03-02"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[DrugExposure(codeset_id=2)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "drug_exposure") + + +def test_build_cohort_visit_occurrence(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "visit_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_occurrence_id": [300, 301], + "visit_concept_id": [333, 999], + "visit_start_date": ["2020-05-01", "2020-05-01"], + "visit_end_date": ["2020-05-02", "2020-05-02"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(3, 333)], + primary_criteria=PrimaryCriteria(criteria_list=[VisitOccurrence(codeset_id=3)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "visit_occurrence") + + +def test_build_cohort_measurement(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "measurement", + obj=ibis.memtable( + { + "person_id": [1, 2], + "measurement_id": [400, 401], + "measurement_concept_id": [444, 999], + "measurement_date": ["2020-06-01", "2020-06-01"], + "visit_occurrence_id": [10, 11], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(4, 444)], + primary_criteria=PrimaryCriteria(criteria_list=[Measurement(codeset_id=4)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "measurement") + + +def test_build_cohort_measurement_with_value_and_unit_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "measurement", + obj=ibis.memtable( + { + "person_id": [1, 2], + "measurement_id": [410, 411], + "measurement_concept_id": [444, 444], + "measurement_date": ["2020-06-01", "2020-06-01"], + "visit_occurrence_id": [10, 11], + "value_as_number": [5.0, 15.0], + "unit_concept_id": [9001, 9002], + "value_as_concept_id": [7001, 7002], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(4, 444)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + Measurement( + codeset_id=4, + value_as_number=NumericRange(op="gte", value=10), + unit=[Concept(conceptId=9002)], + value_as_concept=[Concept(conceptId=7002)], + ) + ] + ), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {2} + assert all(result.domain == "measurement") + + +def test_build_cohort_procedure_occurrence(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "procedure_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "procedure_occurrence_id": [500, 501], + "procedure_concept_id": [555, 999], + "procedure_date": ["2020-07-01", "2020-07-01"], + "visit_occurrence_id": [10, 11], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(5, 555)], + primary_criteria=PrimaryCriteria(criteria_list=[ProcedureOccurrence(codeset_id=5)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "procedure_occurrence") + + +def test_build_cohort_procedure_occurrence_with_domain_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "procedure_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "procedure_occurrence_id": [510, 511], + "procedure_concept_id": [555, 555], + "procedure_date": ["2020-07-01", "2020-07-01"], + "visit_occurrence_id": [10, 11], + "procedure_type_concept_id": [901, 902], + "modifier_concept_id": [1001, 1002], + "quantity": [1, 5], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(5, 555)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + ProcedureOccurrence( + codeset_id=5, + procedure_type=[Concept(conceptId=902)], + quantity=NumericRange(op="gte", value=5), + ) + ] + ), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {2} + assert all(result.domain == "procedure_occurrence") + + +def test_build_cohort_observation(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "observation", + obj=ibis.memtable( + { + "person_id": [1, 2], + "observation_id": [600, 601], + "observation_concept_id": [666, 999], + "observation_date": ["2020-08-01", "2020-08-01"], + "visit_occurrence_id": [10, 11], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(6, 666)], + primary_criteria=PrimaryCriteria(criteria_list=[Observation(codeset_id=6)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "observation") + + +def test_build_cohort_observation_with_domain_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "observation", + obj=ibis.memtable( + { + "person_id": [1, 2], + "observation_id": [610, 611], + "observation_concept_id": [666, 666], + "observation_date": ["2020-08-01", "2020-08-01"], + "visit_occurrence_id": [10, 11], + "observation_type_concept_id": [2001, 2002], + "value_as_number": [1.0, 20.0], + "value_as_string": ["low", "high"], + "value_as_concept_id": [3001, 3002], + "unit_concept_id": [4001, 4002], + "qualifier_concept_id": [5001, 5002], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(6, 666)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + Observation( + codeset_id=6, + observation_type=[Concept(conceptId=2002)], + value_as_number=NumericRange(op="gte", value=10), + value_as_concept=[Concept(conceptId=3002)], + unit=[Concept(conceptId=4002)], + ) + ] + ), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {2} + assert all(result.domain == "observation") + + +def test_build_cohort_visit_detail(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "visit_detail", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_detail_id": [700, 701], + "visit_detail_concept_id": [777, 999], + "visit_detail_start_date": ["2020-09-01", "2020-09-01"], + "visit_detail_end_date": ["2020-09-02", "2020-09-02"], + "visit_occurrence_id": [10, 11], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(7, 777)], + primary_criteria=PrimaryCriteria(criteria_list=[VisitDetail(codeset_id=7)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "visit_detail") + + +def test_build_cohort_visit_detail_with_domain_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "visit_detail", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_detail_id": [710, 711], + "visit_detail_concept_id": [777, 777], + "visit_detail_start_date": ["2020-09-01", "2020-09-01"], + "visit_detail_end_date": ["2020-09-02", "2020-09-02"], + "visit_occurrence_id": [10, 11], + "visit_detail_type_concept_id": [6001, 6002], + "discharge_to_concept_id": [7001, 7002], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(7, 777)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + VisitDetail( + codeset_id=7, + visit_detail_type=[Concept(conceptId=6002)], + discharge_to=[Concept(conceptId=7002)], + ) + ] + ), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {2} + assert all(result.domain == "visit_detail") + + +def test_build_cohort_device_exposure(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "device_exposure", + obj=ibis.memtable( + { + "person_id": [1, 2], + "device_exposure_id": [800, 801], + "device_concept_id": [888, 999], + "device_exposure_start_date": ["2020-10-01", "2020-10-01"], + "device_exposure_end_date": ["2020-10-02", "2020-10-02"], + "visit_occurrence_id": [10, 11], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(8, 888)], + primary_criteria=PrimaryCriteria(criteria_list=[DeviceExposure(codeset_id=8)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "device_exposure") + + +def test_build_cohort_specimen(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "specimen", + obj=ibis.memtable( + { + "person_id": [1, 2], + "specimen_id": [900, 901], + "specimen_concept_id": [9990, 9991], + "specimen_date": ["2020-11-01", "2020-11-01"], + "visit_occurrence_id": [10, 11], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(9, 9990)], + primary_criteria=PrimaryCriteria(criteria_list=[Specimen(codeset_id=9)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "specimen") + + +def test_build_cohort_death(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "death", + obj=ibis.memtable( + { + "person_id": [1, 2], + "cause_concept_id": [10001, 10002], + "cause_source_concept_id": [20001, 20002], + "death_date": ["2020-12-01", "2020-12-01"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(10, 10001)], + primary_criteria=PrimaryCriteria(criteria_list=[Death(codeset_id=10)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "death") + + +def test_build_cohort_observation_period(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + expression = CohortExpression( + primary_criteria=PrimaryCriteria(criteria_list=[ObservationPeriod()]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1, 2} + assert all(result.domain == "observation_period") + + +def test_build_cohort_payer_plan_period(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "payer_plan_period", + obj=ibis.memtable( + { + "person_id": [1], + "payer_plan_period_id": [1100], + "payer_concept_id": [12345], + "payer_source_concept_id": [54321], + "payer_plan_period_start_date": ["2020-01-01"], + "payer_plan_period_end_date": ["2020-12-31"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + primary_criteria=PrimaryCriteria(criteria_list=[PayerPlanPeriod()]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "payer_plan_period") + + +def test_build_cohort_condition_era(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_era", + obj=ibis.memtable( + { + "person_id": [1, 2], + "condition_era_id": [1200, 1201], + "condition_concept_id": [12121, 99999], + "condition_era_start_date": ["2020-01-01", "2020-01-01"], + "condition_era_end_date": ["2020-02-01", "2020-02-01"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(11, 12121)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionEra(codeset_id=11)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "condition_era") + + +def test_build_cohort_condition_era_applies_era_length_and_occurrence_count(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_era", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "condition_era_id": [1200, 1201, 1202], + "condition_concept_id": [12121, 12121, 12121], + "condition_era_start_date": ["2020-01-01", "2020-01-01", "2020-01-01"], + "condition_era_end_date": ["2020-02-15", "2020-01-20", "2020-02-15"], + "condition_occurrence_count": [4, 4, 1], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(11, 12121)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + ConditionEra( + codeset_id=11, + era_length=NumericRange(op="gte", value=30), + occurrence_count=NumericRange(op="gte", value=2), + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert set(result.person_id) == {1} + + +def test_build_cohort_drug_era(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "drug_era", + obj=ibis.memtable( + { + "person_id": [1, 2], + "drug_era_id": [1300, 1301], + "drug_concept_id": [13131, 99999], + "drug_era_start_date": ["2020-03-01", "2020-03-01"], + "drug_era_end_date": ["2020-04-01", "2020-04-01"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(12, 13131)], + primary_criteria=PrimaryCriteria(criteria_list=[DrugEra(codeset_id=12)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "drug_era") + + +def test_build_cohort_drug_era_applies_era_length(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "drug_era", + obj=ibis.memtable( + { + "person_id": [1, 2], + "drug_era_id": [1300, 1301], + "drug_concept_id": [13131, 13131], + "drug_era_start_date": ["2020-03-01", "2020-03-01"], + "drug_era_end_date": ["2020-04-15", "2020-03-10"], + "drug_exposure_count": [2, 2], + "gap_days": [5, 5], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(12, 13131)], + primary_criteria=PrimaryCriteria( + criteria_list=[DrugEra(codeset_id=12, era_length=NumericRange(op="gte", value=30))] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert set(result.person_id) == {1} + + +def test_build_cohort_drug_era_applies_occurrence_count_and_gap_days(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "drug_era", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "drug_era_id": [1300, 1301, 1302], + "drug_concept_id": [13131, 13131, 13131], + "drug_era_start_date": ["2020-03-01", "2020-03-01", "2020-03-01"], + "drug_era_end_date": ["2020-04-15", "2020-04-15", "2020-04-15"], + "drug_exposure_count": [4, 1, 4], + "gap_days": [8, 8, 2], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(12, 13131)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + DrugEra( + codeset_id=12, + occurrence_count=NumericRange(op="gte", value=2), + gap_days=NumericRange(op="gte", value=5), + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert set(result.person_id) == {1} + + +def test_build_cohort_dose_era(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "dose_era", + obj=ibis.memtable( + { + "person_id": [1, 2], + "dose_era_id": [1400, 1401], + "drug_concept_id": [14141, 99999], + "dose_era_start_date": ["2020-05-01", "2020-05-01"], + "dose_era_end_date": ["2020-06-01", "2020-06-01"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(13, 14141)], + primary_criteria=PrimaryCriteria(criteria_list=[DoseEra(codeset_id=13)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "dose_era") + + +def test_build_cohort_dose_era_applies_era_length(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "dose_era", + obj=ibis.memtable( + { + "person_id": [1, 2], + "dose_era_id": [1400, 1401], + "drug_concept_id": [14141, 14141], + "dose_era_start_date": ["2020-05-01", "2020-05-01"], + "dose_era_end_date": ["2020-06-15", "2020-05-10"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(13, 14141)], + primary_criteria=PrimaryCriteria( + criteria_list=[DoseEra(codeset_id=13, era_length=NumericRange(op="gte", value=30))] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert set(result.person_id) == {1} + + +def test_build_cohort_location_region(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "location", + obj=ibis.memtable( + { + "location_id": [10, 20], + "region_concept_id": [15151, 99999], + } + ), + overwrite=True, + ) + conn.create_table( + "location_history", + obj=ibis.memtable( + { + "entity_id": [1, 2], + "location_id": [10, 20], + "start_date": ["2020-01-01", "2020-01-01"], + "end_date": ["2020-12-31", "2020-12-31"], + "domain_id": ["PERSON", "PERSON"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(14, 15151)], + primary_criteria=PrimaryCriteria(criteria_list=[LocationRegion(codeset_id=14)]), + ) + + table = build_cohort(expression, backend=conn, cdm_schema="main") + result = table.execute() + + assert set(result.person_id) == {1} + assert all(result.domain == "location_region") + + +def test_build_cohort_location_region_keeps_repeated_location_history_rows(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "location", + obj=ibis.memtable( + { + "location_id": [10], + "region_concept_id": [15151], + } + ), + overwrite=True, + ) + conn.create_table( + "location_history", + obj=ibis.memtable( + { + "entity_id": [1, 1], + "location_id": [10, 10], + "start_date": ["2020-01-01", "2020-02-01"], + "end_date": ["2020-01-31", "2020-02-28"], + "domain_id": ["PERSON", "PERSON"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(14, 15151)], + primary_criteria=PrimaryCriteria(criteria_list=[LocationRegion(codeset_id=14)]), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 2 + assert set(result.person_id) == {1} + assert sorted(result.start_date.astype(str).tolist()) == ["2020-01-01", "2020-02-01"] + + +def test_build_cohort_rejects_unsupported_features(): + expression = CohortExpression( + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()]), + end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), + ) + with pytest.raises(UnsupportedFeatureError, match="custom_era"): + _ = build_cohort(expression, backend=object(), cdm_schema="main") diff --git a/tests/execution/test_api_public.py b/tests/execution/test_api_public.py new file mode 100644 index 00000000..a8f074d2 --- /dev/null +++ b/tests/execution/test_api_public.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import pytest + +import circe.api as api +from circe.api import build_cohort, write_cohort +from circe.cohortdefinition import CohortExpression, ConditionOccurrence, PrimaryCriteria +from circe.execution.api import write_relation +from circe.execution.errors import ExecutionError +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def _expression() -> CohortExpression: + return CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + ) + + +def _seed_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2], + "year_of_birth": [1980, 1982], + "gender_concept_id": [8507, 8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1, 2], + "observation_period_id": [10, 11], + "observation_period_start_date": ["2019-01-01", "2019-01-01"], + "observation_period_end_date": ["2021-12-31", "2021-12-31"], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-01-02"], + "condition_end_date": ["2020-01-01", "2020-01-02"], + } + ), + overwrite=True, + ) + + +def test_public_execution_functions_are_exported(): + assert hasattr(api, "build_cohort") + assert hasattr(api, "write_cohort") + assert hasattr(api, "build_cohort_query") + + +def test_build_cohort_returns_relation(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_tables(conn, ibis) + + expression = _expression() + + relation = build_cohort(expression, backend=conn, cdm_schema="main") + + assert hasattr(relation, "execute") + assert len(relation.execute()) == 2 + + +def test_write_cohort_writes_result_table(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_tables(conn, ibis) + + write_cohort( + _expression(), + backend=conn, + cdm_schema="main", + cohort_table="cohort_out", + cohort_id=42, + if_exists="replace", + ) + result = conn.table("cohort_out").execute() + assert len(result) == 2 + assert list(result.columns) == [ + "cohort_definition_id", + "subject_id", + "cohort_start_date", + "cohort_end_date", + ] + assert set(result.cohort_definition_id) == {42} + + +def test_write_cohort_if_exists_fail_raises(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_tables(conn, ibis) + + write_cohort( + _expression(), + backend=conn, + cdm_schema="main", + cohort_table="cohort_out", + cohort_id=42, + if_exists="fail", + ) + with pytest.raises(ExecutionError, match="already contains rows for cohort_id=42"): + write_cohort( + _expression(), + backend=conn, + cdm_schema="main", + cohort_table="cohort_out", + cohort_id=42, + if_exists="fail", + ) + + +def test_write_cohort_if_exists_replace_overwrites(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_tables(conn, ibis) + expression = _expression() + + write_cohort( + expression, + backend=conn, + cdm_schema="main", + cohort_table="cohort_out", + cohort_id=10, + if_exists="replace", + ) + first = conn.table("cohort_out").execute() + assert len(first) == 2 + assert set(first.cohort_definition_id) == {10} + + write_cohort( + expression, + backend=conn, + cdm_schema="main", + cohort_table="cohort_out", + cohort_id=20, + if_exists="replace", + ) + combined = conn.table("cohort_out").execute() + assert len(combined) == 4 + assert set(combined.cohort_definition_id) == {10, 20} + + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": ["2020-01-01"], + "condition_end_date": ["2020-01-01"], + } + ), + overwrite=True, + ) + write_cohort( + expression, + backend=conn, + cdm_schema="main", + cohort_table="cohort_out", + cohort_id=10, + if_exists="replace", + ) + replaced = conn.table("cohort_out").execute() + replaced_10 = replaced[replaced.cohort_definition_id == 10] + replaced_20 = replaced[replaced.cohort_definition_id == 20] + assert set(replaced_10.subject_id) == {1} + assert set(replaced_20.subject_id) == {1, 2} + + +def test_write_cohort_respects_results_schema(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_tables(conn, ibis) + + write_cohort( + _expression(), + backend=conn, + cdm_schema="main", + results_schema="main", + cohort_table="cohort_schema", + cohort_id=7, + if_exists="replace", + ) + assert len(conn.table("cohort_schema", database="main").execute()) == 2 + + +def test_expression_first_build_modify_then_write_relation(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_tables(conn, ibis) + + relation = build_cohort(_expression(), backend=conn, cdm_schema="main") + modified = relation.filter(relation.person_id == 1) + + write_relation( + modified, + backend=conn, + target_table="cohort_filtered", + target_schema="main", + if_exists="replace", + ) + result = conn.table("cohort_filtered", database="main").execute() + assert set(result.person_id) == {1} + + +def test_write_cohort_rejects_invalid_if_exists(): + with pytest.raises(ValueError, match="if_exists must be one of"): + write_cohort( + _expression(), + backend=object(), + cdm_schema="main", + cohort_table="cohort_out", + cohort_id=1, + if_exists="append", + ) + + +def test_write_cohort_replace_uses_delete_then_insert(monkeypatch: pytest.MonkeyPatch): + import circe.execution.api as execution_api + + events: list[tuple[str, object]] = [] + + monkeypatch.setattr(execution_api, "build_cohort", lambda *args, **kwargs: object()) + monkeypatch.setattr( + execution_api, "project_to_ohdsi_cohort_table", lambda relation, *, cohort_id: relation + ) + monkeypatch.setattr(execution_api, "table_exists", lambda *args, **kwargs: True) + monkeypatch.setattr(execution_api, "supports_transactional_replace", lambda *args, **kwargs: True) + monkeypatch.setattr( + execution_api, + "replace_cohort_rows_transactionally", + lambda relation, *, backend, cohort_table, results_schema=None, cohort_id: events.append( + ("replace", cohort_table, results_schema, cohort_id) + ), + ) + monkeypatch.setattr( + execution_api, + "write_relation", + lambda *args, **kwargs: events.append(("create", kwargs["target_table"])), + ) + + write_cohort( + _expression(), + backend=object(), + cdm_schema="main", + results_schema="results", + cohort_table="cohort_out", + cohort_id=9, + if_exists="replace", + ) + + assert events == [("replace", "cohort_out", "results", 9)] + + +def test_write_cohort_replace_falls_back_to_safe_rewrite(monkeypatch: pytest.MonkeyPatch): + import circe.execution.api as execution_api + + events: list[tuple[str, object]] = [] + existing = object() + + class _Filtered: + def union(self, relation, distinct=False): + events.append(("union", distinct)) + return "merged" + + filtered = _Filtered() + + monkeypatch.setattr(execution_api, "build_cohort", lambda *args, **kwargs: object()) + monkeypatch.setattr( + execution_api, "project_to_ohdsi_cohort_table", lambda relation, *, cohort_id: relation + ) + monkeypatch.setattr(execution_api, "table_exists", lambda *args, **kwargs: True) + monkeypatch.setattr(execution_api, "supports_transactional_replace", lambda *args, **kwargs: False) + monkeypatch.setattr(execution_api, "read_table", lambda *args, **kwargs: existing) + monkeypatch.setattr( + execution_api, + "exclude_cohort_rows", + lambda relation, *, cohort_id: events.append(("filter", cohort_id)) or filtered, + ) + monkeypatch.setattr( + execution_api, + "write_relation", + lambda relation, *, backend, target_table, target_schema=None, if_exists="fail", temporary=False: ( + events.append(("write", relation, target_table, target_schema, if_exists)) + ), + ) + + write_cohort( + _expression(), + backend=object(), + cdm_schema="main", + results_schema="results", + cohort_table="cohort_out", + cohort_id=9, + if_exists="replace", + ) + + assert events == [ + ("filter", 9), + ("union", False), + ("write", "merged", "cohort_out", "results", "replace"), + ] + + +def test_write_relation_type_error_is_reported_as_generic_write_failure(): + class _Backend: + def create_table(self, name, **kwargs): + raise TypeError("boom") + + with pytest.raises(ExecutionError, match="failed writing relation to table 'cohort_out'"): + write_relation( + object(), + backend=_Backend(), + target_table="cohort_out", + target_schema="main", + if_exists="replace", + ) diff --git a/tests/execution/test_compile_contracts.py b/tests/execution/test_compile_contracts.py new file mode 100644 index 00000000..841497e3 --- /dev/null +++ b/tests/execution/test_compile_contracts.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import pytest + +from circe.cohortdefinition import CohortExpression, PrimaryCriteria +from circe.execution.ibis.compiler import compile_event_plan +from circe.execution.ibis.context import make_execution_context +from circe.execution.lower.criteria import lower_criterion +from circe.execution.normalize.cohort import normalize_cohort +from circe.execution.plan.schema import STANDARD_EVENT_COLUMNS +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem +from tests.execution._domain_cases import domain_criteria_cases + + +def _seed_common_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1], + "year_of_birth": [1980], + "gender_concept_id": [8507], + "race_concept_id": [8527], + "ethnicity_concept_id": [38003564], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": ["2019-01-01"], + "observation_period_end_date": ["2022-12-31"], + } + ), + overwrite=True, + ) + + +@pytest.mark.parametrize(("source_table", "factory", "concept_id"), domain_criteria_cases()) +def test_compile_contract_emits_standard_schema(source_table, factory, concept_id): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + criteria = factory() + concept_sets = [] + if concept_id is not None: + concept_sets = [ + ConceptSet( + id=1, + expression=ConceptSetExpression( + items=[ConceptSetItem(concept=Concept(conceptId=concept_id))] + ), + ) + ] + + expression = CohortExpression( + concept_sets=concept_sets, + primary_criteria=PrimaryCriteria(criteria_list=[criteria]), + ) + normalized = normalize_cohort(expression) + normalized_criterion = normalized.primary.criteria[0] + plan = lower_criterion(normalized_criterion, criterion_index=0) + + source_data = { + plan.source.person_id_column: [1], + plan.source.event_id_column: [101], + plan.source.start_date_column: ["2020-01-01"], + plan.source.end_date_column: ["2020-01-01"], + } + if plan.source.visit_occurrence_column and plan.source.visit_occurrence_column not in source_data: + source_data[plan.source.visit_occurrence_column] = [10] + if ( + plan.source.concept_column + and plan.source.concept_column not in source_data + and source_table != "location_history" + ): + source_data[plan.source.concept_column] = [concept_id or 0] + if plan.source.source_concept_column and plan.source.source_concept_column not in source_data: + source_data[plan.source.source_concept_column] = [concept_id or 0] + if source_table == "location_history": + source_data["domain_id"] = ["PERSON"] + source_data["location_id"] = [10] + conn.create_table( + "location", + obj=ibis.memtable({"location_id": [10], "region_concept_id": [concept_id]}), + overwrite=True, + ) + + conn.create_table(source_table, obj=ibis.memtable(source_data), overwrite=True) + + ctx = make_execution_context( + backend=conn, + cdm_schema="main", + results_schema=None, + concept_sets=normalized.concept_sets, + ) + + result = compile_event_plan(plan, ctx).execute() + assert tuple(result.columns) == STANDARD_EVENT_COLUMNS + assert len(result) == 1 diff --git a/tests/execution/test_compile_steps_helpers.py b/tests/execution/test_compile_steps_helpers.py new file mode 100644 index 00000000..74a3d878 --- /dev/null +++ b/tests/execution/test_compile_steps_helpers.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +from datetime import date +from types import SimpleNamespace + +import ibis +import pytest + +from circe.execution.engine.group_windows import apply_window_constraints, window_bound_expression +from circe.execution.errors import CompilationError, UnsupportedFeatureError +from circe.execution.ibis.compile_steps import ( + _apply_date_predicate, + _apply_numeric_predicate, + _resolve_concept_ids, + apply_step, +) +from circe.execution.normalize.windows import NormalizedWindow, NormalizedWindowBound +from circe.execution.plan.events import ( + ApplyDateAdjustment, + FilterByCareSiteLocationRegion, + FilterByCodeset, + FilterByConceptSet, + FilterByPersonGender, + FilterByText, + KeepFirstPerPerson, + RestrictToCorrelatedWindow, +) +from circe.execution.plan.predicates import DateRangePredicate, NumericRangePredicate +from circe.execution.plan.schema import END_DATE, EVENT_ID, PERSON_ID, START_DATE, VISIT_OCCURRENCE_ID + + +class _Context: + def __init__(self, conn=None, *, codesets: dict[int, tuple[int, ...]] | None = None): + self.conn = conn + self.codesets = codesets or {} + + def concept_ids_for_codeset(self, codeset_id: int) -> tuple[int, ...]: + return self.codesets.get(codeset_id, ()) + + def table(self, name: str): + if self.conn is None: + raise KeyError(name) + return self.conn.table(name) + + +def _events_table(conn): + conn.create_table( + "events", + obj=ibis.memtable( + { + PERSON_ID: [1, 1, 2], + EVENT_ID: [10, 11, 20], + START_DATE: [ + date(2020, 1, 1), + date(2020, 1, 2), + date(2020, 1, 3), + ], + END_DATE: [ + date(2020, 1, 5), + date(2020, 1, 4), + date(2020, 1, 6), + ], + VISIT_OCCURRENCE_ID: [100, 101, 200], + "concept_id": [1, 2, 3], + "text_value": ["alpha", "beta", "gamma"], + } + ), + overwrite=True, + ) + return conn.table("events") + + +@pytest.mark.parametrize( + ("predicate", "expected"), + [ + (NumericRangePredicate(op=None, value=None, extent=None), [True, True, True]), + (NumericRangePredicate(op="eq", value=2, extent=None), [False, True, False]), + (NumericRangePredicate(op="neq", value=2, extent=None), [True, False, True]), + (NumericRangePredicate(op="gt", value=1, extent=None), [False, True, True]), + (NumericRangePredicate(op="gte", value=2, extent=None), [False, True, True]), + (NumericRangePredicate(op="lt", value=3, extent=None), [True, True, False]), + (NumericRangePredicate(op="lte", value=2, extent=None), [True, True, False]), + (NumericRangePredicate(op="between", value=2, extent=3), [False, True, True]), + ], +) +def test_apply_numeric_predicate_covers_supported_ops(predicate, expected): + table = ibis.memtable({"value": [1, 2, 3]}) + result = table.select(_apply_numeric_predicate(table.value, predicate).name("matched")).execute() + assert list(result.matched) == expected + + +def test_apply_numeric_predicate_rejects_invalid_ranges(): + expr = ibis.memtable({"value": [1]}).value + + with pytest.raises(CompilationError, match="numeric range 'between' requires an extent value"): + _apply_numeric_predicate(expr, NumericRangePredicate(op="between", value=1, extent=None)) + + with pytest.raises(CompilationError, match="unsupported numeric range op"): + _apply_numeric_predicate(expr, NumericRangePredicate(op="weird", value=1, extent=None)) + + +@pytest.mark.parametrize( + ("predicate", "expected"), + [ + (DateRangePredicate(op=None, value=None, extent=None), [True, True, True]), + (DateRangePredicate(op="eq", value="2020-01-02", extent=None), [False, True, False]), + (DateRangePredicate(op="neq", value="2020-01-02", extent=None), [True, False, True]), + (DateRangePredicate(op="gt", value="2020-01-01", extent=None), [False, True, True]), + (DateRangePredicate(op="gte", value="2020-01-02", extent=None), [False, True, True]), + (DateRangePredicate(op="lt", value="2020-01-03", extent=None), [True, True, False]), + (DateRangePredicate(op="lte", value="2020-01-02", extent=None), [True, True, False]), + ( + DateRangePredicate(op="between", value="2020-01-02", extent="2020-01-03"), + [False, True, True], + ), + ], +) +def test_apply_date_predicate_covers_supported_ops(predicate, expected): + table = ibis.memtable({"value": ["2020-01-01", "2020-01-02", "2020-01-03"]}) + result = table.select(_apply_date_predicate(table.value, predicate).name("matched")).execute() + assert list(result.matched) == expected + + +def test_apply_date_predicate_rejects_invalid_ranges(): + expr = ibis.memtable({"value": ["2020-01-01"]}).value + + with pytest.raises(CompilationError, match="date range 'between' requires an extent value"): + _apply_date_predicate(expr, DateRangePredicate(op="between", value="2020-01-01", extent=None)) + + with pytest.raises(CompilationError, match="unsupported date range op"): + _apply_date_predicate(expr, DateRangePredicate(op="weird", value="2020-01-01", extent=None)) + + +def test_resolve_concept_ids_deduplicates_codeset_ids(): + ctx = _Context(codesets={1: (2, 3, 4)}) + assert _resolve_concept_ids(direct_ids=(1, 2), codeset_id=1, ctx=ctx) == (1, 2, 3, 4) + + +def test_apply_step_covers_text_codeset_concept_and_adjustment_paths(): + ibis_mod = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis_mod.duckdb.connect() + table = _events_table(conn) + ctx = _Context(conn, codesets={1: (1, 3), 2: ()}) + + codeset_hit = apply_step( + FilterByCodeset(column="concept_id", codeset_id=1), + table=table, + source=None, + ctx=ctx, + ).execute() + assert set(codeset_hit.concept_id) == {1, 3} + + codeset_exclude = apply_step( + FilterByCodeset(column="concept_id", codeset_id=2, exclude=True), + table=table, + source=None, + ctx=ctx, + ).execute() + assert len(codeset_exclude) == 3 + + empty_concepts = apply_step( + FilterByConceptSet(column="concept_id", concept_ids=(), exclude=False), + table=table, + source=None, + ctx=ctx, + ).execute() + assert empty_concepts.empty + + text_eq = apply_step( + FilterByText(column="text_value", op="eq", text="alpha"), + table=table, + source=None, + ctx=ctx, + ).execute() + assert list(text_eq.text_value) == ["alpha"] + + text_neq = apply_step( + FilterByText(column="text_value", op="neq", text="alpha"), + table=table, + source=None, + ctx=ctx, + ).execute() + assert set(text_neq.text_value) == {"beta", "gamma"} + + text_none = apply_step( + FilterByText(column="text_value", op="contains", text=None), + table=table, + source=None, + ctx=ctx, + ) + assert text_none is table + + text_like = apply_step( + FilterByText(column="text_value", op="contains", text="a"), + table=table, + source=None, + ctx=ctx, + ).execute() + assert set(text_like.text_value) == {"alpha", "beta", "gamma"} + + adjusted = apply_step( + ApplyDateAdjustment(start_offset_days=2, end_offset_days=1, start_with=END_DATE, end_with=START_DATE), + table=table, + source=None, + ctx=ctx, + ).execute() + assert str(adjusted.iloc[0][START_DATE])[:10] == "2020-01-07" + assert str(adjusted.iloc[0][END_DATE])[:10] == "2020-01-02" + + +def test_apply_step_covers_keep_first_person_filter_and_error_paths(): + ibis_mod = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis_mod.duckdb.connect() + table = _events_table(conn) + conn.create_table( + "person", + obj=ibis_mod.memtable( + { + PERSON_ID: [1, 2], + "gender_concept_id": [8507, 8532], + } + ), + overwrite=True, + ) + ctx = _Context(conn, codesets={9: ()}) + + first = apply_step( + KeepFirstPerPerson(order_by=(START_DATE,)), + table=table, + source=None, + ctx=ctx, + ) + assert first.columns == table.columns + assert "row_number()" in ibis_mod.to_sql(first).lower() + + filtered = apply_step( + FilterByPersonGender(concept_ids=(8507,), codeset_id=None), + table=table, + source=None, + ctx=ctx, + ).execute() + assert set(filtered[PERSON_ID]) == {1} + + care_site_empty = apply_step( + FilterByCareSiteLocationRegion(codeset_id=9), + table=table, + source=None, + ctx=ctx, + ).execute() + assert care_site_empty.empty + + with pytest.raises(CompilationError, match="unsupported text filter op"): + apply_step(FilterByText(column="text_value", op="weird", text="x"), table=table, source=None, ctx=ctx) + + with pytest.raises(UnsupportedFeatureError, match="RestrictToCorrelatedWindow step is not implemented"): + apply_step(RestrictToCorrelatedWindow(payload={}), table=table, source=None, ctx=ctx) + + with pytest.raises(CompilationError, match="unsupported plan step"): + apply_step(SimpleNamespace(), table=table, source=None, ctx=ctx) + + +def test_window_bound_expression_and_end_window_constraints(): + ibis_mod = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + assert ( + window_bound_expression( + None, + index_anchor_expr=ibis_mod.literal("2020-01-01").cast("date"), + use_observation_period=True, + op_start_expr=ibis_mod.literal("2019-01-01").cast("date"), + op_end_expr=ibis_mod.literal("2020-12-31").cast("date"), + ) + is None + ) + assert ( + window_bound_expression( + NormalizedWindowBound(coeff=1, days=None), + index_anchor_expr=ibis_mod.literal("2020-01-01").cast("date"), + use_observation_period=False, + op_start_expr=ibis_mod.literal("2019-01-01").cast("date"), + op_end_expr=ibis_mod.literal("2020-12-31").cast("date"), + ) + is None + ) + + joined = ibis_mod.memtable( + { + "a_person_id": [1, 1], + "p_person_id": [1, 1], + "a_start_date": [date(2020, 1, 3), date(2020, 1, 20)], + "a_end_date": [date(2020, 1, 5), date(2020, 1, 25)], + "p_start_date": [date(2020, 1, 1), date(2020, 1, 1)], + "p_end_date": [date(2020, 1, 10), date(2020, 1, 10)], + "p_op_start_date": [date(2019, 1, 1), date(2019, 1, 1)], + "p_op_end_date": [date(2020, 12, 31), date(2020, 12, 31)], + "a_visit_occurrence_id": [100, 101], + "p_visit_occurrence_id": [100, 100], + } + ) + correlated = SimpleNamespace( + ignore_observation_period=False, + restrict_visit=True, + start_window=None, + end_window=NormalizedWindow( + start=NormalizedWindowBound(coeff=1, days=0), + end=NormalizedWindowBound(coeff=1, days=10), + use_event_end=False, + use_index_end=False, + ), + ) + + result = apply_window_constraints(joined, correlated).execute() + assert len(result) == 1 + assert int(result.iloc[0]["a_visit_occurrence_id"]) == 100 diff --git a/tests/execution/test_context_wiring.py b/tests/execution/test_context_wiring.py new file mode 100644 index 00000000..1231f121 --- /dev/null +++ b/tests/execution/test_context_wiring.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from circe.execution.ibis.codesets import CachedConceptSetResolver +from circe.execution.ibis.context import ExecutionContext, make_execution_context +from circe.execution.normalize.cohort import NormalizedConceptSet, NormalizedConceptSetItem + + +class _BackendWithSchemaSupport: + def __init__(self): + self.calls: list[tuple[str, str | None]] = [] + + def table(self, name: str, database: str | None = None): + self.calls.append((name, database)) + return (name, database) + + +class _BackendWithoutSchemaSupport: + def __init__(self): + self.calls: list[tuple[str, str | None]] = [] + + def table(self, name: str, database: str | None = None): + self.calls.append((name, database)) + if database is not None: + raise TypeError("database kwarg not supported") + return (name, None) + + +def test_make_execution_context_uses_cdm_schema_as_vocabulary_fallback(): + backend = _BackendWithSchemaSupport() + ctx = make_execution_context( + backend=backend, + cdm_schema="cdm", + concept_sets={}, + ) + + assert isinstance(ctx, ExecutionContext) + assert ctx.vocabulary_schema == "cdm" + assert isinstance(ctx.codeset_resolver, CachedConceptSetResolver) + assert ctx.table("person") == ("person", "cdm") + assert ctx.concept_ids_for_codeset(999) == () + + +def test_make_execution_context_honors_vocabulary_schema_option_and_backend_fallback(): + backend = _BackendWithoutSchemaSupport() + ctx = make_execution_context( + backend=backend, + cdm_schema="cdm", + concept_sets={}, + vocabulary_schema="vocab", + ) + + assert ctx.vocabulary_schema == "vocab" + assert ctx.vocabulary_table("concept") == ("concept", None) + assert backend.calls == [("concept", "vocab"), ("concept", None)] + + +def test_codeset_resolver_caches_expanded_results(monkeypatch): + resolver = CachedConceptSetResolver( + table_getter=lambda name, schema: (name, schema), + vocabulary_schema="vocab", + concept_sets={ + 1: NormalizedConceptSet( + set_id=1, + items=( + NormalizedConceptSetItem( + concept_id=123, + is_excluded=False, + include_descendants=False, + include_mapped=False, + ), + ), + ) + }, + ) + calls: list[int] = [] + + def _expand(item): + calls.append(item.concept_id) + return {item.concept_id} + + monkeypatch.setattr(resolver, "_expand_item", _expand) + + assert resolver.resolve_codeset(1) == (123,) + assert resolver.resolve_codeset(1) == (123,) + assert calls == [123] + + +def test_codeset_resolver_handles_empty_and_non_dataframe_query_results(): + resolver = CachedConceptSetResolver( + table_getter=lambda name, schema: (name, schema), + vocabulary_schema="vocab", + concept_sets={}, + ) + + assert resolver._descendant_ids(set()) == set() + assert resolver._mapped_ids(set()) == set() + assert resolver._execute_concept_id_query(SimpleNamespace(execute=lambda: [1, None, 2])) == {1, 2} + assert resolver._execute_concept_id_query(SimpleNamespace(execute=lambda: 3)) == {3} diff --git a/tests/execution/test_databricks_compat.py b/tests/execution/test_databricks_compat.py new file mode 100644 index 00000000..448bd3df --- /dev/null +++ b/tests/execution/test_databricks_compat.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import pytest + +from circe.execution.databricks_compat import ( + _backend_looks_like_databricks, + _is_memtable_volume_error, + _post_connect_needs_workaround, + apply_databricks_post_connect_workaround, + maybe_apply_databricks_post_connect_workaround, +) + + +def test_databricks_post_connect_workaround_swallows_memtable_volume_error(): + class FakeDatabricksBackend: + def _post_connect(self): + raise RuntimeError("CREATE VOLUME IF NOT EXISTS my_catalog.my_schema.memtable") + + patched = apply_databricks_post_connect_workaround(backend_cls=FakeDatabricksBackend) + assert patched is True + + backend = FakeDatabricksBackend() + assert backend._post_connect() is None + + +def test_databricks_post_connect_workaround_keeps_non_volume_errors(): + class FakeDatabricksBackend: + def _post_connect(self): + _ = "CREATE VOLUME IF NOT EXISTS my_catalog.my_schema.memtable" + raise RuntimeError("different setup error") + + patched = apply_databricks_post_connect_workaround(backend_cls=FakeDatabricksBackend) + assert patched is True + + backend = FakeDatabricksBackend() + with pytest.raises(RuntimeError, match="different setup error"): + backend._post_connect() + + +def test_post_connect_needs_workaround_handles_missing_source_and_false_pattern(monkeypatch): + def _plain_post_connect(): + return None + + monkeypatch.setattr("inspect.getsource", lambda _fn: "plain setup") + assert _post_connect_needs_workaround(_plain_post_connect) is False + + monkeypatch.setattr("inspect.getsource", lambda _fn: (_ for _ in ()).throw(OSError("no source"))) + assert _post_connect_needs_workaround(_plain_post_connect) is True + + +def test_databricks_detection_helpers_cover_non_patched_paths(): + assert _is_memtable_volume_error(RuntimeError("memtable volume failure")) is True + assert _is_memtable_volume_error(RuntimeError("different failure")) is False + + assert _backend_looks_like_databricks(type("DatabricksConn", (), {})()) is True + assert _backend_looks_like_databricks(type("Backend", (), {"name": "databricks"})()) is True + assert _backend_looks_like_databricks(type("Backend", (), {"name": "duckdb"})()) is False + + +def test_apply_databricks_workaround_returns_false_when_not_patchable(): + class NoPostConnectBackend: + pass + + class PlainBackend: + def _post_connect(self): + return None + + assert apply_databricks_post_connect_workaround(backend_cls=None) is False + assert apply_databricks_post_connect_workaround(backend_cls=NoPostConnectBackend) is False + assert apply_databricks_post_connect_workaround(backend_cls=PlainBackend) is False + assert maybe_apply_databricks_post_connect_workaround(object()) is False + + +def test_apply_databricks_workaround_is_idempotent(): + class FakeDatabricksBackend: + def _post_connect(self): + raise RuntimeError("CREATE VOLUME IF NOT EXISTS my_catalog.my_schema.memtable") + + assert apply_databricks_post_connect_workaround(backend_cls=FakeDatabricksBackend) is True + assert apply_databricks_post_connect_workaround(backend_cls=FakeDatabricksBackend) is True + assert maybe_apply_databricks_post_connect_workaround(FakeDatabricksBackend()) is True diff --git a/tests/execution/test_domain_filter_parity.py b/tests/execution/test_domain_filter_parity.py new file mode 100644 index 00000000..def5b8a7 --- /dev/null +++ b/tests/execution/test_domain_filter_parity.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + Death, + DeviceExposure, + DrugExposure, + Measurement, + PrimaryCriteria, + Specimen, + VisitDetail, + VisitOccurrence, +) +from circe.cohortdefinition.core import ConceptSetSelection, DateAdjustment, NumericRange +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem +from tests.execution.test_api_ibis import _seed_common_tables + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def test_condition_occurrence_applies_related_filters_and_date_adjustment(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "provider", + obj=ibis.memtable( + { + "provider_id": [1, 2], + "specialty_concept_id": [8001, 8002], + } + ), + overwrite=True, + ) + conn.create_table( + "visit_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_occurrence_id": [10, 11], + "visit_concept_id": [7001, 7002], + "visit_start_date": ["2020-01-01", "2020-01-01"], + "visit_end_date": ["2020-01-03", "2020-01-03"], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-01-01"], + "condition_end_date": [None, "2020-01-03"], + "visit_occurrence_id": [10, 11], + "provider_id": [1, 2], + "condition_type_concept_id": [9001, 9002], + "condition_status_concept_id": [9101, 9102], + "stop_reason": ["keep me", "drop me"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + ConditionOccurrence( + codeset_id=1, + condition_type=[Concept(conceptId=9001)], + condition_status=[Concept(conceptId=9101)], + stop_reason={"op": "contains", "text": "keep"}, + provider_specialty=[Concept(conceptId=8001)], + visit_type=[Concept(conceptId=7001)], + occurrence_start_date={"op": "gte", "value": "2020-01-02"}, + occurrence_end_date={"op": "gte", "value": "2020-01-04"}, + date_adjustment=DateAdjustment(start_offset=1, end_offset=2), + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert list(result.person_id) == [1] + assert result.iloc[0].start_date.date().isoformat() == "2020-01-02" + + +def test_drug_exposure_applies_domain_filters_and_end_date_fallback(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "provider", + obj=ibis.memtable({"provider_id": [1, 2], "specialty_concept_id": [8001, 8002]}), + overwrite=True, + ) + conn.create_table( + "visit_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_occurrence_id": [10, 11], + "visit_concept_id": [7001, 7002], + "visit_start_date": ["2020-03-01", "2020-03-01"], + "visit_end_date": ["2020-03-02", "2020-03-02"], + } + ), + overwrite=True, + ) + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 2], + "drug_exposure_id": [200, 201], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": ["2020-03-01", "2020-03-01"], + "drug_exposure_end_date": [None, "2020-03-02"], + "visit_occurrence_id": [10, 11], + "provider_id": [1, 2], + "drug_type_concept_id": [3001, 3002], + "route_concept_id": [4001, 4002], + "dose_unit_concept_id": [5001, 5002], + "lot_number": ["A-LOT", "B-LOT"], + "quantity": [10.0, 1.0], + "days_supply": [5, 1], + "refills": [2, 0], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + DrugExposure( + codeset_id=2, + drug_type=[Concept(conceptId=3001)], + route_concept=[Concept(conceptId=4001)], + dose_unit=[Concept(conceptId=5001)], + lot_number={"op": "contains", "text": "A-"}, + quantity=NumericRange(op="gte", value=10), + days_supply=NumericRange(op="gte", value=5), + refills=NumericRange(op="gte", value=2), + provider_specialty=[Concept(conceptId=8001)], + visit_type=[Concept(conceptId=7001)], + occurrence_end_date={"op": "gte", "value": "2020-03-06"}, + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert list(result.person_id) == [1] + + +def test_visit_occurrence_applies_care_site_provider_location_and_duration_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "provider", + obj=ibis.memtable({"provider_id": [1, 2], "specialty_concept_id": [8001, 8002]}), + overwrite=True, + ) + conn.create_table( + "care_site", + obj=ibis.memtable( + { + "care_site_id": [100, 101], + "place_of_service_concept_id": [9001, 9002], + } + ), + overwrite=True, + ) + conn.create_table( + "location_history", + obj=ibis.memtable( + { + "entity_id": [100, 101], + "domain_id": ["CARE_SITE", "CARE_SITE"], + "location_id": [500, 501], + "start_date": ["2020-01-01", "2020-01-01"], + "end_date": [None, "2020-12-31"], + } + ), + overwrite=True, + ) + conn.create_table( + "location", + obj=ibis.memtable({"location_id": [500, 501], "region_concept_id": [6001, 6002]}), + overwrite=True, + ) + conn.create_table( + "visit_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_occurrence_id": [300, 301], + "visit_concept_id": [333, 333], + "visit_start_date": ["2020-05-01", "2020-05-01"], + "visit_end_date": ["2020-05-03", "2020-05-02"], + "visit_type_concept_id": [7001, 7002], + "provider_id": [1, 2], + "care_site_id": [100, 101], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(3, 333), _make_concept_set(31, 6001)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + VisitOccurrence( + codeset_id=3, + visit_type=[Concept(conceptId=7001)], + visit_length=NumericRange(op="gte", value=2), + provider_specialty=[Concept(conceptId=8001)], + place_of_service=[Concept(conceptId=9001)], + place_of_service_location=31, + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert list(result.person_id) == [1] + + +def test_device_exposure_applies_domain_filters_and_end_date_fallback(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "provider", + obj=ibis.memtable({"provider_id": [1, 2], "specialty_concept_id": [8001, 8002]}), + overwrite=True, + ) + conn.create_table( + "visit_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_occurrence_id": [10, 11], + "visit_concept_id": [7001, 7002], + "visit_start_date": ["2020-10-01", "2020-10-01"], + "visit_end_date": ["2020-10-02", "2020-10-02"], + } + ), + overwrite=True, + ) + conn.create_table( + "device_exposure", + obj=ibis.memtable( + { + "person_id": [1, 2], + "device_exposure_id": [800, 801], + "device_concept_id": [888, 888], + "device_exposure_start_date": ["2020-10-01", "2020-10-01"], + "device_exposure_end_date": [None, "2020-10-02"], + "visit_occurrence_id": [10, 11], + "provider_id": [1, 2], + "device_type_concept_id": [3001, 3002], + "unique_device_id": ["abc-123", "xyz-999"], + "quantity": [5, 1], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(8, 888)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + DeviceExposure( + codeset_id=8, + device_type=[Concept(conceptId=3001)], + unique_device_id={"op": "contains", "text": "abc"}, + quantity=NumericRange(op="gte", value=5), + provider_specialty=[Concept(conceptId=8001)], + visit_type=[Concept(conceptId=7001)], + occurrence_end_date={"op": "gte", "value": "2020-10-02"}, + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert list(result.person_id) == [1] + + +def test_specimen_applies_domain_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "specimen", + obj=ibis.memtable( + { + "person_id": [1, 2], + "specimen_id": [900, 901], + "specimen_concept_id": [9990, 9990], + "specimen_date": ["2020-11-01", "2020-11-01"], + "visit_occurrence_id": [10, 11], + "specimen_type_concept_id": [1001, 1002], + "quantity": [5.0, 1.0], + "unit_concept_id": [2001, 2002], + "anatomic_site_concept_id": [3001, 3002], + "disease_status_concept_id": [4001, 4002], + "specimen_source_id": ["keep-source", "drop-source"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(9, 9990)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + Specimen( + codeset_id=9, + specimen_type=[Concept(conceptId=1001)], + quantity=NumericRange(op="gte", value=5), + unit=[Concept(conceptId=2001)], + anatomic_site=[Concept(conceptId=3001)], + disease_status=[Concept(conceptId=4001)], + source_id={"op": "contains", "text": "keep"}, + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert list(result.person_id) == [1] + + +def test_death_applies_death_type_and_derived_end_date(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "death", + obj=ibis.memtable( + { + "person_id": [1, 2], + "cause_concept_id": [10001, 10001], + "cause_source_concept_id": [20001, 20002], + "death_type_concept_id": [3001, 3002], + "death_date": ["2020-12-01", "2020-12-01"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(10, 10001)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + Death( + codeset_id=10, + death_type=[Concept(conceptId=3001)], + occurrence_end_date={"op": "gte", "value": "2020-12-02"}, + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert list(result.person_id) == [1] + + +def test_measurement_and_visit_detail_apply_shared_related_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "provider", + obj=ibis.memtable({"provider_id": [1, 2], "specialty_concept_id": [8001, 8002]}), + overwrite=True, + ) + conn.create_table( + "visit_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_occurrence_id": [10, 11], + "visit_concept_id": [7001, 7002], + "visit_start_date": ["2020-06-01", "2020-06-01"], + "visit_end_date": ["2020-06-02", "2020-06-02"], + } + ), + overwrite=True, + ) + conn.create_table( + "care_site", + obj=ibis.memtable( + { + "care_site_id": [100, 101], + "place_of_service_concept_id": [9001, 9002], + } + ), + overwrite=True, + ) + conn.create_table( + "location_history", + obj=ibis.memtable( + { + "entity_id": [100, 101], + "domain_id": ["CARE_SITE", "CARE_SITE"], + "location_id": [500, 501], + "start_date": ["2020-01-01", "2020-01-01"], + "end_date": [None, "2020-12-31"], + } + ), + overwrite=True, + ) + conn.create_table( + "location", + obj=ibis.memtable({"location_id": [500, 501], "region_concept_id": [6001, 6002]}), + overwrite=True, + ) + conn.create_table( + "measurement", + obj=ibis.memtable( + { + "person_id": [1, 2], + "measurement_id": [400, 401], + "measurement_concept_id": [444, 444], + "measurement_date": ["2020-06-01", "2020-06-01"], + "visit_occurrence_id": [10, 11], + "provider_id": [1, 2], + } + ), + overwrite=True, + ) + conn.create_table( + "visit_detail", + obj=ibis.memtable( + { + "person_id": [1, 2], + "visit_detail_id": [710, 711], + "visit_detail_concept_id": [777, 777], + "visit_detail_start_date": ["2020-09-01", "2020-09-01"], + "visit_detail_end_date": ["2020-09-03", "2020-09-02"], + "visit_occurrence_id": [10, 11], + "provider_id": [1, 2], + "care_site_id": [100, 101], + } + ), + overwrite=True, + ) + + measurement_expression = CohortExpression( + concept_sets=[_make_concept_set(4, 444)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + Measurement( + codeset_id=4, + provider_specialty=[Concept(conceptId=8001)], + visit_type=[Concept(conceptId=7001)], + ) + ] + ), + ) + measurement_result = build_cohort( + measurement_expression, + backend=conn, + cdm_schema="main", + ).execute() + assert list(measurement_result.person_id) == [1] + + visit_detail_expression = CohortExpression( + concept_sets=[ + _make_concept_set(7, 777), + _make_concept_set(21, 8001), + _make_concept_set(22, 9001), + _make_concept_set(23, 6001), + ], + primary_criteria=PrimaryCriteria( + criteria_list=[ + VisitDetail( + codeset_id=7, + provider_specialty_cs=ConceptSetSelection(codeset_id=21, is_exclusion=False), + place_of_service_cs=ConceptSetSelection(codeset_id=22, is_exclusion=False), + place_of_service_location=23, + visit_detail_length=NumericRange(op="gte", value=2), + ) + ] + ), + ) + visit_detail_result = build_cohort( + visit_detail_expression, + backend=conn, + cdm_schema="main", + ).execute() + assert list(visit_detail_result.person_id) == [1] diff --git a/tests/execution/test_end_strategy_censoring.py b/tests/execution/test_end_strategy_censoring.py new file mode 100644 index 00000000..67f7912e --- /dev/null +++ b/tests/execution/test_end_strategy_censoring.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +from datetime import date +from types import SimpleNamespace + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import CohortExpression, ConditionOccurrence, PrimaryCriteria +from circe.cohortdefinition.core import CollapseSettings, DateOffsetStrategy, Period +from circe.execution.engine.end_strategy import apply_end_strategy +from circe.execution.errors import UnsupportedFeatureError +from circe.execution.normalize.end_strategy import NormalizedEndStrategy +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def _seed_common_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1], + "year_of_birth": [1980], + "gender_concept_id": [8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": ["2019-01-01"], + "observation_period_end_date": ["2021-12-31"], + } + ), + overwrite=True, + ) + + +def test_date_offset_end_strategy_applies_to_end_date(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": ["2020-01-01"], + "condition_end_date": ["2020-01-01"], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=DateOffsetStrategy(offset=30, date_field="start_date"), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-31" + + +def test_censoring_criteria_clips_end_date(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 222], + "condition_start_date": ["2020-01-01", "2020-01-10"], + "condition_end_date": ["2020-01-01", "2020-01-10"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + censoring_criteria=[ConditionOccurrence(codeset_id=2)], + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-10" + + +def test_censor_window_clips_start_and_end_dates(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": ["2020-01-01"], + "condition_end_date": ["2020-01-01"], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=DateOffsetStrategy(offset=40, date_field="start_date"), + censor_window=Period(start_date="2020-01-05", end_date="2020-01-20"), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-05" + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-20" + + +def test_collapse_settings_era_merges_intervals(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-01-03"], + "condition_end_date": ["2020-01-01", "2020-01-03"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=DateOffsetStrategy(offset=0, date_field="start_date"), + collapse_settings=CollapseSettings(era_pad=2), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.columns) == {"person_id", "start_date", "end_date"} + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-03" + + +def test_collapse_settings_era_does_not_merge_non_overlapping_intervals(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1], + "condition_occurrence_id": [100, 101, 102], + "condition_concept_id": [111, 111, 111], + "condition_start_date": ["2020-01-01", "2020-03-01", "2020-06-01"], + "condition_end_date": ["2020-01-01", "2020-03-01", "2020-06-01"], + "visit_occurrence_id": [10, 10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=DateOffsetStrategy(offset=0, date_field="start_date"), + collapse_settings=CollapseSettings(era_pad=1), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.columns) == {"person_id", "start_date", "end_date"} + assert len(result) == 3 + assert list(result.sort_values(["start_date", "end_date"]).start_date.astype(str)) == [ + "2020-01-01", + "2020-03-01", + "2020-06-01", + ] + + +def test_collapse_settings_era_deduplicates_identical_intervals(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-01-01"], + "condition_end_date": ["2020-01-01", "2020-01-01"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=DateOffsetStrategy(offset=14, date_field="end_date"), + collapse_settings=CollapseSettings(era_pad=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-15" + + +def test_collapse_settings_era_merges_tied_start_dates_into_one_group(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-01-01"], + "condition_end_date": ["2020-01-02", "2020-01-05"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=DateOffsetStrategy(offset=0, date_field="end_date"), + collapse_settings=CollapseSettings(era_pad=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-05" + + +def test_apply_end_strategy_rejects_invalid_date_field_and_preserves_fallback_semantics(): + ibis_mod = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis_mod.duckdb.connect() + conn.create_table( + "events", + obj=ibis_mod.memtable( + { + "person_id": [1], + "event_id": [100], + "start_date": [date(2020, 1, 1)], + "end_date": [date(2020, 1, 5)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis_mod.memtable( + { + "person_id": [1], + "observation_period_start_date": [date(2019, 1, 1)], + "observation_period_end_date": [date(2020, 1, 10)], + } + ), + overwrite=True, + ) + ctx = SimpleNamespace(table=lambda name: conn.table(name)) + events = conn.table("events") + + with pytest.raises(UnsupportedFeatureError, match="unsupported date_offset date field"): + apply_end_strategy( + events, + NormalizedEndStrategy(kind="date_offset", payload={"offset": 1, "date_field": "weird"}), + ctx, + ).execute() + + fallback = apply_end_strategy(events, NormalizedEndStrategy(kind="unknown", payload={}), ctx).execute() + assert str(fallback.iloc[0]["end_date"])[:10] == "2020-01-10" diff --git a/tests/execution/test_error_messages.py b/tests/execution/test_error_messages.py new file mode 100644 index 00000000..80133b45 --- /dev/null +++ b/tests/execution/test_error_messages.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + CorelatedCriteria, + Criteria, + CriteriaGroup, + DemographicCriteria, + Measurement, + Occurrence, + PrimaryCriteria, +) +from circe.cohortdefinition.core import CustomEraStrategy, NumericRange +from circe.execution.errors import CompilationError, UnsupportedCriterionError, UnsupportedFeatureError +from circe.execution.normalize.criteria import normalize_criterion +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _seed_common_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1], + "year_of_birth": [1980], + "gender_concept_id": [8507], + "race_concept_id": [8527], + "ethnicity_concept_id": [38003564], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": ["2019-01-01"], + "observation_period_end_date": ["2022-12-31"], + } + ), + overwrite=True, + ) + + +def _concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def test_error_message_for_custom_era_end_strategy(): + expression = CohortExpression( + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()]), + end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), + ) + + with pytest.raises(UnsupportedFeatureError, match="custom_era end strategy"): + _ = build_cohort(expression, backend=object(), cdm_schema="main") + + +def test_error_message_for_unsupported_criterion_type(): + with pytest.raises( + UnsupportedCriterionError, + match="normalization error: unsupported criterion type Criteria", + ): + _ = normalize_criterion(Criteria()) + + +def test_error_message_for_unsupported_numeric_op_during_compilation(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "measurement", + obj=ibis.memtable( + { + "person_id": [1], + "measurement_id": [100], + "measurement_concept_id": [444], + "measurement_date": ["2020-01-01"], + "visit_occurrence_id": [10], + "value_as_number": [5.0], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_concept_set(1, 444)], + primary_criteria=PrimaryCriteria( + criteria_list=[Measurement(codeset_id=1, value_as_number=NumericRange(op="nope", value=1))] + ), + ) + + with pytest.raises(CompilationError, match="compilation error: unsupported numeric range op"): + _ = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + +def test_error_message_for_unsupported_demographic_numeric_op(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 222], + "condition_start_date": ["2020-01-01", "2020-01-03"], + "condition_end_date": ["2020-01-01", "2020-01-03"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_concept_set(1, 111), _concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + demographic_criteria_list=[DemographicCriteria(age=NumericRange(op="invalid", value=18))], + ), + ) + + with pytest.raises( + UnsupportedFeatureError, + match="group evaluation error: unsupported demographic numeric range op", + ): + _ = build_cohort(expression, backend=conn, cdm_schema="main").execute() diff --git a/tests/execution/test_group_demographics.py b/tests/execution/test_group_demographics.py new file mode 100644 index 00000000..d11cc730 --- /dev/null +++ b/tests/execution/test_group_demographics.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import pytest + +from circe.execution.engine.group_demographics import ( + _apply_date_predicate, + _demographic_concept_ids, + demographic_match_keys, +) +from circe.execution.errors import UnsupportedFeatureError +from circe.execution.normalize.groups import NormalizedDemographicCriteria +from circe.execution.normalize.windows import NormalizedDateRange, NormalizedNumericRange + + +class _DemographicContext: + def __init__(self, conn, *, codesets: dict[int, tuple[int, ...]] | None = None): + self.conn = conn + self.codesets = codesets or {} + + def table(self, name: str): + return self.conn.table(name) + + def concept_ids_for_codeset(self, codeset_id: int) -> tuple[int, ...]: + return self.codesets.get(codeset_id, ()) + + +def _seed_demographic_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "year_of_birth": [1980, 1990, 1980], + "gender_concept_id": [8507, 8507, 8532], + "race_concept_id": [8527, 8516, 8527], + "ethnicity_concept_id": [38003564, 38003564, 38003563], + } + ), + overwrite=True, + ) + conn.create_table( + "index_events", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "event_id": [10, 20, 30], + "start_date": ["2020-01-05", "2020-02-05", "2020-01-10"], + "end_date": ["2020-01-20", "2020-02-20", "2020-01-15"], + } + ), + overwrite=True, + ) + + +def test_apply_date_predicate_rejects_invalid_between_and_op(): + ibis = pytest.importorskip("ibis") + + with pytest.raises(UnsupportedFeatureError, match="between' requires an extent value"): + _apply_date_predicate( + ibis.literal("2020-01-01"), + NormalizedDateRange(op="between", value="2020-01-01", extent=None), + ) + + with pytest.raises(UnsupportedFeatureError, match="unsupported demographic date range op"): + _apply_date_predicate( + ibis.literal("2020-01-01"), + NormalizedDateRange(op="invalid", value="2020-01-01", extent=None), + ) + + +def test_demographic_concept_ids_merge_codesets_without_duplicates(): + ctx = _DemographicContext(None, codesets={1: (8507, 8532)}) + + assert _demographic_concept_ids(explicit_ids=(8507,), codeset_id=1, ctx=ctx) == (8507, 8532) + + +def test_demographic_match_keys_applies_all_supported_filters(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_demographic_tables(conn, ibis) + ctx = _DemographicContext(conn, codesets={1: (8507,), 2: (38003564,)}) + + demographic = NormalizedDemographicCriteria( + age=NormalizedNumericRange(op="gte", value=30, extent=None), + gender_codeset_id=1, + race_concept_ids=(8527,), + ethnicity_codeset_id=2, + occurrence_start_date=NormalizedDateRange( + op="between", + value="2020-01-01", + extent="2020-01-31", + ), + occurrence_end_date=NormalizedDateRange( + op="lte", + value="2020-01-31", + extent=None, + ), + ) + + result = demographic_match_keys(conn.table("index_events"), demographic, ctx).execute() + + assert list(result.person_id) == [1] + assert list(result.event_id) == [10] diff --git a/tests/execution/test_groups.py b/tests/execution/test_groups.py new file mode 100644 index 00000000..016cec53 --- /dev/null +++ b/tests/execution/test_groups.py @@ -0,0 +1,926 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + CorelatedCriteria, + CriteriaColumn, + CriteriaGroup, + DemographicCriteria, + DrugEra, + Occurrence, + PrimaryCriteria, + Window, + WindowBound, +) +from circe.cohortdefinition.core import DateRange, NumericRange +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def _seed_common_tables(conn, ibis, *, persons=(1, 2, 3)): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": list(persons), + "year_of_birth": [1980 for _ in persons], + "gender_concept_id": [8507 for _ in persons], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": list(persons), + "observation_period_id": [10 + idx for idx, _ in enumerate(persons)], + "observation_period_start_date": ["2019-01-01" for _ in persons], + "observation_period_end_date": ["2022-12-31" for _ in persons], + } + ), + overwrite=True, + ) + + +def test_additional_criteria_all_filters_primary_events(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2], + "condition_occurrence_id": [100, 101, 102], + "condition_concept_id": [111, 222, 111], + "condition_start_date": ["2020-01-01", "2020-01-02", "2020-01-01"], + "condition_end_date": ["2020-01-01", "2020-01-02", "2020-01-01"], + "visit_occurrence_id": [10, 10, 20], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +@pytest.mark.parametrize( + ("group_type", "count", "expected_persons"), + [ + ("ANY", None, {1, 2, 3}), + ("ALL", None, {3}), + ("AT_LEAST", 2, {3}), + ("AT_MOST", 1, {1, 2}), + ], +) +def test_additional_group_operators(group_type, count, expected_persons): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2, 3)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2, 2, 3, 3, 3], + "condition_occurrence_id": [100, 101, 200, 201, 300, 301, 302], + "condition_concept_id": [111, 222, 111, 333, 111, 222, 333], + "condition_start_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-03", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-03", + ], + "visit_occurrence_id": [10, 10, 20, 20, 30, 30, 30], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type=group_type, + count=count, + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ), + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=3), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ), + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == expected_persons + + +def test_correlated_criteria_respects_restrict_visit_and_start_window(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2, 2], + "condition_occurrence_id": [100, 101, 200, 201], + "condition_concept_id": [111, 222, 111, 222], + "condition_start_date": [ + "2020-01-01", + "2020-01-06", + "2020-01-01", + "2020-01-10", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-06", + "2020-01-01", + "2020-01-10", + ], + "visit_occurrence_id": [10, 10, 20, 21], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + restrict_visit=True, + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=7), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + # Person 1 matches (same visit, +5 days). Person 2 fails (different visit and +9 days). + assert set(result.person_id) == {1} + + +def test_additional_demographic_criteria_groups_filter_primary_events(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2, 3)) + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "year_of_birth": [1980, 1980, 2010], + "gender_concept_id": [8507, 8507, 8507], + "race_concept_id": [8527, 8516, 8527], + "ethnicity_concept_id": [38003564, 38003564, 38003563], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "condition_occurrence_id": [100, 200, 300], + "condition_concept_id": [111, 111, 111], + "condition_start_date": ["2020-01-03", "2020-01-03", "2020-01-03"], + "condition_end_date": ["2020-01-03", "2020-01-03", "2020-01-03"], + "visit_occurrence_id": [10, 20, 30], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + demographic_criteria_list=[ + DemographicCriteria( + age=NumericRange(op="gte", value=18), + gender=[Concept(conceptId=8507)], + race=[Concept(conceptId=8527)], + ethnicity=[Concept(conceptId=38003564)], + occurrence_start_date=DateRange(op="gte", value="2020-01-02"), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_nested_correlated_criteria_inside_group_are_applied(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 2, 2], + "condition_occurrence_id": [100, 101, 102, 200, 201], + "condition_concept_id": [111, 222, 333, 111, 222], + "condition_start_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-15", + "2020-01-01", + "2020-01-10", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-15", + "2020-01-01", + "2020-01-10", + ], + "visit_occurrence_id": [10, 10, 10, 20, 20], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=2, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=3), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=10), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=20), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_nested_correlated_inner_any_mode_with_multiple_children(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2, 3)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 2, 2, 3, 3], + "condition_occurrence_id": [100, 101, 102, 200, 201, 300, 301], + "condition_concept_id": [111, 222, 333, 111, 222, 111, 444], + "condition_start_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-12", + "2020-01-01", + "2020-01-10", + "2020-01-01", + "2020-01-12", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-12", + "2020-01-01", + "2020-01-10", + "2020-01-01", + "2020-01-12", + ], + "visit_occurrence_id": [10, 10, 10, 20, 20, 30, 30], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + _make_concept_set(4, 444), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=2, + correlated_criteria=CriteriaGroup( + type="ANY", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=3), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=5), + use_event_end=False, + use_index_end=False, + ), + ), + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=4), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=5), + use_event_end=False, + use_index_end=False, + ), + ), + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=20), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_nested_correlated_group_demographics_are_applied(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2], + "year_of_birth": [1980, 1980], + "gender_concept_id": [8507, 8507], + "race_concept_id": [8527, 8516], + "ethnicity_concept_id": [38003564, 38003564], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2, 2], + "condition_occurrence_id": [100, 101, 200, 201], + "condition_concept_id": [111, 222, 111, 222], + "condition_start_date": ["2020-01-01", "2020-01-10", "2020-01-01", "2020-01-10"], + "condition_end_date": ["2020-01-01", "2020-01-10", "2020-01-01", "2020-01-10"], + "visit_occurrence_id": [10, 10, 20, 20], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=2, + correlated_criteria=CriteriaGroup( + type="ALL", + demographic_criteria_list=[ + DemographicCriteria( + age=NumericRange(op="gte", value=18), + race=[Concept(conceptId=8527)], + ) + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=20), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_nested_correlated_distinct_count_is_applied(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 1, 2, 2, 2, 2], + "condition_occurrence_id": [100, 101, 102, 103, 200, 201, 202, 203], + "condition_concept_id": [111, 222, 333, 333, 111, 222, 333, 333], + "condition_start_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-11", + "2020-01-12", + "2020-01-01", + "2020-01-10", + "2020-01-11", + "2020-01-11", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-11", + "2020-01-12", + "2020-01-01", + "2020-01-10", + "2020-01-11", + "2020-01-11", + ], + "visit_occurrence_id": [10, 10, 10, 10, 20, 20, 20, 20], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=2, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=3), + occurrence=Occurrence( + type=Occurrence._AT_LEAST, + count=2, + is_distinct=True, + count_column=CriteriaColumn.START_DATE, + ), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=5), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=20), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_nested_correlated_end_window_respects_index_end(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 2, 2, 2], + "condition_occurrence_id": [100, 101, 102, 200, 201, 202], + "condition_concept_id": [111, 222, 333, 111, 222, 333], + "condition_start_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-23", + "2020-01-01", + "2020-01-10", + "2020-01-27", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-20", + "2020-01-25", + "2020-01-01", + "2020-01-20", + "2020-01-29", + ], + "visit_occurrence_id": [10, 10, 10, 20, 20, 20], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=2, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=3), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + end_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=5), + use_event_end=False, + use_index_end=True, + ), + ) + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=20), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_nested_correlated_ignore_observation_period_changes_matching(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1,)) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": ["2019-01-01"], + "observation_period_end_date": ["2020-01-15"], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1], + "condition_occurrence_id": [100, 101, 102], + "condition_concept_id": [111, 222, 333], + "condition_start_date": ["2020-01-01", "2020-01-10", "2020-01-20"], + "condition_end_date": ["2020-01-01", "2020-01-10", "2020-01-20"], + "visit_occurrence_id": [10, 10, 10], + } + ), + overwrite=True, + ) + + def _expression(ignore_observation_period: bool) -> CohortExpression: + return CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=2, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=3), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=15), + use_event_end=False, + use_index_end=False, + ), + ignore_observation_period=ignore_observation_period, + ) + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=20), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + without_ignore = build_cohort(_expression(False), backend=conn, cdm_schema="main").execute() + with_ignore = build_cohort(_expression(True), backend=conn, cdm_schema="main").execute() + + assert len(without_ignore) == 0 + assert set(with_ignore.person_id) == {1} + + +def test_nested_correlated_multi_level_nesting_is_applied(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 1, 2, 2, 2], + "condition_occurrence_id": [100, 101, 102, 103, 200, 201, 202], + "condition_concept_id": [111, 222, 333, 444, 111, 222, 333], + "condition_start_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-12", + "2020-01-14", + "2020-01-01", + "2020-01-10", + "2020-01-12", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-10", + "2020-01-12", + "2020-01-14", + "2020-01-01", + "2020-01-10", + "2020-01-12", + ], + "visit_occurrence_id": [10, 10, 10, 10, 20, 20, 20], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + _make_concept_set(4, 444), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=2, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence( + codeset_id=3, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=4), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=5), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=5), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=20), + use_event_end=False, + use_index_end=False, + ), + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + + +def test_primary_drug_era_correlated_era_length_is_applied(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "drug_era", + obj=ibis.memtable( + { + "person_id": [1, 1, 2, 2], + "drug_era_id": [1300, 1301, 2300, 2301], + "drug_concept_id": [111, 222, 111, 222], + "drug_era_start_date": ["2020-01-01", "2020-03-20", "2020-01-01", "2020-03-20"], + "drug_era_end_date": ["2020-03-15", "2020-04-25", "2020-03-15", "2020-03-20"], + "drug_exposure_count": [3, 1, 3, 1], + "gap_days": [10, 0, 10, 0], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + DrugEra( + codeset_id=1, + era_length=NumericRange(op="gte", value=30), + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=DrugEra( + codeset_id=2, + era_length=NumericRange(op="gte", value=30), + ), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + start_window=Window( + start=WindowBound(coeff=1, days=0), + end=WindowBound(coeff=1, days=60), + use_event_end=False, + use_index_end=True, + ), + ) + ], + ), + ) + ] + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert set(result.person_id) == {1} diff --git a/tests/execution/test_ibis_compat.py b/tests/execution/test_ibis_compat.py new file mode 100644 index 00000000..8f71ee68 --- /dev/null +++ b/tests/execution/test_ibis_compat.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import pytest + +from circe.execution.ibis_compat import literal_column_relation, literal_rows_relation + + +def test_literal_column_relation_round_trips_values(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + relation = literal_column_relation( + [3, 1, 2], + column_name="value", + dtype="int64", + backend=conn, + ) + result = relation.execute() + + assert sorted(result["value"].tolist()) == [1, 2, 3] + + +def test_literal_column_relation_empty_preserves_schema(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + relation = literal_column_relation( + [], + column_name="value", + dtype="int64", + backend=conn, + ) + result = relation.execute() + + assert list(result.columns) == ["value"] + assert len(result) == 0 + + +def test_literal_rows_relation_round_trips_typed_rows(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + relation = literal_rows_relation( + [ + {"cohort_id": 1, "cohort_name": "A", "is_subset": False}, + {"cohort_id": 2, "cohort_name": None, "is_subset": True}, + ], + schema={ + "cohort_id": "int64", + "cohort_name": "string", + "is_subset": "boolean", + }, + backend=conn, + ) + result = relation.execute().sort_values("cohort_id").reset_index(drop=True) + + assert list(result["cohort_id"]) == [1, 2] + assert result.loc[0, "cohort_name"] == "A" + assert result.loc[1, "cohort_name"] is None or result["cohort_name"].isna().iloc[1] + assert list(result["is_subset"]) == [False, True] + + +def test_literal_rows_relation_empty_relation(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + relation = literal_rows_relation( + [], + schema={"cohort_id": "int64", "status": "string"}, + backend=conn, + ) + result = relation.execute() + + assert list(result.columns) == ["cohort_id", "status"] + assert len(result) == 0 diff --git a/tests/execution/test_inclusion.py b/tests/execution/test_inclusion.py new file mode 100644 index 00000000..46ce20ee --- /dev/null +++ b/tests/execution/test_inclusion.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + CorelatedCriteria, + CriteriaGroup, + InclusionRule, + Occurrence, + PrimaryCriteria, +) +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def _seed_common_tables(conn, ibis, *, persons=(1, 2, 3)): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": list(persons), + "year_of_birth": [1980 for _ in persons], + "gender_concept_id": [8507 for _ in persons], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": list(persons), + "observation_period_id": [10 + idx for idx, _ in enumerate(persons)], + "observation_period_start_date": ["2019-01-01" for _ in persons], + "observation_period_end_date": ["2022-12-31" for _ in persons], + } + ), + overwrite=True, + ) + + +def test_inclusion_rules_require_all_rules_to_match(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2, 3)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2, 2, 3, 3, 3], + "condition_occurrence_id": [100, 101, 200, 201, 300, 301, 302], + "condition_concept_id": [111, 222, 111, 333, 111, 222, 333], + "condition_start_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-03", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-01", + "2020-01-02", + "2020-01-03", + ], + "visit_occurrence_id": [10, 10, 20, 20, 30, 30, 30], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + _make_concept_set(3, 333), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + inclusion_rules=[ + InclusionRule( + name="rule-1", + expression=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ), + InclusionRule( + name="rule-2", + expression=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=3), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ), + ], + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {3} + + +def test_inclusion_rule_without_expression_is_noop(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 2], + "condition_occurrence_id": [100, 200], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-01-01"], + "condition_end_date": ["2020-01-01", "2020-01-01"], + "visit_occurrence_id": [10, 20], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + inclusion_rules=[InclusionRule(name="empty", expression=None)], + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1, 2} diff --git a/tests/execution/test_lower_contracts.py b/tests/execution/test_lower_contracts.py new file mode 100644 index 00000000..576de634 --- /dev/null +++ b/tests/execution/test_lower_contracts.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import pytest + +from circe.cohortdefinition import ConditionOccurrence +from circe.execution.lower.criteria import lower_criterion +from circe.execution.normalize.criteria import normalize_criterion +from circe.execution.plan.events import ( + FilterByCodeset, + FilterByPersonEthnicity, + FilterByPersonRace, + StandardizeEventShape, +) +from circe.vocabulary import Concept +from tests.execution._domain_cases import domain_criteria_cases + + +@pytest.mark.parametrize(("source_table", "factory", "concept_id"), domain_criteria_cases()) +def test_lower_contract_emits_source_and_standardization( + source_table, + factory, + concept_id, +): + criteria = factory() + normalized = normalize_criterion(criteria) + plan = lower_criterion(normalized, criterion_index=17) + + assert plan.source.table_name == source_table + assert plan.criterion_type == criteria.__class__.__name__ + assert any(isinstance(step, StandardizeEventShape) for step in plan.steps) + + has_codeset_step = any(isinstance(step, FilterByCodeset) for step in plan.steps) + assert has_codeset_step is (concept_id is not None) + + +def test_lower_contract_emits_person_race_and_ethnicity_steps_when_present(): + criteria = ConditionOccurrence(codeset_id=1) + criteria.__dict__["race"] = [Concept(conceptId=8527)] + criteria.__dict__["ethnicity"] = [Concept(conceptId=38003564)] + + plan = lower_criterion(normalize_criterion(criteria), criterion_index=18) + assert any(isinstance(step, FilterByPersonRace) for step in plan.steps) + assert any(isinstance(step, FilterByPersonEthnicity) for step in plan.steps) diff --git a/tests/execution/test_lowering.py b/tests/execution/test_lowering.py new file mode 100644 index 00000000..8e29a57a --- /dev/null +++ b/tests/execution/test_lowering.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import pytest + +from circe.cohortdefinition import ( + ConditionEra, + ConditionOccurrence, + Death, + DeviceExposure, + DoseEra, + DrugEra, + LocationRegion, + Measurement, + Observation, + ObservationPeriod, + PayerPlanPeriod, + ProcedureOccurrence, + Specimen, + VisitDetail, + VisitOccurrence, +) +from circe.execution.lower.criteria import lower_criterion +from circe.execution.normalize.criteria import normalize_criterion +from circe.execution.plan.events import ( + FilterByCareSite, + FilterByCareSiteLocationRegion, + FilterByCodeset, + FilterByConceptSet, + FilterByDateRange, + FilterByNumericRange, + FilterByPersonEthnicity, + FilterByPersonRace, + FilterByProviderSpecialty, + FilterByText, + FilterByVisit, + KeepFirstPerPerson, + StandardizeEventShape, +) +from circe.execution.plan.schema import DURATION, START_DATE +from circe.vocabulary import Concept + + +def test_lowering_condition_occurrence_emits_expected_steps(): + normalized = normalize_criterion( + ConditionOccurrence( + codeset_id=1, + first=True, + ) + ) + + plan = lower_criterion(normalized, criterion_index=3) + + assert plan.source.table_name == "condition_occurrence" + assert plan.source.concept_column == "condition_concept_id" + assert any(isinstance(step, FilterByCodeset) for step in plan.steps) + assert any(isinstance(step, KeepFirstPerPerson) for step in plan.steps) + standardize = [step for step in plan.steps if isinstance(step, StandardizeEventShape)] + assert len(standardize) == 1 + assert standardize[0].criterion_index == 3 + + +def test_lowering_measurement_emits_domain_specific_filter_steps(): + normalized = normalize_criterion( + Measurement( + codeset_id=1, + value_as_number={"op": "gte", "value": 10}, + unit=[{"conceptId": 9002}], + value_as_concept=[{"conceptId": 7002}], + ) + ) + + plan = lower_criterion(normalized, criterion_index=5) + + assert any(isinstance(step, FilterByNumericRange) for step in plan.steps) + # unit + value_as_concept should emit concept filters in addition to codeset filter + concept_steps = [step for step in plan.steps if isinstance(step, FilterByConceptSet)] + assert len(concept_steps) >= 2 + + +def test_lowering_observation_procedure_visit_detail_emit_domain_filters(): + observation_plan = lower_criterion( + normalize_criterion( + Observation( + codeset_id=1, + observation_type=[Concept(conceptId=1001)], + value_as_number={"op": "gte", "value": 2}, + value_as_string={"op": "contains", "text": "abc"}, + ) + ), + criterion_index=6, + ) + assert any(isinstance(step, FilterByConceptSet) for step in observation_plan.steps) + assert any(isinstance(step, FilterByNumericRange) for step in observation_plan.steps) + assert any(isinstance(step, FilterByText) for step in observation_plan.steps) + + procedure_plan = lower_criterion( + normalize_criterion( + ProcedureOccurrence( + codeset_id=1, + procedure_type=[Concept(conceptId=2001)], + quantity={"op": "gte", "value": 1}, + ) + ), + criterion_index=7, + ) + assert any(isinstance(step, FilterByConceptSet) for step in procedure_plan.steps) + assert any(isinstance(step, FilterByNumericRange) for step in procedure_plan.steps) + + visit_detail_plan = lower_criterion( + normalize_criterion( + VisitDetail( + codeset_id=1, + visit_detail_type=[Concept(conceptId=3001)], + discharge_to=[Concept(conceptId=3002)], + ) + ), + criterion_index=8, + ) + concept_steps = [s for s in visit_detail_plan.steps if isinstance(s, FilterByConceptSet)] + assert len(concept_steps) >= 2 + + +@pytest.mark.parametrize( + ("criteria", "table_name", "concept_column", "expects_codeset_step"), + [ + (Measurement(codeset_id=1), "measurement", "measurement_concept_id", True), + ( + ProcedureOccurrence(codeset_id=1), + "procedure_occurrence", + "procedure_concept_id", + True, + ), + (Observation(codeset_id=1), "observation", "observation_concept_id", True), + (VisitDetail(codeset_id=1), "visit_detail", "visit_detail_concept_id", True), + (DeviceExposure(codeset_id=1), "device_exposure", "device_concept_id", True), + (Specimen(codeset_id=1), "specimen", "specimen_concept_id", True), + (Death(codeset_id=1), "death", "cause_concept_id", True), + (ObservationPeriod(), "observation_period", "period_type_concept_id", False), + (PayerPlanPeriod(), "payer_plan_period", "payer_concept_id", False), + (ConditionEra(codeset_id=1), "condition_era", "condition_concept_id", True), + (DrugEra(codeset_id=1), "drug_era", "drug_concept_id", True), + (DoseEra(codeset_id=1), "dose_era", "drug_concept_id", True), + (LocationRegion(codeset_id=1), "location_history", "region_concept_id", True), + ], +) +def test_lowering_new_domains_emit_standardized_plans( + criteria, + table_name, + concept_column, + expects_codeset_step, +): + normalized = normalize_criterion(criteria) + plan = lower_criterion(normalized, criterion_index=4) + + assert plan.source.table_name == table_name + assert plan.source.concept_column == concept_column + assert any(isinstance(step, FilterByCodeset) for step in plan.steps) is expects_codeset_step + standardize = [step for step in plan.steps if isinstance(step, StandardizeEventShape)] + assert len(standardize) == 1 + + +def test_lowering_emits_race_and_ethnicity_person_filters(): + criteria = ConditionOccurrence(codeset_id=1) + criteria.__dict__["race"] = [Concept(conceptId=8527)] + criteria.__dict__["ethnicity"] = [Concept(conceptId=38003564)] + + plan = lower_criterion(normalize_criterion(criteria), criterion_index=9) + assert any(isinstance(step, FilterByPersonRace) for step in plan.steps) + assert any(isinstance(step, FilterByPersonEthnicity) for step in plan.steps) + + +def test_lowering_condition_occurrence_emits_related_filters_and_post_standardized_dates(): + normalized = normalize_criterion( + ConditionOccurrence( + codeset_id=1, + occurrence_start_date={"op": "gte", "value": "2020-01-02"}, + condition_type=[{"conceptId": 1001}], + provider_specialty=[{"conceptId": 2001}], + visit_type=[{"conceptId": 3001}], + date_adjustment={ + "startOffset": 1, + "endOffset": 2, + }, + ) + ) + + plan = lower_criterion(normalized, criterion_index=10) + + assert any(isinstance(step, FilterByConceptSet) for step in plan.steps) + assert any(isinstance(step, FilterByProviderSpecialty) for step in plan.steps) + assert any(isinstance(step, FilterByVisit) for step in plan.steps) + date_steps = [step for step in plan.steps if isinstance(step, FilterByDateRange)] + assert len(date_steps) == 1 + assert date_steps[0].column == START_DATE + standardize = next(step for step in plan.steps if isinstance(step, StandardizeEventShape)) + assert standardize.start_offset_days == 1 + assert standardize.end_offset_days == 2 + + +def test_lowering_visit_occurrence_emits_care_site_and_duration_filters(): + normalized = normalize_criterion( + VisitOccurrence( + codeset_id=1, + visit_type=[{"conceptId": 1001}], + visit_length={"op": "gte", "value": 2}, + provider_specialty=[{"conceptId": 2001}], + place_of_service=[{"conceptId": 3001}], + place_of_service_location=4, + ) + ) + + plan = lower_criterion(normalized, criterion_index=11) + + assert any(isinstance(step, FilterByProviderSpecialty) for step in plan.steps) + assert any(isinstance(step, FilterByCareSite) for step in plan.steps) + assert any(isinstance(step, FilterByCareSiteLocationRegion) for step in plan.steps) + duration_steps = [ + step for step in plan.steps if isinstance(step, FilterByNumericRange) and step.column == DURATION + ] + assert len(duration_steps) == 1 diff --git a/tests/execution/test_normalize.py b/tests/execution/test_normalize.py new file mode 100644 index 00000000..bd69e13b --- /dev/null +++ b/tests/execution/test_normalize.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from circe.cohortdefinition import ( + CohortExpression, + ConditionEra, + ConditionOccurrence, + CorelatedCriteria, + CriteriaGroup, + Death, + DeviceExposure, + DoseEra, + DrugEra, + InclusionRule, + LocationRegion, + Measurement, + Observation, + ObservationPeriod, + Occurrence, + PayerPlanPeriod, + PrimaryCriteria, + ProcedureOccurrence, + Specimen, + VisitDetail, +) +from circe.cohortdefinition.core import ConceptSetSelection, NumericRange +from circe.execution.normalize.cohort import normalize_cohort +from circe.execution.normalize.criteria import normalize_criterion +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _concept_set(set_id: int, include: int, exclude: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression( + items=[ + ConceptSetItem(concept=Concept(conceptId=include), isExcluded=False), + ConceptSetItem(concept=Concept(conceptId=exclude), isExcluded=True), + ] + ), + ) + + +def test_normalize_cohort_extracts_codesets_and_keeps_expression_immutable(): + expression = CohortExpression( + title="Normalize Test", + concept_sets=[_concept_set(1, include=111, exclude=999)], + primary_criteria=PrimaryCriteria( + criteria_list=[ + ConditionOccurrence( + codeset_id=1, + first=True, + age=NumericRange(op="gte", value=18), + ) + ] + ), + ) + before = expression.model_dump_json(by_alias=True, exclude_none=False) + + normalized = normalize_cohort(expression) + + after = expression.model_dump_json(by_alias=True, exclude_none=False) + assert before == after + assert normalized.title == "Normalize Test" + assert 1 in normalized.concept_sets + assert tuple(item.concept_id for item in normalized.concept_sets[1].items) == ( + 111, + 999, + ) + assert len(normalized.primary.criteria) == 1 + criterion = normalized.primary.criteria[0] + assert criterion.criterion_type == "ConditionOccurrence" + assert criterion.codeset_id == 1 + assert criterion.first is True + assert criterion.person_filters.age is not None + + +def test_normalize_cohort_additional_criteria_group(): + expression = CohortExpression( + concept_sets=[_concept_set(1, include=111, exclude=999)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ANY", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=1), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ) + + normalized = normalize_cohort(expression) + assert normalized.additional_criteria is not None + assert normalized.additional_criteria.mode == "ANY" + assert len(normalized.additional_criteria.criteria) == 1 + + +def test_normalize_cohort_inclusion_rules(): + expression = CohortExpression( + concept_sets=[_concept_set(1, include=111, exclude=999)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + inclusion_rules=[ + InclusionRule( + name="rule-1", + expression=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=1), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ) + ], + ) + + normalized = normalize_cohort(expression) + assert len(normalized.inclusion_rules) == 1 + assert normalized.inclusion_rules[0].name == "rule-1" + assert normalized.inclusion_rules[0].expression is not None + + +def test_normalize_new_domains(): + cases = [ + (Measurement(codeset_id=1), "measurement"), + (ProcedureOccurrence(codeset_id=1), "procedure_occurrence"), + (Observation(codeset_id=1), "observation"), + (VisitDetail(codeset_id=1), "visit_detail"), + (DeviceExposure(codeset_id=1), "device_exposure"), + (Specimen(codeset_id=1), "specimen"), + (Death(codeset_id=1), "death"), + (ObservationPeriod(), "observation_period"), + (PayerPlanPeriod(), "payer_plan_period"), + (ConditionEra(codeset_id=1), "condition_era"), + (DrugEra(codeset_id=1), "drug_era"), + (DoseEra(codeset_id=1), "dose_era"), + (LocationRegion(codeset_id=1), "location_history"), + ] + for criteria, expected_table in cases: + normalized = normalize_criterion(criteria) + assert normalized.source_table == expected_table + + +def test_normalize_cohort_preserves_concept_set_item_expansion_flags(): + expression = CohortExpression( + concept_sets=[ + ConceptSet( + id=1, + expression=ConceptSetExpression( + items=[ + ConceptSetItem( + concept=Concept(conceptId=111), + includeDescendants=True, + ) + ] + ), + ) + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + ) + + normalized = normalize_cohort(expression) + assert normalized.concept_sets[1].items[0].include_descendants is True + + +def test_normalize_cohort_preserves_expression_level_concept_set_flags(): + expression = CohortExpression( + concept_sets=[ + ConceptSet( + id=1, + expression=ConceptSetExpression( + concept=Concept(conceptId=111), + includeMapped=True, + ), + ) + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + ) + + normalized = normalize_cohort(expression) + normalized_item = normalized.concept_sets[1].items[0] + assert normalized_item.concept_id == 111 + assert normalized_item.include_mapped is True + assert normalized_item.is_excluded is False + + +def test_normalize_criterion_preserves_criterion_local_correlated_criteria(): + criteria = ConditionOccurrence( + codeset_id=1, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=1), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ) + + normalized = normalize_criterion(criteria) + assert normalized.correlated_criteria is not None + assert normalized.correlated_criteria.mode == "ALL" + assert len(normalized.correlated_criteria.criteria) == 1 + + +def test_normalize_criterion_includes_race_and_ethnicity_person_filters(): + criteria = ConditionOccurrence(codeset_id=1) + criteria.__dict__["race"] = [Concept(conceptId=8527)] + criteria.__dict__["race_cs"] = ConceptSetSelection(codeset_id=2, is_exclusion=False) + criteria.__dict__["ethnicity"] = [Concept(conceptId=38003564)] + criteria.__dict__["ethnicity_cs"] = ConceptSetSelection( + codeset_id=3, + is_exclusion=False, + ) + + normalized = normalize_criterion(criteria) + assert normalized.person_filters.race_concept_ids == (8527,) + assert normalized.person_filters.race_codeset_id == 2 + assert normalized.person_filters.ethnicity_concept_ids == (38003564,) + assert normalized.person_filters.ethnicity_codeset_id == 3 diff --git a/tests/execution/test_normalize_contracts.py b/tests/execution/test_normalize_contracts.py new file mode 100644 index 00000000..569c2a6e --- /dev/null +++ b/tests/execution/test_normalize_contracts.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest + +from circe.cohortdefinition import CohortExpression, PrimaryCriteria +from circe.execution.normalize.cohort import normalize_cohort +from circe.execution.normalize.criteria import normalize_criterion +from circe.execution.normalize.groups import NormalizedCriteriaGroup +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem +from tests.execution._domain_cases import domain_criteria_cases + + +@pytest.mark.parametrize(("source_table", "factory", "_"), domain_criteria_cases()) +def test_normalize_criterion_contract(source_table, factory, _): + criteria = factory() + normalized = normalize_criterion(criteria) + + assert normalized.criterion_type == criteria.__class__.__name__ + assert normalized.source_table == source_table + assert normalized.domain + assert normalized.event_id_column + assert normalized.start_date_column + assert normalized.end_date_column + + +@pytest.mark.parametrize(("source_table", "factory", "concept_id"), domain_criteria_cases()) +def test_normalize_cohort_does_not_mutate_public_expression( + source_table, + factory, + concept_id, +): + del source_table + + criteria = factory() + concept_sets = [] + if concept_id is not None: + concept_sets = [ + ConceptSet( + id=1, + expression=ConceptSetExpression( + items=[ConceptSetItem(concept=Concept(conceptId=concept_id))] + ), + ) + ] + + expression = CohortExpression( + concept_sets=concept_sets, + primary_criteria=PrimaryCriteria(criteria_list=[criteria]), + ) + before = expression.model_dump_json(by_alias=True, exclude_none=False) + + normalized = normalize_cohort(expression) + after = expression.model_dump_json(by_alias=True, exclude_none=False) + + assert before == after + assert len(normalized.primary.criteria) == 1 + assert isinstance(normalized.additional_criteria, (type(None), NormalizedCriteriaGroup)) diff --git a/tests/execution/test_operations.py b/tests/execution/test_operations.py new file mode 100644 index 00000000..ce42448a --- /dev/null +++ b/tests/execution/test_operations.py @@ -0,0 +1,328 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from circe.execution.errors import ExecutionError +from circe.execution.ibis.operations import ( + _catalog_db_tuple, + _run_transaction_control, + cohort_rows_exist, + create_table, + delete_cohort_rows, + exclude_cohort_rows, + insert_relation, + read_table, + replace_cohort_rows_transactionally, + supports_transactional_replace, + table_exists, +) + + +class _Backend: + name = "duckdb" + compiler = SimpleNamespace(quoted=False) + + def __init__(self, *, fail_insert: bool = False): + self.fail_insert = fail_insert + self.events: list[tuple[str, object]] = [] + + def raw_sql(self, query): + sql = query.sql("duckdb") if hasattr(query, "sql") else query + self.events.append(("sql", sql)) + + def insert(self, name, obj, *, database=None, overwrite=False): + self.events.append(("insert", name, database, overwrite)) + if self.fail_insert: + raise RuntimeError("boom") + + +class _SchemaFallbackBackend: + def __init__(self): + self.calls: list[tuple[str, object, object]] = [] + + def table(self, name, database=None): + self.calls.append(("table", name, database)) + if database is not None: + raise TypeError("database kwarg not supported") + return (name, None) + + def create_table(self, name, *, obj=None, database=None, overwrite=False, temp=False): + self.calls.append(("create_table", name, database)) + if database is not None: + raise TypeError("database kwarg not supported") + return None + + def insert(self, name, obj, *, database=None, overwrite=False): + self.calls.append(("insert", name, database)) + if database is not None: + raise TypeError("database kwarg not supported") + return None + + +class _ListTablesBackend: + def __init__(self, tables: list[str], *, reject_database: bool = False): + self.tables = tables + self.reject_database = reject_database + self.calls: list[str | None] = [] + + def list_tables(self, database=None): + self.calls.append(database) + if self.reject_database and database is not None: + raise TypeError("database kwarg not supported") + return self.tables + + +class _TableBackend: + def __init__(self, relation=None, *, fail: bool = False): + self.relation = relation + self.fail = fail + + def table(self, name, database=None): + if self.fail: + raise RuntimeError("boom") + return self.relation + + +class _CohortColumn: + def cast(self, _dtype): + return self + + def __eq__(self, other): + return ("eq", other) + + def __ne__(self, other): + return ("ne", other) + + +class _CohortRelation: + cohort_definition_id = _CohortColumn() + + def __init__(self, rows, *, fail_filter: bool = False): + self.rows = rows + self.fail_filter = fail_filter + + def filter(self, _predicate): + if self.fail_filter: + raise RuntimeError("boom") + return self + + def limit(self, _count): + return self + + def execute(self): + return self.rows + + +class _RawSqlBackend: + compiler = SimpleNamespace(quoted=False) + + def __init__(self, *, fail: bool = False): + self.fail = fail + self.calls: list[object] = [] + + def raw_sql(self, statement): + self.calls.append(statement) + if self.fail: + raise RuntimeError("boom") + + +class _CatalogBackend: + def _to_sqlglot_table(self, schema): + return f"table:{schema}" + + def _to_catalog_db_tuple(self, table): + assert table.startswith("table:") + return ("catalog", "database") + + +class _BrokenCatalogBackend: + def _to_sqlglot_table(self, _schema): + raise RuntimeError("boom") + + +def test_replace_cohort_rows_transactionally_commits_on_success(): + backend = _Backend() + + replace_cohort_rows_transactionally( + object(), + backend=backend, + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + + assert backend.events[0] == ("sql", "BEGIN") + assert backend.events[1][0] == "sql" + assert "DELETE FROM main.cohort_out WHERE cohort_definition_id = 5" in backend.events[1][1] + assert backend.events[2] == ("insert", "cohort_out", "main", False) + assert backend.events[3] == ("sql", "COMMIT") + + +def test_replace_cohort_rows_transactionally_rolls_back_on_insert_failure(): + backend = _Backend(fail_insert=True) + + with pytest.raises(ExecutionError, match="failed inserting relation into table 'cohort_out'"): + replace_cohort_rows_transactionally( + object(), + backend=backend, + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + + assert backend.events[0] == ("sql", "BEGIN") + assert backend.events[1][0] == "sql" + assert backend.events[2] == ("insert", "cohort_out", "main", False) + assert backend.events[3] == ("sql", "ROLLBACK") + + +def test_read_table_falls_back_when_backend_rejects_database_kwarg(): + backend = _SchemaFallbackBackend() + + result = read_table( + backend, + table_name="cohort_out", + schema="main", + ) + + assert result == ("cohort_out", None) + assert backend.calls == [ + ("table", "cohort_out", "main"), + ("table", "cohort_out", None), + ] + + +def test_create_table_falls_back_when_backend_rejects_database_kwarg(): + backend = _SchemaFallbackBackend() + + create_table( + backend, + table_name="cohort_out", + schema="main", + obj=object(), + overwrite=True, + ) + + assert backend.calls == [ + ("create_table", "cohort_out", "main"), + ("create_table", "cohort_out", None), + ] + + +def test_insert_relation_falls_back_when_backend_rejects_database_kwarg(): + backend = _SchemaFallbackBackend() + + insert_relation( + object(), + backend=backend, + target_table="cohort_out", + target_schema="main", + ) + + assert backend.calls == [ + ("insert", "cohort_out", "main"), + ("insert", "cohort_out", None), + ] + + +def test_table_exists_uses_list_tables_with_database_fallback(): + backend = _ListTablesBackend(["cohort_out"], reject_database=True) + + assert table_exists(backend, table_name="cohort_out", schema="main") is True + assert backend.calls == ["main", None] + + +def test_table_exists_falls_back_to_read_table_when_list_tables_is_unavailable(): + assert table_exists(_TableBackend(object()), table_name="cohort_out", schema="main") is True + assert table_exists(_TableBackend(fail=True), table_name="cohort_out", schema="main") is False + + +def test_cohort_rows_exist_returns_true_and_false_from_relation(): + assert ( + cohort_rows_exist( + _TableBackend(_CohortRelation([{"cohort_definition_id": 5}])), + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + is True + ) + assert ( + cohort_rows_exist( + _TableBackend(_CohortRelation([])), + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + is False + ) + + +def test_cohort_rows_exist_wraps_relation_errors(): + with pytest.raises(ExecutionError, match="failed checking existing rows for cohort_id=5"): + cohort_rows_exist( + _TableBackend(_CohortRelation([], fail_filter=True)), + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + + +def test_delete_cohort_rows_requires_raw_sql_support(): + with pytest.raises(ExecutionError, match="does not support raw_sql for cohort-table deletes"): + delete_cohort_rows( + object(), + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + + +def test_delete_cohort_rows_wraps_backend_failures(): + with pytest.raises(ExecutionError, match="failed deleting existing cohort rows"): + delete_cohort_rows( + _RawSqlBackend(fail=True), + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + + +def test_supports_transactional_replace_only_for_supported_backends(): + assert supports_transactional_replace(SimpleNamespace(name="duckdb")) is True + assert supports_transactional_replace(SimpleNamespace(name="postgres")) is True + assert supports_transactional_replace(SimpleNamespace(name="sqlite")) is False + + +def test_replace_cohort_rows_transactionally_rejects_unsupported_backends(): + with pytest.raises(ExecutionError, match="does not support transactional cohort-table replace"): + replace_cohort_rows_transactionally( + object(), + backend=SimpleNamespace(name="sqlite"), + cohort_table="cohort_out", + results_schema="main", + cohort_id=5, + ) + + +def test_exclude_cohort_rows_wraps_filter_errors(): + with pytest.raises(ExecutionError, match="failed removing existing rows for cohort_id=5"): + exclude_cohort_rows(_CohortRelation([], fail_filter=True), cohort_id=5) + + +def test_run_transaction_control_requires_raw_sql_support(): + with pytest.raises(ExecutionError, match="does not support raw_sql for transactional cohort writes"): + _run_transaction_control(object(), "BEGIN") + + +def test_run_transaction_control_wraps_backend_errors(): + with pytest.raises(ExecutionError, match="failed executing transaction statement 'BEGIN'"): + _run_transaction_control(_RawSqlBackend(fail=True), "BEGIN") + + +def test_catalog_db_tuple_uses_backend_helpers_and_falls_back_cleanly(): + assert _catalog_db_tuple(_CatalogBackend(), "results") == ("catalog", "database") + assert _catalog_db_tuple(_BrokenCatalogBackend(), "results") == (None, "results") + assert _catalog_db_tuple(object(), None) == (None, None) diff --git a/tests/execution/test_parity_regressions.py b/tests/execution/test_parity_regressions.py new file mode 100644 index 00000000..4dec4d5f --- /dev/null +++ b/tests/execution/test_parity_regressions.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + CorelatedCriteria, + CriteriaGroup, + DemographicCriteria, + Occurrence, + PrimaryCriteria, +) +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _seed_common_tables(conn, ibis, *, persons): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": list(persons), + "year_of_birth": [1980 for _ in persons], + "gender_concept_id": [8507 for _ in persons], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": list(persons), + "observation_period_id": [10 + idx for idx, _ in enumerate(persons)], + "observation_period_start_date": ["2019-01-01" for _ in persons], + "observation_period_end_date": ["2022-12-31" for _ in persons], + } + ), + overwrite=True, + ) + + +def test_parity_concept_set_expansion_with_exclusions(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2)) + conn.create_table( + "concept", + obj=ibis.memtable( + { + "concept_id": [100, 101, 102, 200, 201], + "invalid_reason": [None, None, "D", None, None], + } + ), + overwrite=True, + ) + conn.create_table( + "concept_ancestor", + obj=ibis.memtable( + { + "ancestor_concept_id": [100, 100], + "descendant_concept_id": [101, 102], + } + ), + overwrite=True, + ) + conn.create_table( + "concept_relationship", + obj=ibis.memtable( + { + "concept_id_1": [200, 201], + "concept_id_2": [100, 101], + "relationship_id": ["Maps to", "Maps to"], + "invalid_reason": [None, "D"], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 2], + "condition_occurrence_id": [1000, 1001, 1002, 1003], + "condition_concept_id": [100, 101, 200, 999], + "condition_start_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-03", + "2020-01-01", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-02", + "2020-01-03", + "2020-01-01", + ], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + ConceptSet( + id=1, + expression=ConceptSetExpression( + items=[ + ConceptSetItem( + concept=Concept(conceptId=100), + includeDescendants=True, + includeMapped=True, + ), + ConceptSetItem( + concept=Concept(conceptId=101), + isExcluded=True, + includeMapped=True, + ), + ] + ), + ) + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} + assert set(result.concept_id) == {100, 200} + + +def test_parity_primary_correlated_and_demographic_group_combination(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis, persons=(1, 2, 3)) + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "year_of_birth": [1980, 1980, 1980], + "gender_concept_id": [8507, 8507, 8507], + "race_concept_id": [8527, 8527, 8516], + "ethnicity_concept_id": [38003564, 38003564, 38003564], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 2, 3, 3], + "condition_occurrence_id": [10, 11, 20, 30, 31], + "condition_concept_id": [111, 222, 111, 111, 222], + "condition_start_date": [ + "2020-01-01", + "2020-01-03", + "2020-01-01", + "2020-01-01", + "2020-01-03", + ], + "condition_end_date": [ + "2020-01-01", + "2020-01-03", + "2020-01-01", + "2020-01-01", + "2020-01-03", + ], + "visit_occurrence_id": [10, 10, 20, 30, 30], + } + ), + overwrite=True, + ) + + primary = ConditionOccurrence( + codeset_id=1, + correlated_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + ) + ], + ), + ) + expression = CohortExpression( + concept_sets=[ + ConceptSet( + id=1, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=111))]), + ), + ConceptSet( + id=2, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=222))]), + ), + ], + primary_criteria=PrimaryCriteria(criteria_list=[primary]), + additional_criteria=CriteriaGroup( + type="ALL", + demographic_criteria_list=[ + DemographicCriteria( + race=[Concept(conceptId=8527)], + ethnicity=[Concept(conceptId=38003564)], + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert set(result.person_id) == {1} diff --git a/tests/execution/test_person_filters.py b/tests/execution/test_person_filters.py new file mode 100644 index 00000000..2bb40cdb --- /dev/null +++ b/tests/execution/test_person_filters.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import pytest + +from circe.execution.errors import CompilationError +from circe.execution.ibis.person_filters import ( + _apply_numeric_predicate, + apply_person_age_filter, + apply_person_ethnicity_filter, + apply_person_gender_filter, + apply_person_race_filter, +) +from circe.execution.plan.predicates import NumericRangePredicate + + +class _PersonFilterContext: + def __init__(self, conn, *, codesets: dict[int, tuple[int, ...]] | None = None): + self.conn = conn + self.codesets = codesets or {} + + def table(self, name: str): + return self.conn.table(name) + + def concept_ids_for_codeset(self, codeset_id: int) -> tuple[int, ...]: + return self.codesets.get(codeset_id, ()) + + +def _seed_person_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "year_of_birth": [1980, 1995, 2005], + "gender_concept_id": [8507, 8532, 8507], + "race_concept_id": [8527, 8516, 8527], + "ethnicity_concept_id": [38003564, 38003563, 38003564], + } + ), + overwrite=True, + ) + conn.create_table( + "events", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "start_date": ["2020-01-01", "2020-01-01", "2020-01-01"], + } + ), + overwrite=True, + ) + + +def test_apply_person_age_filter_supports_between_predicate(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_person_tables(conn, ibis) + ctx = _PersonFilterContext(conn) + events = conn.table("events") + + result = apply_person_age_filter( + events, + ctx, + date_column="start_date", + predicate=NumericRangePredicate(op="between", value=20, extent=40), + ).execute() + + assert set(result.person_id) == {1, 2} + + +def test_apply_person_numeric_predicate_rejects_invalid_between_and_op(): + with pytest.raises(CompilationError, match="between' requires an extent value"): + _apply_numeric_predicate( + 5, + NumericRangePredicate(op="between", value=1, extent=None), + ) + + with pytest.raises(CompilationError, match="unsupported person numeric range op"): + _apply_numeric_predicate( + 5, + NumericRangePredicate(op="invalid", value=1, extent=None), + ) + + +def test_apply_person_gender_filter_returns_original_table_when_no_ids(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_person_tables(conn, ibis) + ctx = _PersonFilterContext(conn) + events = conn.table("events") + + assert apply_person_gender_filter(events, ctx, concept_ids=(), codeset_id=None) is events + + +def test_apply_person_gender_filter_merges_explicit_and_codeset_ids(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_person_tables(conn, ibis) + ctx = _PersonFilterContext(conn, codesets={1: (8507, 8532)}) + events = conn.table("events") + + result = apply_person_gender_filter(events, ctx, concept_ids=(8507,), codeset_id=1).execute() + + assert set(result.person_id) == {1, 2, 3} + + +def test_apply_person_race_and_ethnicity_filters_use_codeset_expansion(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_person_tables(conn, ibis) + ctx = _PersonFilterContext(conn, codesets={2: (8527,), 3: (38003564,)}) + events = conn.table("events") + + race_result = apply_person_race_filter(events, ctx, concept_ids=(), codeset_id=2).execute() + ethnicity_result = apply_person_ethnicity_filter(events, ctx, concept_ids=(), codeset_id=3).execute() + + assert set(race_result.person_id) == {1, 3} + assert set(ethnicity_result.person_id) == {1, 3} diff --git a/tests/execution/test_result_limits.py b/tests/execution/test_result_limits.py new file mode 100644 index 00000000..15d835b9 --- /dev/null +++ b/tests/execution/test_result_limits.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + CorelatedCriteria, + CriteriaColumn, + CriteriaGroup, + Occurrence, + PrimaryCriteria, + VisitDetail, +) +from circe.cohortdefinition.core import ResultLimit +from circe.execution import api as execution_api +from circe.execution.engine.group_operators import resolve_distinct_count_column +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def _seed_common_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1], + "year_of_birth": [1980], + "gender_concept_id": [8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": ["2019-01-01"], + "observation_period_end_date": ["2022-12-31"], + } + ), + overwrite=True, + ) + + +def test_primary_limit_last_keeps_latest_primary_event(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-02-01"], + "condition_end_date": ["2020-01-01", "2020-02-01"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)], + primary_limit=ResultLimit(type="LAST"), + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-02-01" + + +def test_expression_limit_last_keeps_latest_qualified_event(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1], + "condition_occurrence_id": [100, 101], + "condition_concept_id": [111, 111], + "condition_start_date": ["2020-01-01", "2020-02-01"], + "condition_end_date": ["2020-01-01", "2020-02-01"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111)], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)], + primary_limit=ResultLimit(type="ALL"), + ), + expression_limit=ResultLimit(type="LAST"), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-02-01" + + +def test_qualified_limit_last_applies_after_additional_criteria(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 1], + "condition_occurrence_id": [100, 101, 102, 103], + "condition_concept_id": [111, 111, 222, 222], + "condition_start_date": [ + "2020-01-01", + "2020-02-01", + "2020-01-02", + "2020-02-02", + ], + "condition_end_date": [ + "2020-01-01", + "2020-02-01", + "2020-01-02", + "2020-02-02", + ], + "visit_occurrence_id": [10, 10, 10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)], + primary_limit=ResultLimit(type="ALL"), + ), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=ConditionOccurrence(codeset_id=2), + occurrence=Occurrence(type=Occurrence._AT_LEAST, count=1), + restrict_visit=True, + ) + ], + ), + qualified_limit=ResultLimit(type="LAST"), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-02-01" + + +def test_write_cohort_without_results_schema_uses_backend_default(monkeypatch): + captured: dict[str, object] = {} + + def _fake_build_cohort(*args, **kwargs): + return object() + + def _fake_project_to_ohdsi_cohort_table(relation, *, cohort_id): + return relation + + def _fake_table_exists(*args, **kwargs): + return False + + def _fake_write_relation( + relation, *, backend, target_table, target_schema=None, if_exists="fail", temporary=False + ): + backend.create_table(target_table, obj=relation, overwrite=(if_exists == "replace")) + + class _Backend: + def create_table(self, name, **kwargs): + captured["name"] = name + captured["kwargs"] = kwargs + + monkeypatch.setattr(execution_api, "build_cohort", _fake_build_cohort) + monkeypatch.setattr( + execution_api, + "project_to_ohdsi_cohort_table", + _fake_project_to_ohdsi_cohort_table, + ) + monkeypatch.setattr(execution_api, "table_exists", _fake_table_exists) + monkeypatch.setattr(execution_api, "write_relation", _fake_write_relation) + + execution_api.write_cohort( + CohortExpression(primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()])), + backend=_Backend(), + cdm_schema="cdm", + cohort_table="cohort_out", + cohort_id=1, + ) + + assert captured["name"] == "cohort_out" + assert "database" not in captured["kwargs"] + + +@pytest.mark.parametrize( + "count_column", + [ + None, + CriteriaColumn.DOMAIN_CONCEPT, + CriteriaColumn.DOMAIN_SOURCE_CONCEPT, + CriteriaColumn.VISIT_ID, + CriteriaColumn.VISIT_DETAIL_ID, + CriteriaColumn.START_DATE, + CriteriaColumn.END_DATE, + CriteriaColumn.DURATION, + CriteriaColumn.QUANTITY, + CriteriaColumn.DAYS_SUPPLY, + CriteriaColumn.REFILLS, + CriteriaColumn.RANGE_LOW, + CriteriaColumn.RANGE_HIGH, + CriteriaColumn.VALUE_AS_NUMBER, + CriteriaColumn.UNIT, + CriteriaColumn.ERA_OCCURRENCES, + CriteriaColumn.GAP_DAYS, + ], +) +def test_resolve_distinct_count_column_supports_public_count_columns(count_column): + resolved = resolve_distinct_count_column(None if count_column is None else count_column.value) + assert resolved.startswith("a_") + + +def test_distinct_count_by_visit_detail_id_matches_sql_semantics(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": ["2020-01-01"], + "condition_end_date": ["2020-01-01"], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + conn.create_table( + "visit_detail", + obj=ibis.memtable( + { + "person_id": [1, 1], + "visit_detail_id": [200, 201], + "visit_detail_concept_id": [222, 222], + "visit_detail_start_date": ["2020-01-01", "2020-01-01"], + "visit_detail_end_date": ["2020-01-02", "2020-01-02"], + "visit_occurrence_id": [10, 10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 111), _make_concept_set(2, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + additional_criteria=CriteriaGroup( + type="ALL", + criteria_list=[ + CorelatedCriteria( + criteria=VisitDetail(codeset_id=2), + occurrence=Occurrence( + type=Occurrence._AT_LEAST, + count=2, + is_distinct=True, + count_column=CriteriaColumn.VISIT_DETAIL_ID, + ), + restrict_visit=True, + ) + ], + ), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert len(result) == 1 diff --git a/tests/execution/test_scaffolding.py b/tests/execution/test_scaffolding.py new file mode 100644 index 00000000..c9cdecfd --- /dev/null +++ b/tests/execution/test_scaffolding.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import FrozenInstanceError + +import pytest + +from circe.execution.normalize.windows import NormalizedDateRange + + +def test_execution_package_imports(): + import circe.execution + import circe.execution.api + import circe.execution.engine + import circe.execution.ibis + import circe.execution.lower + import circe.execution.normalize + import circe.execution.plan + + assert hasattr(circe.execution, "build_cohort") + assert hasattr(circe.execution, "write_cohort") + + +def test_normalized_dataclasses_are_frozen(): + value = NormalizedDateRange(op="gte", value="2020-01-01", extent=None) + with pytest.raises(FrozenInstanceError): + value.op = "lt" diff --git a/tests/execution/test_standard_schema_contracts.py b/tests/execution/test_standard_schema_contracts.py new file mode 100644 index 00000000..7191c0b5 --- /dev/null +++ b/tests/execution/test_standard_schema_contracts.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import CohortExpression, ConditionOccurrence, PrimaryCriteria +from circe.execution.plan.schema import STANDARD_EVENT_COLUMNS +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem +from tests.execution._assertions import assert_standard_event_columns + + +def test_standard_schema_constants_define_expected_column_order(): + assert STANDARD_EVENT_COLUMNS == ( + "person_id", + "event_id", + "start_date", + "end_date", + "domain", + "concept_id", + "source_concept_id", + "visit_occurrence_id", + "visit_detail_id", + "quantity", + "days_supply", + "refills", + "range_low", + "range_high", + "value_as_number", + "unit_concept_id", + "occurrence_count", + "gap_days", + "duration", + "criterion_index", + "criterion_type", + "source_table", + ) + + +def test_standard_schema_contract_for_compiled_primary_events(): + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1], + "year_of_birth": [1980], + "gender_concept_id": [8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": ["2019-01-01"], + "observation_period_end_date": ["2022-12-31"], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": ["2020-01-01"], + "condition_end_date": ["2020-01-01"], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + ConceptSet( + id=1, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=111))]), + ) + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + assert_standard_event_columns(result.columns) diff --git a/tests/test_execution_api.py b/tests/test_execution_api.py deleted file mode 100644 index 8e6feeec..00000000 --- a/tests/test_execution_api.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Tests for experimental execution API surface.""" - -from __future__ import annotations - -import json -from pathlib import Path - -import pytest - -from circe import CohortExpression -from circe.cohortdefinition import ( - ConditionOccurrence, - CustomEraStrategy, - DateOffsetStrategy, - DrugExposure, - PayerPlanPeriod, - PrimaryCriteria, - VisitDetail, -) -from circe.execution import ExecutionOptions, IbisExecutor -from circe.execution.criteria_compat import parse_single_criteria -from circe.execution.ibis import write_cohort -from circe.execution.options import schema_to_str -from circe.io import load_expression -from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem - - -def test_execution_options_defaults(): - options = ExecutionOptions() - - assert options.cdm_schema is None - assert options.vocabulary_schema is None - assert options.result_schema is None - assert options.cohort_id is None - assert options.materialize_stages is False - assert options.materialize_codesets is True - assert options.temp_emulation_schema is None - assert options.capture_sql is False - assert options.profile_dir is None - - -def test_schema_to_str_with_tuple_schema(): - assert schema_to_str(("catalog", "schema")) == "catalog.schema" - - -def test_load_expression_from_mapping(): - expression = load_expression({"Title": "Mapping Input"}) - assert isinstance(expression, CohortExpression) - assert expression.title == "Mapping Input" - - -def test_load_expression_from_path(tmp_path: Path): - payload = {"Title": "File Input"} - path = tmp_path / "cohort.json" - path.write_text(json.dumps(payload), encoding="utf-8") - - expression = load_expression(path) - assert isinstance(expression, CohortExpression) - assert expression.title == "File Input" - - -def test_ibis_executor_missing_optional_dependencies(monkeypatch): - class DummyConn: - pass - - import builtins - import sys - - real_import = builtins.__import__ - sys.modules.pop("circe.execution.build_context", None) - - def _import(name, *args, **kwargs): - if name.endswith("build_context"): - raise ModuleNotFoundError("No module named 'ibis'") - return real_import(name, *args, **kwargs) - - monkeypatch.setattr(builtins, "__import__", _import) - - executor = IbisExecutor(DummyConn(), ExecutionOptions()) - - with pytest.raises(RuntimeError, match="requires optional dependencies"): - executor.build({"Title": "No Backend"}) - - -def test_criteria_compat_methods_available(): - criteria = DrugExposure() - - assert criteria.get_primary_key_column() == "drug_exposure_id" - assert criteria.get_start_date_column() == "drug_exposure_start_date" - assert criteria.get_end_date_column() == "drug_exposure_end_date" - assert criteria.get_concept_id_column() == "drug_concept_id" - - -def test_parse_single_criteria_wrapper(): - parsed = parse_single_criteria({"ConditionOccurrence": {"CodesetId": 10}}) - - assert isinstance(parsed, ConditionOccurrence) - assert parsed.codeset_id == 10 - - -def test_parse_single_criteria_wrapper_case_insensitive(): - parsed = parse_single_criteria({"conditionoccurrence": {"CodesetId": 11}}) - - assert isinstance(parsed, ConditionOccurrence) - assert parsed.codeset_id == 11 - - -def test_pipeline_registers_visit_detail_and_payer_plan_period_builders(): - from circe.execution.builders import pipeline as _pipeline # noqa: F401 - from circe.execution.builders.registry import get_builder - - assert callable(get_builder(VisitDetail())) - assert callable(get_builder(PayerPlanPeriod())) - - -def test_coerce_concept_set_selection_rejects_invalid_value(): - from circe.execution.builders.common import coerce_concept_set_selection - - with pytest.raises(ValueError, match="Unsupported concept set selection value"): - coerce_concept_set_selection(object()) - - -def test_write_rejects_append_and_overwrite_together(): - class DummyConn: - pass - - executor = IbisExecutor(DummyConn(), ExecutionOptions()) - - with pytest.raises(ValueError, match="cannot be used together"): - executor.write( - {"Title": "Invalid write options"}, - table="cohort", - append=True, - overwrite=True, - ) - - -def test_write_cohort_rejects_append_and_overwrite_together(): - class DummyConn: - pass - - with pytest.raises(ValueError, match="cannot be used together"): - write_cohort( - {"Title": "Invalid write options"}, - DummyConn(), - table="cohort", - append=True, - overwrite=True, - ) - - -def test_has_end_strategy_handles_polymorphic_models(): - from circe.execution.builders.common import has_end_strategy - - assert has_end_strategy(None) is False - assert has_end_strategy(DateOffsetStrategy(offset=7, date_field="StartDate")) is True - assert has_end_strategy(CustomEraStrategy(drug_codeset_id=123)) is True - - -@pytest.mark.filterwarnings( - "ignore:fetch_arrow_table\\(\\) is deprecated, use to_arrow_table\\(\\) instead\\.:DeprecationWarning" -) -def test_ibis_executor_build_smoke_duckdb(): - ibis = pytest.importorskip("ibis") - _ = pytest.importorskip("duckdb") - - conn = ibis.duckdb.connect() - - conn.create_table( - "concept", - obj=ibis.memtable( - { - "concept_id": [111, 999], - "invalid_reason": [None, "D"], - } - ), - overwrite=True, - ) - conn.create_table( - "concept_ancestor", - obj=ibis.memtable( - { - "ancestor_concept_id": [111], - "descendant_concept_id": [111], - } - ), - overwrite=True, - ) - conn.create_table( - "concept_relationship", - obj=ibis.memtable( - { - "concept_id_1": [111], - "concept_id_2": [111], - "relationship_id": ["Maps to"], - "invalid_reason": [""], - } - ), - overwrite=True, - ) - conn.create_table( - "condition_occurrence", - obj=ibis.memtable( - { - "person_id": [1], - "condition_occurrence_id": [1001], - "condition_concept_id": [111], - "condition_start_date": ["2020-01-01"], - "condition_end_date": ["2020-01-02"], - } - ), - overwrite=True, - ) - - cohort = CohortExpression( - concept_sets=[ - ConceptSet( - id=1, - expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=111))]), - ) - ], - primary_criteria=PrimaryCriteria( - criteria_list=[ConditionOccurrence(codeset_id=1)], - ), - ) - - with IbisExecutor(conn, ExecutionOptions(materialize_stages=False)) as executor: - events = executor.build(cohort) - result = events.execute() - - assert len(result) == 1 - assert set(result.columns) == { - "person_id", - "event_id", - "start_date", - "end_date", - "visit_occurrence_id", - } diff --git a/uv.lock b/uv.lock index 5b22cf81..25038575 100644 --- a/uv.lock +++ b/uv.lock @@ -1869,8 +1869,6 @@ ibis-databricks = [ ibis-duckdb = [ { name = "ibis-framework", version = "11.0.0", source = { registry = "https://pypi.org/simple" }, extra = ["duckdb"], marker = "python_full_version < '3.10'" }, { name = "ibis-framework", version = "12.0.0", source = { registry = "https://pypi.org/simple" }, extra = ["duckdb"], marker = "python_full_version >= '3.10'" }, - { name = "polars", version = "1.36.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, - { name = "polars", version = "1.39.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, ] ibis-postgres = [ { name = "ibis-framework", version = "11.0.0", source = { registry = "https://pypi.org/simple" }, extra = ["postgres"], marker = "python_full_version < '3.10'" }, @@ -1893,7 +1891,6 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.0.0" }, { name = "myst-parser", marker = "extra == 'docs'", specifier = ">=0.18.0" }, - { name = "polars", marker = "python_full_version >= '3.9' and extra == 'ibis-duckdb'", specifier = ">=0.20.0" }, { name = "polars", marker = "extra == 'dev'", specifier = ">=0.20.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.0.0" }, { name = "pydantic", specifier = ">=2.0.0" },