diff --git a/circe/execution/databricks_compat.py b/circe/execution/databricks_compat.py index fa8375f7..2d9405ec 100644 --- a/circe/execution/databricks_compat.py +++ b/circe/execution/databricks_compat.py @@ -56,6 +56,13 @@ def apply_databricks_post_connect_workaround( This helper should be applied lazily by the execution path when a Databricks backend is actually used. """ + import ibis + from packaging.version import Version + + # If the installed Ibis version is late enough to contain the fix, skip patch + if Version(ibis.__version__) >= Version("10.0.0"): + return False + backend_cls = _databricks_backend_class() if backend_cls is None else backend_cls if backend_cls is None: return False @@ -70,6 +77,15 @@ def apply_databricks_post_connect_workaround( if not _post_connect_needs_workaround(post_connect): return False + import warnings + + warnings.warn( + "The Databricks workaround for Ibis issue #11598 is active. " + "This will be removed in a future release once older Ibis versions are deprecated.", + DeprecationWarning, + stacklevel=2, + ) + @functools.wraps(post_connect) def _patched_post_connect(self: Any, *args: Any, **kwargs: Any) -> Any: try: diff --git a/circe/execution/lower/__init__.py b/circe/execution/lower/__init__.py index 95d11a79..60bdce84 100644 --- a/circe/execution/lower/__init__.py +++ b/circe/execution/lower/__init__.py @@ -1,3 +1,3 @@ -from .criteria import LOWERERS, lower_criterion +from .criteria import LowerFn, lower_criterion -__all__ = ["LOWERERS", "lower_criterion"] +__all__ = ["LowerFn", "lower_criterion"] diff --git a/circe/execution/lower/condition_era.py b/circe/execution/lower/condition_era.py index 9162e510..afbdc1fe 100644 --- a/circe/execution/lower/condition_era.py +++ b/circe/execution/lower/condition_era.py @@ -1,5 +1,8 @@ from __future__ import annotations +from circe.cohortdefinition.criteria import ConditionEra +from circe.extensions import lowerer + from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan from ..plan.schema import OCCURRENCE_COUNT @@ -11,6 +14,7 @@ ) +@lowerer(ConditionEra) def lower_condition_era( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/condition_occurrence.py b/circe/execution/lower/condition_occurrence.py index 8d11c2dc..294677cf 100644 --- a/circe/execution/lower/condition_occurrence.py +++ b/circe/execution/lower/condition_occurrence.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import ConditionOccurrence from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan @@ -13,6 +15,7 @@ ) +@lowerer(ConditionOccurrence) def lower_condition_occurrence( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/criteria.py b/circe/execution/lower/criteria.py index 90171f78..4a17e90c 100644 --- a/circe/execution/lower/criteria.py +++ b/circe/execution/lower/criteria.py @@ -2,44 +2,11 @@ from typing import Protocol -from ...cohortdefinition.criteria import ( - ConditionEra, - ConditionOccurrence, - Criteria, - Death, - DeviceExposure, - DoseEra, - DrugEra, - DrugExposure, - LocationRegion, - Measurement, - Observation, - ObservationPeriod, - PayerPlanPeriod, - ProcedureOccurrence, - Specimen, - VisitDetail, - VisitOccurrence, -) +from circe.extensions import get_registry + from ..errors import UnsupportedCriterionError from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan -from .condition_era import lower_condition_era -from .condition_occurrence import lower_condition_occurrence -from .death import lower_death -from .device_exposure import lower_device_exposure -from .dose_era import lower_dose_era -from .drug_era import lower_drug_era -from .drug_exposure import lower_drug_exposure -from .location_region import lower_location_region -from .measurement import lower_measurement -from .observation import lower_observation -from .observation_period import lower_observation_period -from .payer_plan_period import lower_payer_plan_period -from .procedure_occurrence import lower_procedure_occurrence -from .specimen import lower_specimen -from .visit_detail import lower_visit_detail -from .visit_occurrence import lower_visit_occurrence class LowerFn(Protocol): @@ -51,34 +18,18 @@ def __call__( ) -> EventPlan: ... -LOWERERS: dict[type[Criteria], LowerFn] = { - ConditionOccurrence: lower_condition_occurrence, - DrugExposure: lower_drug_exposure, - VisitOccurrence: lower_visit_occurrence, - Measurement: lower_measurement, - ProcedureOccurrence: lower_procedure_occurrence, - Observation: lower_observation, - VisitDetail: lower_visit_detail, - DeviceExposure: lower_device_exposure, - Specimen: lower_specimen, - Death: lower_death, - ObservationPeriod: lower_observation_period, - PayerPlanPeriod: lower_payer_plan_period, - ConditionEra: lower_condition_era, - DrugEra: lower_drug_era, - DoseEra: lower_dose_era, - LocationRegion: lower_location_region, -} - - def lower_criterion( criterion: NormalizedCriterion, *, criterion_index: int, ) -> EventPlan: - lowerer = LOWERERS.get(type(criterion.raw_criteria)) + registry = get_registry() + criteria_cls = type(criterion.raw_criteria) + lowerer = registry.get_lowerer(criteria_cls) + if lowerer is not None: return lowerer(criterion, criterion_index=criterion_index) + raise UnsupportedCriterionError( f"Ibis executor lowering error: no lowerer registered for {criterion.criterion_type}." ) diff --git a/circe/execution/lower/death.py b/circe/execution/lower/death.py index 0c95c21d..2b5f44f7 100644 --- a/circe/execution/lower/death.py +++ b/circe/execution/lower/death.py @@ -1,11 +1,14 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import Death from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan from .common import append_concept_filters, build_standard_domain_plan, lower_common_steps +@lowerer(Death) def lower_death( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/device_exposure.py b/circe/execution/lower/device_exposure.py index c5c2e19d..23f44654 100644 --- a/circe/execution/lower/device_exposure.py +++ b/circe/execution/lower/device_exposure.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import DeviceExposure from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan @@ -14,6 +16,7 @@ ) +@lowerer(DeviceExposure) def lower_device_exposure( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/dose_era.py b/circe/execution/lower/dose_era.py index 503a5a8a..8ae286a2 100644 --- a/circe/execution/lower/dose_era.py +++ b/circe/execution/lower/dose_era.py @@ -1,10 +1,14 @@ from __future__ import annotations +from circe.cohortdefinition.criteria import DoseEra +from circe.extensions import lowerer + from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan from .common import append_duration_filter, build_standard_domain_plan, lower_common_steps +@lowerer(DoseEra) def lower_dose_era( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/drug_era.py b/circe/execution/lower/drug_era.py index 02afe74c..8d322217 100644 --- a/circe/execution/lower/drug_era.py +++ b/circe/execution/lower/drug_era.py @@ -1,5 +1,8 @@ from __future__ import annotations +from circe.cohortdefinition.criteria import DrugEra +from circe.extensions import lowerer + from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan from ..plan.schema import GAP_DAYS, OCCURRENCE_COUNT @@ -11,6 +14,7 @@ ) +@lowerer(DrugEra) def lower_drug_era( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/drug_exposure.py b/circe/execution/lower/drug_exposure.py index cdbbf93b..d7100bf5 100644 --- a/circe/execution/lower/drug_exposure.py +++ b/circe/execution/lower/drug_exposure.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import DrugExposure from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan @@ -14,6 +16,7 @@ ) +@lowerer(DrugExposure) def lower_drug_exposure( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/location_region.py b/circe/execution/lower/location_region.py index 21afdfae..c25722a3 100644 --- a/circe/execution/lower/location_region.py +++ b/circe/execution/lower/location_region.py @@ -1,5 +1,8 @@ from __future__ import annotations +from circe.cohortdefinition.criteria import LocationRegion +from circe.extensions import lowerer + from ..normalize.criteria import NormalizedCriterion from ..plan.events import ( EventPlan, @@ -11,6 +14,7 @@ ) +@lowerer(LocationRegion) def lower_location_region( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/measurement.py b/circe/execution/lower/measurement.py index d2491945..4fadb112 100644 --- a/circe/execution/lower/measurement.py +++ b/circe/execution/lower/measurement.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import Measurement from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan @@ -14,6 +16,7 @@ ) +@lowerer(Measurement) def lower_measurement( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/observation.py b/circe/execution/lower/observation.py index 7ff85c21..fa991894 100644 --- a/circe/execution/lower/observation.py +++ b/circe/execution/lower/observation.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import Observation from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan @@ -14,6 +16,7 @@ ) +@lowerer(Observation) def lower_observation( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/observation_period.py b/circe/execution/lower/observation_period.py index ef2b8a90..d1bccbfd 100644 --- a/circe/execution/lower/observation_period.py +++ b/circe/execution/lower/observation_period.py @@ -1,10 +1,14 @@ from __future__ import annotations +from circe.cohortdefinition.criteria import ObservationPeriod +from circe.extensions import lowerer + from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan from .common import lower_standard_domain_plan +@lowerer(ObservationPeriod) def lower_observation_period( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/payer_plan_period.py b/circe/execution/lower/payer_plan_period.py index 3a08d1ea..f5a1221d 100644 --- a/circe/execution/lower/payer_plan_period.py +++ b/circe/execution/lower/payer_plan_period.py @@ -1,10 +1,14 @@ from __future__ import annotations +from circe.cohortdefinition.criteria import PayerPlanPeriod +from circe.extensions import lowerer + from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan from .common import lower_standard_domain_plan +@lowerer(PayerPlanPeriod) def lower_payer_plan_period( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/procedure_occurrence.py b/circe/execution/lower/procedure_occurrence.py index caded791..13846f3a 100644 --- a/circe/execution/lower/procedure_occurrence.py +++ b/circe/execution/lower/procedure_occurrence.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import ProcedureOccurrence from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan @@ -13,6 +15,7 @@ ) +@lowerer(ProcedureOccurrence) def lower_procedure_occurrence( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/specimen.py b/circe/execution/lower/specimen.py index 02710630..10f1cdbb 100644 --- a/circe/execution/lower/specimen.py +++ b/circe/execution/lower/specimen.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import Specimen from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan @@ -12,6 +14,7 @@ ) +@lowerer(Specimen) def lower_specimen( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/visit_detail.py b/circe/execution/lower/visit_detail.py index 73985026..f85c3354 100644 --- a/circe/execution/lower/visit_detail.py +++ b/circe/execution/lower/visit_detail.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import VisitDetail from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan, PlanStep @@ -14,6 +16,7 @@ ) +@lowerer(VisitDetail) def lower_visit_detail( criterion: NormalizedCriterion, *, diff --git a/circe/execution/lower/visit_occurrence.py b/circe/execution/lower/visit_occurrence.py index ef7e9d9e..871287bb 100644 --- a/circe/execution/lower/visit_occurrence.py +++ b/circe/execution/lower/visit_occurrence.py @@ -1,5 +1,7 @@ from __future__ import annotations +from circe.extensions import lowerer + from ...cohortdefinition.criteria import VisitOccurrence from ..normalize.criteria import NormalizedCriterion from ..plan.events import EventPlan, PlanStep @@ -14,6 +16,7 @@ ) +@lowerer(VisitOccurrence) def lower_visit_occurrence( criterion: NormalizedCriterion, *, diff --git a/circe/execution/normalize/criteria.py b/circe/execution/normalize/criteria.py index f0af5979..c3fb6eb3 100644 --- a/circe/execution/normalize/criteria.py +++ b/circe/execution/normalize/criteria.py @@ -3,6 +3,8 @@ from dataclasses import replace from typing import TYPE_CHECKING +from circe.extensions import get_registry, normalizer + from ...cohortdefinition.criteria import ( ConditionEra, ConditionOccurrence, @@ -140,6 +142,7 @@ def _build_normalized_criterion( ) +@normalizer(ConditionOccurrence) def _normalize_condition_occurrence(criteria: ConditionOccurrence) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -159,6 +162,7 @@ def _normalize_condition_occurrence(criteria: ConditionOccurrence) -> Normalized ) +@normalizer(DrugExposure) def _normalize_drug_exposure(criteria: DrugExposure) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -178,6 +182,7 @@ def _normalize_drug_exposure(criteria: DrugExposure) -> NormalizedCriterion: ) +@normalizer(VisitOccurrence) def _normalize_visit_occurrence(criteria: VisitOccurrence) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -197,6 +202,7 @@ def _normalize_visit_occurrence(criteria: VisitOccurrence) -> NormalizedCriterio ) +@normalizer(Measurement) def _normalize_measurement(criteria: Measurement) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -216,6 +222,7 @@ def _normalize_measurement(criteria: Measurement) -> NormalizedCriterion: ) +@normalizer(ProcedureOccurrence) def _normalize_procedure_occurrence( criteria: ProcedureOccurrence, ) -> NormalizedCriterion: @@ -237,6 +244,7 @@ def _normalize_procedure_occurrence( ) +@normalizer(Observation) def _normalize_observation(criteria: Observation) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -256,6 +264,7 @@ def _normalize_observation(criteria: Observation) -> NormalizedCriterion: ) +@normalizer(VisitDetail) def _normalize_visit_detail(criteria: VisitDetail) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -275,6 +284,7 @@ def _normalize_visit_detail(criteria: VisitDetail) -> NormalizedCriterion: ) +@normalizer(DeviceExposure) def _normalize_device_exposure(criteria: DeviceExposure) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -294,6 +304,7 @@ def _normalize_device_exposure(criteria: DeviceExposure) -> NormalizedCriterion: ) +@normalizer(Specimen) def _normalize_specimen(criteria: Specimen) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -313,6 +324,7 @@ def _normalize_specimen(criteria: Specimen) -> NormalizedCriterion: ) +@normalizer(Death) def _normalize_death(criteria: Death) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -332,6 +344,7 @@ def _normalize_death(criteria: Death) -> NormalizedCriterion: ) +@normalizer(ObservationPeriod) def _normalize_observation_period(criteria: ObservationPeriod) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -351,6 +364,7 @@ def _normalize_observation_period(criteria: ObservationPeriod) -> NormalizedCrit ) +@normalizer(PayerPlanPeriod) def _normalize_payer_plan_period(criteria: PayerPlanPeriod) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -370,6 +384,7 @@ def _normalize_payer_plan_period(criteria: PayerPlanPeriod) -> NormalizedCriteri ) +@normalizer(ConditionEra) def _normalize_condition_era(criteria: ConditionEra) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -389,6 +404,7 @@ def _normalize_condition_era(criteria: ConditionEra) -> NormalizedCriterion: ) +@normalizer(DrugEra) def _normalize_drug_era(criteria: DrugEra) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -408,6 +424,7 @@ def _normalize_drug_era(criteria: DrugEra) -> NormalizedCriterion: ) +@normalizer(DoseEra) def _normalize_dose_era(criteria: DoseEra) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -427,6 +444,7 @@ def _normalize_dose_era(criteria: DoseEra) -> NormalizedCriterion: ) +@normalizer(LocationRegion) def _normalize_location_region(criteria: LocationRegion) -> NormalizedCriterion: return _build_normalized_criterion( criteria=criteria, @@ -447,43 +465,16 @@ def _normalize_location_region(criteria: LocationRegion) -> NormalizedCriterion: def normalize_criterion(criteria: Criteria) -> NormalizedCriterion: - if isinstance(criteria, ConditionOccurrence): - normalized = _normalize_condition_occurrence(criteria) - elif isinstance(criteria, DrugExposure): - normalized = _normalize_drug_exposure(criteria) - elif isinstance(criteria, VisitOccurrence): - normalized = _normalize_visit_occurrence(criteria) - elif isinstance(criteria, Measurement): - normalized = _normalize_measurement(criteria) - elif isinstance(criteria, ProcedureOccurrence): - normalized = _normalize_procedure_occurrence(criteria) - elif isinstance(criteria, Observation): - normalized = _normalize_observation(criteria) - elif isinstance(criteria, VisitDetail): - normalized = _normalize_visit_detail(criteria) - elif isinstance(criteria, DeviceExposure): - normalized = _normalize_device_exposure(criteria) - elif isinstance(criteria, Specimen): - normalized = _normalize_specimen(criteria) - elif isinstance(criteria, Death): - normalized = _normalize_death(criteria) - elif isinstance(criteria, ObservationPeriod): - normalized = _normalize_observation_period(criteria) - elif isinstance(criteria, PayerPlanPeriod): - normalized = _normalize_payer_plan_period(criteria) - elif isinstance(criteria, ConditionEra): - normalized = _normalize_condition_era(criteria) - elif isinstance(criteria, DrugEra): - normalized = _normalize_drug_era(criteria) - elif isinstance(criteria, DoseEra): - normalized = _normalize_dose_era(criteria) - elif isinstance(criteria, LocationRegion): - normalized = _normalize_location_region(criteria) - else: + registry = get_registry() + normalizer_fn = registry.get_normalizer(type(criteria)) + + if normalizer_fn is None: raise UnsupportedCriterionError( f"Ibis executor normalization error: unsupported criterion type {criteria.__class__.__name__}." ) + normalized = normalizer_fn(criteria) + if criteria.correlated_criteria is not None and not criteria.correlated_criteria.is_empty(): from .groups import normalize_criteria_group diff --git a/circe/extensions/__init__.py b/circe/extensions/__init__.py index 5ba69400..b91a94aa 100644 --- a/circe/extensions/__init__.py +++ b/circe/extensions/__init__.py @@ -34,6 +34,10 @@ class WaveformOccurrenceMarkdownRenderer: if TYPE_CHECKING: from .cohortdefinition.builders.base import CriteriaSqlBuilder from .cohortdefinition.criteria import Criteria + from .execution.lower.criteria import LowerFn + from .execution.normalize.criteria import NormalizedCriterion + +NormalizerFn = Callable[["Criteria"], "NormalizedCriterion"] class ExtensionRegistry: @@ -46,6 +50,12 @@ def __init__(self): # Maps criteria types to SQL builder classes self._sql_builders: dict[type[Criteria], type[CriteriaSqlBuilder]] = {} + # Maps criteria types to lower functions + self._lowerers: dict[type[Criteria], LowerFn] = {} + + # Maps criteria types to normalizer functions + self._normalizers: dict[type[Criteria], NormalizerFn] = {} + # Maps criteria types to markdown template names self._markdown_templates: dict[type[Criteria], str] = {} @@ -74,6 +84,24 @@ def register_sql_builder( """ self._sql_builders[criteria_cls] = builder_cls + def register_lowerer(self, criteria_cls: type["Criteria"], lowerer: "LowerFn") -> None: + """Register a lower function for a criteria type. + + Args: + criteria_cls: The Criteria subclass + lowerer: The LowerFn to execute for this criteria + """ + self._lowerers[criteria_cls] = lowerer + + def register_normalizer(self, criteria_cls: type["Criteria"], normalizer: NormalizerFn) -> None: + """Register a normalizer function for a criteria type. + + Args: + criteria_cls: The Criteria subclass + normalizer: The NormalizerFn to execute for this criteria + """ + self._normalizers[criteria_cls] = normalizer + def register_markdown_template(self, criteria_cls: type["Criteria"], template_name: str) -> None: """Register a Jinja2 template for markdown rendering. @@ -104,6 +132,28 @@ def get_builder(self, criteria: "Criteria") -> Optional["CriteriaSqlBuilder"]: builder_cls = self._sql_builders.get(type(criteria)) return builder_cls() if builder_cls else None + def get_lowerer(self, criteria_cls: type["Criteria"]) -> Optional["LowerFn"]: + """Get the lower function for a criteria type. + + Args: + criteria_cls: The Criteria subclass + + Returns: + The LowerFn, or None if not found + """ + return self._lowerers.get(criteria_cls) + + def get_normalizer(self, criteria_cls: type["Criteria"]) -> Optional[NormalizerFn]: + """Get the normalizer function for a criteria type. + + Args: + criteria_cls: The Criteria subclass + + Returns: + The normalizer function, or None if not found + """ + return self._normalizers.get(criteria_cls) + def get_template(self, criteria: "Criteria") -> Optional[str]: """Get the markdown template name for a criteria instance. @@ -189,6 +239,46 @@ def decorator(builder_cls: "type['CriteriaSqlBuilder']") -> "type['CriteriaSqlBu return decorator # type: ignore[return-value] +def lowerer(criteria_cls: "type['Criteria']") -> Callable[["LowerFn"], "LowerFn"]: + """Decorator that registers an execution lower function for a Criteria type. + + Args: + criteria_cls: The Criteria subclass this function lowers. + + Example:: + + @lowerer(WaveformOccurrence) + def lower_waveform_occurrence(criterion, *, criterion_index): + ... + """ + + def decorator(fn: "LowerFn") -> "LowerFn": + _registry.register_lowerer(criteria_cls, fn) # type: ignore[arg-type] + return fn + + return decorator + + +def normalizer(criteria_cls: "type['Criteria']") -> Callable[[NormalizerFn], NormalizerFn]: + """Decorator that registers a normalizer function for a Criteria type. + + Args: + criteria_cls: The Criteria subclass this function normalizes. + + Example:: + + @normalizer(WaveformOccurrence) + def normalize_waveform_occurrence(criteria): + ... + """ + + def decorator(fn: NormalizerFn) -> NormalizerFn: + _registry.register_normalizer(criteria_cls, fn) # type: ignore[arg-type] + return fn + + return decorator + + def markdown_template(criteria_cls: "type['Criteria']", template_name: str) -> "Callable[[type], type]": """Class decorator that registers a Jinja2 markdown template for a Criteria type. diff --git a/circe/extensions/waveform/__init__.py b/circe/extensions/waveform/__init__.py index bd410196..37576469 100644 --- a/circe/extensions/waveform/__init__.py +++ b/circe/extensions/waveform/__init__.py @@ -2,6 +2,8 @@ from circe.extensions import template_path +# Import lowers and normalizers to trigger decorators +from . import lower, normalizer from .builders.waveform_channel_metadata import WaveformChannelMetadataSqlBuilder from .builders.waveform_feature import WaveformFeatureSqlBuilder @@ -16,6 +18,8 @@ template_path(Path(__file__).parent / "templates") __all__ = [ + "lower", + "normalizer", "WaveformChannelMetadata", "WaveformFeature", "WaveformOccurrence", diff --git a/circe/extensions/waveform/lower.py b/circe/extensions/waveform/lower.py new file mode 100644 index 00000000..b2f6633d --- /dev/null +++ b/circe/extensions/waveform/lower.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from circe.extensions import lowerer + +from ...execution.lower.common import ( + append_concept_filters, + append_numeric_filter, + append_text_filter, + build_standard_domain_plan, + lower_common_steps, +) +from ...execution.normalize.criteria import NormalizedCriterion +from ...execution.plan.events import EventPlan +from .criteria import WaveformOccurrence + + +@lowerer(WaveformOccurrence) +def lower_waveform_occurrence( + criterion: NormalizedCriterion, + *, + criterion_index: int, +) -> EventPlan: + raw = criterion.raw_criteria + if not isinstance(raw, WaveformOccurrence): + raise TypeError("lower_waveform_occurrence requires WaveformOccurrence criteria") + + steps = lower_common_steps(criterion) + + append_concept_filters( + steps, + column="waveform_occurrence_concept_id", + concepts=raw.waveform_occurrence_concept_id, + # codeset_selection does not exist on WaveformOccurrence based on the pydantic logic, we just pass concepts + ) + + append_text_filter( + steps, column="waveform_occurrence_source_value", value=raw.waveform_occurrence_source_value + ) + + append_numeric_filter(steps, column="visit_occurrence_id", value=raw.visit_occurrence_id) + + append_numeric_filter(steps, column="visit_detail_id", value=raw.visit_detail_id) + + append_numeric_filter(steps, column="num_of_files", value=raw.num_of_files) + + append_numeric_filter( + steps, column="preceding_waveform_occurrence_id", value=raw.preceding_waveform_occurrence_id + ) + + return build_standard_domain_plan( + criterion, + criterion_index=criterion_index, + steps=steps, + ) diff --git a/circe/extensions/waveform/normalizer.py b/circe/extensions/waveform/normalizer.py new file mode 100644 index 00000000..5e733aff --- /dev/null +++ b/circe/extensions/waveform/normalizer.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from circe.extensions import normalizer + +from ...execution.normalize.criteria import NormalizedCriterion, _build_normalized_criterion +from ...execution.normalize.windows import normalize_date_range +from .criteria import WaveformOccurrence + + +@normalizer(WaveformOccurrence) +def normalize_waveform_occurrence(criteria: WaveformOccurrence) -> NormalizedCriterion: + return _build_normalized_criterion( + criteria=criteria, + criterion_type="WaveformOccurrence", + domain="waveform_occurrence", + source_table="waveform_occurrence", + event_id_column="waveform_occurrence_id", + start_date_column="waveform_occurrence_start_datetime", + end_date_column="waveform_occurrence_end_datetime", + concept_column="waveform_occurrence_concept_id", + source_concept_column=None, + visit_occurrence_column="visit_occurrence_id", + codeset_id=None, + first=False, + occurrence_start_date=normalize_date_range(criteria.occurrence_start_datetime), + occurrence_end_date=normalize_date_range(criteria.occurrence_end_datetime), + ) diff --git a/tests/execution/test_registry_dispatch.py b/tests/execution/test_registry_dispatch.py new file mode 100644 index 00000000..7d53701a --- /dev/null +++ b/tests/execution/test_registry_dispatch.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import pytest + +from circe.cohortdefinition.criteria import Criteria, CriteriaGroup +from circe.execution.errors import UnsupportedCriterionError +from circe.execution.lower.criteria import lower_criterion +from circe.execution.normalize.criteria import NormalizedCriterion, normalize_criterion +from circe.extensions import get_registry, lowerer, normalizer +from circe.extensions.waveform import WaveformOccurrence + + +class FakeCriteria(Criteria): + pass + + +FakeCriteria.model_rebuild(_types_namespace={"CriteriaGroup": CriteriaGroup}) + + +@pytest.fixture(autouse=True) +def cleanup_registry(): + """Ensure the fake criteria gets cleaned up after the test.""" + registry = get_registry() + yield + registry._lowerers.pop(FakeCriteria, None) + registry._normalizers.pop(FakeCriteria, None) + + +def test_registry_dispatch_round_trip(): + fake_criterion = FakeCriteria() + fake_normalized = NormalizedCriterion( + raw_criteria=fake_criterion, + criterion_type="Fake", + domain="fake", + source_table="fake", + event_id_column="fake_id", + start_date_column="fake_start", + end_date_column="fake_end", + concept_column=None, + source_concept_column=None, + visit_occurrence_column=None, + codeset_id=None, + first=False, + occurrence_start_date=None, + occurrence_end_date=None, + person_filters=NormalizedCriterion._person_filters_from_criterion(fake_criterion) + if hasattr(NormalizedCriterion, "_person_filters_from_criterion") + else None, + ) + + @normalizer(FakeCriteria) + def fake_normalizer(criteria): + return fake_normalized + + @lowerer(FakeCriteria) + def fake_lowerer(criterion, *, criterion_index): + return "fake_plan_result" + + # Test normalize dispatch + normalized = normalize_criterion(fake_criterion) + assert normalized is fake_normalized + + # Test lower dispatch + plan = lower_criterion(normalized, criterion_index=1) + assert plan == "fake_plan_result" + + +def test_unknown_criteria_raises(): + class UnknownCriteria(Criteria): + pass + + UnknownCriteria.model_rebuild(_types_namespace={"CriteriaGroup": CriteriaGroup}) + + with pytest.raises(UnsupportedCriterionError, match="unsupported criterion type UnknownCriteria"): + normalize_criterion(UnknownCriteria()) + + # Create a dummy normalized criterion containing the unknown criteria to test lower_criterion + fake_normalized = NormalizedCriterion( + raw_criteria=UnknownCriteria(), + criterion_type="UnknownCriteria", + domain="unknown", + source_table="unknown", + event_id_column="id", + start_date_column="start", + end_date_column="end", + concept_column=None, + source_concept_column=None, + visit_occurrence_column=None, + codeset_id=None, + first=False, + occurrence_start_date=None, + occurrence_end_date=None, + person_filters=None, + ) + + with pytest.raises(UnsupportedCriterionError, match="no lowerer registered for UnknownCriteria"): + lower_criterion(fake_normalized, criterion_index=1) + + +def test_waveform_extension_dispatch(): + # The extension should have pre-registered its normalizer and lowerer + waveform = WaveformOccurrence() + + # Check that normalizer is found + normalized = normalize_criterion(waveform) + assert normalized.domain == "waveform_occurrence" + + # Check that lowerer is found + plan = lower_criterion(normalized, criterion_index=1) + + assert plan.source.table_name == "waveform_occurrence" + assert plan.criterion_type == "WaveformOccurrence"