Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datamimic_ce/constants/attribute_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
ATTR_DEFAULT_VARIABLE_SUFFIX: Final = "defaultVariableSuffix"
ATTR_VARIABLE_PREFIX: Final = "variablePrefix"
ATTR_VARIABLE_SUFFIX: Final = "variableSuffix"
ATTR_DIR: Final = "dir"
ATTR_STRING: Final = "string"
ATTR_BUCKET: Final = "bucket"
ATTR_MP_PLATFORM: Final = "mpPlatform"
Expand Down
1 change: 1 addition & 0 deletions datamimic_ce/constants/element_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
EL_CONDITION = "condition"
EL_ELSE_IF = "else-if"
EL_ELSE = "else"
EL_DEMOGRAPHICS = "demographics"
20 changes: 20 additions & 0 deletions datamimic_ce/contexts/demographic_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Context wiring for demographic samplers."""

from __future__ import annotations

from dataclasses import dataclass
from random import Random

from datamimic_ce.domains.common.demographics.profile import DemographicProfileId
from datamimic_ce.domains.common.demographics.sampler import DemographicSampler
from datamimic_ce.domains.common.models.demographic_config import DemographicConfig


@dataclass(frozen=True)
class DemographicContext:
"""Immutable container for the active demographic profile."""

profile_id: DemographicProfileId
sampler: DemographicSampler
overrides: DemographicConfig | None
rng: Random
12 changes: 12 additions & 0 deletions datamimic_ce/contexts/setup_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from datamimic_ce.clients.database_client import Client
from datamimic_ce.contexts.context import Context
from datamimic_ce.contexts.demographic_context import DemographicContext
from datamimic_ce.converter.converter import Converter
from datamimic_ce.converter.custom_converter import CustomConverter
from datamimic_ce.domains.domain_core.base_literal_generator import BaseLiteralGenerator
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(
generators: dict | None = None,
default_source_scripted: bool | None = None,
report_logging: bool = True,
demographic_context: DemographicContext | None = None,
):
# SetupContext is always its root_context
super().__init__(self)
Expand Down Expand Up @@ -86,6 +88,7 @@ def __init__(
self._report_logging = report_logging
self._current_seed = current_seed
self._task_exporters: dict[str, dict[str, Any]] = {}
self._demographic_context = demographic_context

def __deepcopy__(self, memo):
"""
Expand Down Expand Up @@ -125,6 +128,7 @@ def __deepcopy__(self, memo):
default_source_scripted=self._default_source_scripted,
report_logging=copy.deepcopy(self._report_logging),
current_seed=self._current_seed,
demographic_context=copy.deepcopy(self._demographic_context, memo),
)

def _deepcopy_clients(self, memo):
Expand Down Expand Up @@ -226,6 +230,14 @@ def update_with_stmt(self, stmt: SetupStatement):
if value is not None:
setattr(self, key, value)

@property
def demographic_context(self) -> DemographicContext | None:
return self._demographic_context

def set_demographic_context(self, context: DemographicContext) -> None:
# Keep demographics explicit on the root context instead of mutable module globals.
self._demographic_context = context

@property
def clients(self) -> dict:
return self._clients
Expand Down
15 changes: 15 additions & 0 deletions datamimic_ce/domains/common/demographics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Demographic profile domain package."""

from .loader import DemographicProfileError, load_demographic_profile
from .profile import DemographicProfile, DemographicProfileId, normalize_sex
from .sampler import DemographicSample, DemographicSampler

__all__ = [
"DemographicProfile",
"DemographicProfileId",
"DemographicProfileError",
"DemographicSampler",
"DemographicSample",
"load_demographic_profile",
"normalize_sex",
]
197 changes: 197 additions & 0 deletions datamimic_ce/domains/common/demographics/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""CSV loader for demographic profiles."""

from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path

from datamimic_ce.logger import logger
from datamimic_ce.utils.file_util import FileUtil

from .profile import (
DemographicAgeBand,
DemographicConditionRate,
DemographicProfile,
DemographicProfileId,
SexKey,
normalize_sex,
)

_REQUIRED_FILES = {
"age_pyramid.dmgrp.csv",
"condition_rates.dmgrp.csv",
}


class DemographicProfileError(ValueError):
"""Raised when demographic CSV files are invalid."""


def load_demographic_profile(directory: Path, dataset: str, version: str) -> DemographicProfile:
"""Load a demographic profile from the given directory."""

base_dir = Path(directory)
if not base_dir.exists():
raise DemographicProfileError(f"Demographic directory '{base_dir}' does not exist")

files = {path.name: path for path in base_dir.glob("*.dmgrp.csv")}
missing = _REQUIRED_FILES.difference(files)
if missing:
raise DemographicProfileError(
f"Missing demographic files {sorted(missing)} in '{base_dir}'. Ensure required CSVs exist."
)

age_bands = _load_age_bands(files["age_pyramid.dmgrp.csv"], dataset, version)
condition_rates = _load_condition_rates(files["condition_rates.dmgrp.csv"], dataset, version)

profile = DemographicProfile(
profile_id=DemographicProfileId(dataset=dataset, version=version),
age_bands=age_bands,
condition_rates=condition_rates,
)
return profile


def _load_age_bands(file_path: Path, dataset: str, version: str) -> dict[SexKey, tuple[DemographicAgeBand, ...]]:
rows = FileUtil.read_csv_to_dict_list(file_path, separator=",")
grouped: defaultdict[SexKey, list[DemographicAgeBand]] = defaultdict(list)
for idx, row in enumerate(rows, start=2):
_ensure_dataset_version(row, dataset, version, file_path, idx)
sex = normalize_sex(row.get("sex"))
try:
age_min = int(row["age_min"])
age_max = int(row["age_max"])
weight = float(row["weight"])
except (TypeError, ValueError) as exc:
raise DemographicProfileError(
f"Invalid numeric value in '{file_path}' line {idx}: {exc}."
" Expected integers for age_min/age_max and float for weight."
) from exc
if age_min > age_max:
raise DemographicProfileError(
f"age_min must be <= age_max in '{file_path}' line {idx}: got {age_min}>{age_max}."
)
if weight < 0:
raise DemographicProfileError(f"weight must be non-negative in '{file_path}' line {idx}: got {weight}.")
grouped[sex].append(
DemographicAgeBand(
sex=sex,
age_min=age_min,
age_max=age_max,
weight=weight,
)
)

normalized: dict[SexKey, tuple[DemographicAgeBand, ...]] = {}
for sex, bands in grouped.items():
sorted_bands = sorted(bands, key=lambda b: (b.age_min, b.age_max))
_validate_band_coverage(sorted_bands, file_path, sex)
total_weight = sum(b.weight for b in sorted_bands)
if not _is_close(total_weight, 1.0):
raise DemographicProfileError(
f"Weights must sum to 1.0 per sex in '{file_path}' for sex='{sex or ''}' (sum={total_weight:.6f})."
)
normalized[sex] = tuple(sorted_bands)
if not normalized:
raise DemographicProfileError(f"No rows parsed from '{file_path}'.")
return normalized


def _load_condition_rates(
file_path: Path, dataset: str, version: str
) -> dict[str, tuple[DemographicConditionRate, ...]]:
rows = FileUtil.read_csv_to_dict_list(file_path, separator=",")
grouped: defaultdict[str, list[DemographicConditionRate]] = defaultdict(list)
for idx, row in enumerate(rows, start=2):
_ensure_dataset_version(row, dataset, version, file_path, idx)
condition = (row.get("condition") or "").strip()
if not condition:
raise DemographicProfileError(
f"Condition name missing in '{file_path}' line {idx}. Provide canonical condition labels."
)
sex = normalize_sex(row.get("sex"))
try:
age_min = int(row["age_min"])
age_max = int(row["age_max"])
prevalence = float(row["prevalence"])
except (TypeError, ValueError) as exc:
raise DemographicProfileError(
f"Invalid numeric value in '{file_path}' line {idx}: {exc}."
" Expected integers for age_min/age_max and float for prevalence."
) from exc
if age_min > age_max:
raise DemographicProfileError(
f"age_min must be <= age_max in '{file_path}' line {idx}: got {age_min}>{age_max}."
)
if not 0.0 <= prevalence <= 1.0:
raise DemographicProfileError(
f"prevalence must be within [0,1] in '{file_path}' line {idx}: got {prevalence}."
)
grouped[condition].append(
DemographicConditionRate(
condition=condition,
sex=sex,
age_min=age_min,
age_max=age_max,
prevalence=prevalence,
)
)

normalized: dict[str, tuple[DemographicConditionRate, ...]] = {}
for condition, rates in grouped.items():
normalized[condition] = tuple(sorted(rates, key=_condition_sort_key))
return normalized


def _condition_sort_key(rate: DemographicConditionRate) -> tuple[int, int, int]:
# Stable ordering ensures deterministic sampling and makes tests reproducible.
return (0 if rate.sex is not None else 1, rate.age_min, rate.age_max)


def _ensure_dataset_version(
row: dict,
dataset: str,
version: str,
file_path: Path,
line_number: int,
) -> None:
if (row.get("dataset") or "").strip() != dataset:
raise DemographicProfileError(f"Dataset mismatch in '{file_path}' line {line_number}: expected '{dataset}'.")
if (row.get("version") or "").strip() != version:
raise DemographicProfileError(f"Version mismatch in '{file_path}' line {line_number}: expected '{version}'.")


def _validate_band_coverage(bands: Iterable[DemographicAgeBand], file_path: Path, sex: SexKey) -> None:
sorted_bands = list(bands)
previous = None
for band in sorted_bands:
if previous and band.age_min <= previous.age_max:
raise DemographicProfileError(
f"Overlapping age bands for sex='{sex or ''}' in '{file_path}':"
f" [{previous.age_min},{previous.age_max}] overlaps [{band.age_min},{band.age_max}]."
)
if previous and band.age_min > previous.age_max + 1:
logger.warning(
"Gap detected between age bands for sex='%s' in '%s': [%s,%s] -> [%s,%s]", # type: ignore[str-format]
sex or "",
file_path,
previous.age_min,
previous.age_max,
band.age_min,
band.age_max,
)
previous = band
if sorted_bands:
if sorted_bands[0].age_min > 0:
logger.warning(
"Age coverage for sex='%s' in '%s' starts at %s (>0).", sex or "", file_path, sorted_bands[0].age_min
)
if sorted_bands[-1].age_max < 100:
logger.warning(
"Age coverage for sex='%s' in '%s' ends at %s (<100).", sex or "", file_path, sorted_bands[-1].age_max
)


def _is_close(value: float, target: float, *, tolerance: float = 1e-6) -> bool:
return abs(value - target) <= tolerance
80 changes: 80 additions & 0 deletions datamimic_ce/domains/common/demographics/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Demographic profile domain objects."""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from dataclasses import dataclass

SexKey = str | None


def normalize_sex(sex: str | None) -> SexKey:
"""Normalize raw sex codes coming from CSV files."""

if sex is None:
return None
value = sex.strip().upper()
return value or None


@dataclass(frozen=True)
class DemographicProfileId:
"""Stable identifier for a demographic profile dataset."""

dataset: str
version: str


@dataclass(frozen=True)
class DemographicAgeBand:
"""Age distribution entry scoped to a sex bucket."""

sex: SexKey
age_min: int
age_max: int
weight: float

def contains(self, age: int) -> bool:
return self.age_min <= age <= self.age_max


@dataclass(frozen=True)
class DemographicConditionRate:
"""Condition prevalence entry used for Bernoulli sampling."""

condition: str
sex: SexKey
age_min: int
age_max: int
prevalence: float

def matches(self, *, age: int, sex: SexKey) -> bool:
return self.age_min <= age <= self.age_max and (self.sex is None or self.sex == sex)


@dataclass(frozen=True)
class DemographicProfile:
"""Collection of demographic priors used by samplers."""

profile_id: DemographicProfileId
age_bands: Mapping[SexKey, tuple[DemographicAgeBand, ...]]
condition_rates: Mapping[str, tuple[DemographicConditionRate, ...]]

def bands_for_sex(self, sex: SexKey) -> tuple[DemographicAgeBand, ...]:
"""Return ordered bands for a given sex, falling back to combined data."""

normalized = normalize_sex(sex)
if normalized in self.age_bands:
return self.age_bands[normalized]
# Combined (sex-less) priors should backstop missing sex specific data without failing generation.
return self.age_bands.get(None, ())

def conditions_for(self, condition: str) -> tuple[DemographicConditionRate, ...]:
"""Return ordered prevalence rows for a condition."""

return self.condition_rates.get(condition, ())

def sexes(self) -> Sequence[SexKey]:
"""Expose the known sex buckets for downstream logic."""

return tuple(self.age_bands.keys())
Loading