diff --git a/datamimic_ce/constants/attribute_constants.py b/datamimic_ce/constants/attribute_constants.py index 597dbf97..4dd4a6b8 100644 --- a/datamimic_ce/constants/attribute_constants.py +++ b/datamimic_ce/constants/attribute_constants.py @@ -70,6 +70,7 @@ ATTR_DEFAULT_VARIABLE_SUFFIX: Final = "defaultVariableSuffix" ATTR_VARIABLE_PREFIX: Final = "variablePrefix" ATTR_VARIABLE_SUFFIX: Final = "variableSuffix" +ATTR_DIR: Final = "dir" ATTR_STRING: Final = "string" ATTR_BUCKET: Final = "bucket" ATTR_MP_PLATFORM: Final = "mpPlatform" diff --git a/datamimic_ce/constants/element_constants.py b/datamimic_ce/constants/element_constants.py index ee06e00e..2d8fbbfa 100644 --- a/datamimic_ce/constants/element_constants.py +++ b/datamimic_ce/constants/element_constants.py @@ -25,3 +25,4 @@ EL_CONDITION = "condition" EL_ELSE_IF = "else-if" EL_ELSE = "else" +EL_DEMOGRAPHICS = "demographics" diff --git a/datamimic_ce/contexts/demographic_context.py b/datamimic_ce/contexts/demographic_context.py new file mode 100644 index 00000000..aec6f3ae --- /dev/null +++ b/datamimic_ce/contexts/demographic_context.py @@ -0,0 +1,20 @@ +"""Context wiring for demographic samplers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from random import Random + +from datamimic_ce.domains.common.demographics.profile import DemographicProfileId +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler +from datamimic_ce.domains.common.models.demographic_config import DemographicConfig + + +@dataclass(frozen=True) +class DemographicContext: + """Immutable container for the active demographic profile.""" + + profile_id: DemographicProfileId + sampler: DemographicSampler + overrides: DemographicConfig | None + rng: Random diff --git a/datamimic_ce/contexts/setup_context.py b/datamimic_ce/contexts/setup_context.py index 9bd8895e..5c9cf7b7 100644 --- a/datamimic_ce/contexts/setup_context.py +++ b/datamimic_ce/contexts/setup_context.py @@ -11,6 +11,7 @@ from datamimic_ce.clients.database_client import Client from datamimic_ce.contexts.context import Context +from datamimic_ce.contexts.demographic_context import DemographicContext from datamimic_ce.converter.converter import Converter from datamimic_ce.converter.custom_converter import CustomConverter from datamimic_ce.domains.domain_core.base_literal_generator import BaseLiteralGenerator @@ -49,6 +50,7 @@ def __init__( generators: dict | None = None, default_source_scripted: bool | None = None, report_logging: bool = True, + demographic_context: DemographicContext | None = None, ): # SetupContext is always its root_context super().__init__(self) @@ -86,6 +88,7 @@ def __init__( self._report_logging = report_logging self._current_seed = current_seed self._task_exporters: dict[str, dict[str, Any]] = {} + self._demographic_context = demographic_context def __deepcopy__(self, memo): """ @@ -125,6 +128,7 @@ def __deepcopy__(self, memo): default_source_scripted=self._default_source_scripted, report_logging=copy.deepcopy(self._report_logging), current_seed=self._current_seed, + demographic_context=copy.deepcopy(self._demographic_context, memo), ) def _deepcopy_clients(self, memo): @@ -226,6 +230,14 @@ def update_with_stmt(self, stmt: SetupStatement): if value is not None: setattr(self, key, value) + @property + def demographic_context(self) -> DemographicContext | None: + return self._demographic_context + + def set_demographic_context(self, context: DemographicContext) -> None: + # Keep demographics explicit on the root context instead of mutable module globals. + self._demographic_context = context + @property def clients(self) -> dict: return self._clients diff --git a/datamimic_ce/domains/common/demographics/__init__.py b/datamimic_ce/domains/common/demographics/__init__.py new file mode 100644 index 00000000..6958999a --- /dev/null +++ b/datamimic_ce/domains/common/demographics/__init__.py @@ -0,0 +1,15 @@ +"""Demographic profile domain package.""" + +from .loader import DemographicProfileError, load_demographic_profile +from .profile import DemographicProfile, DemographicProfileId, normalize_sex +from .sampler import DemographicSample, DemographicSampler + +__all__ = [ + "DemographicProfile", + "DemographicProfileId", + "DemographicProfileError", + "DemographicSampler", + "DemographicSample", + "load_demographic_profile", + "normalize_sex", +] diff --git a/datamimic_ce/domains/common/demographics/loader.py b/datamimic_ce/domains/common/demographics/loader.py new file mode 100644 index 00000000..38294890 --- /dev/null +++ b/datamimic_ce/domains/common/demographics/loader.py @@ -0,0 +1,197 @@ +"""CSV loader for demographic profiles.""" + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from pathlib import Path + +from datamimic_ce.logger import logger +from datamimic_ce.utils.file_util import FileUtil + +from .profile import ( + DemographicAgeBand, + DemographicConditionRate, + DemographicProfile, + DemographicProfileId, + SexKey, + normalize_sex, +) + +_REQUIRED_FILES = { + "age_pyramid.dmgrp.csv", + "condition_rates.dmgrp.csv", +} + + +class DemographicProfileError(ValueError): + """Raised when demographic CSV files are invalid.""" + + +def load_demographic_profile(directory: Path, dataset: str, version: str) -> DemographicProfile: + """Load a demographic profile from the given directory.""" + + base_dir = Path(directory) + if not base_dir.exists(): + raise DemographicProfileError(f"Demographic directory '{base_dir}' does not exist") + + files = {path.name: path for path in base_dir.glob("*.dmgrp.csv")} + missing = _REQUIRED_FILES.difference(files) + if missing: + raise DemographicProfileError( + f"Missing demographic files {sorted(missing)} in '{base_dir}'. Ensure required CSVs exist." + ) + + age_bands = _load_age_bands(files["age_pyramid.dmgrp.csv"], dataset, version) + condition_rates = _load_condition_rates(files["condition_rates.dmgrp.csv"], dataset, version) + + profile = DemographicProfile( + profile_id=DemographicProfileId(dataset=dataset, version=version), + age_bands=age_bands, + condition_rates=condition_rates, + ) + return profile + + +def _load_age_bands(file_path: Path, dataset: str, version: str) -> dict[SexKey, tuple[DemographicAgeBand, ...]]: + rows = FileUtil.read_csv_to_dict_list(file_path, separator=",") + grouped: defaultdict[SexKey, list[DemographicAgeBand]] = defaultdict(list) + for idx, row in enumerate(rows, start=2): + _ensure_dataset_version(row, dataset, version, file_path, idx) + sex = normalize_sex(row.get("sex")) + try: + age_min = int(row["age_min"]) + age_max = int(row["age_max"]) + weight = float(row["weight"]) + except (TypeError, ValueError) as exc: + raise DemographicProfileError( + f"Invalid numeric value in '{file_path}' line {idx}: {exc}." + " Expected integers for age_min/age_max and float for weight." + ) from exc + if age_min > age_max: + raise DemographicProfileError( + f"age_min must be <= age_max in '{file_path}' line {idx}: got {age_min}>{age_max}." + ) + if weight < 0: + raise DemographicProfileError(f"weight must be non-negative in '{file_path}' line {idx}: got {weight}.") + grouped[sex].append( + DemographicAgeBand( + sex=sex, + age_min=age_min, + age_max=age_max, + weight=weight, + ) + ) + + normalized: dict[SexKey, tuple[DemographicAgeBand, ...]] = {} + for sex, bands in grouped.items(): + sorted_bands = sorted(bands, key=lambda b: (b.age_min, b.age_max)) + _validate_band_coverage(sorted_bands, file_path, sex) + total_weight = sum(b.weight for b in sorted_bands) + if not _is_close(total_weight, 1.0): + raise DemographicProfileError( + f"Weights must sum to 1.0 per sex in '{file_path}' for sex='{sex or ''}' (sum={total_weight:.6f})." + ) + normalized[sex] = tuple(sorted_bands) + if not normalized: + raise DemographicProfileError(f"No rows parsed from '{file_path}'.") + return normalized + + +def _load_condition_rates( + file_path: Path, dataset: str, version: str +) -> dict[str, tuple[DemographicConditionRate, ...]]: + rows = FileUtil.read_csv_to_dict_list(file_path, separator=",") + grouped: defaultdict[str, list[DemographicConditionRate]] = defaultdict(list) + for idx, row in enumerate(rows, start=2): + _ensure_dataset_version(row, dataset, version, file_path, idx) + condition = (row.get("condition") or "").strip() + if not condition: + raise DemographicProfileError( + f"Condition name missing in '{file_path}' line {idx}. Provide canonical condition labels." + ) + sex = normalize_sex(row.get("sex")) + try: + age_min = int(row["age_min"]) + age_max = int(row["age_max"]) + prevalence = float(row["prevalence"]) + except (TypeError, ValueError) as exc: + raise DemographicProfileError( + f"Invalid numeric value in '{file_path}' line {idx}: {exc}." + " Expected integers for age_min/age_max and float for prevalence." + ) from exc + if age_min > age_max: + raise DemographicProfileError( + f"age_min must be <= age_max in '{file_path}' line {idx}: got {age_min}>{age_max}." + ) + if not 0.0 <= prevalence <= 1.0: + raise DemographicProfileError( + f"prevalence must be within [0,1] in '{file_path}' line {idx}: got {prevalence}." + ) + grouped[condition].append( + DemographicConditionRate( + condition=condition, + sex=sex, + age_min=age_min, + age_max=age_max, + prevalence=prevalence, + ) + ) + + normalized: dict[str, tuple[DemographicConditionRate, ...]] = {} + for condition, rates in grouped.items(): + normalized[condition] = tuple(sorted(rates, key=_condition_sort_key)) + return normalized + + +def _condition_sort_key(rate: DemographicConditionRate) -> tuple[int, int, int]: + # Stable ordering ensures deterministic sampling and makes tests reproducible. + return (0 if rate.sex is not None else 1, rate.age_min, rate.age_max) + + +def _ensure_dataset_version( + row: dict, + dataset: str, + version: str, + file_path: Path, + line_number: int, +) -> None: + if (row.get("dataset") or "").strip() != dataset: + raise DemographicProfileError(f"Dataset mismatch in '{file_path}' line {line_number}: expected '{dataset}'.") + if (row.get("version") or "").strip() != version: + raise DemographicProfileError(f"Version mismatch in '{file_path}' line {line_number}: expected '{version}'.") + + +def _validate_band_coverage(bands: Iterable[DemographicAgeBand], file_path: Path, sex: SexKey) -> None: + sorted_bands = list(bands) + previous = None + for band in sorted_bands: + if previous and band.age_min <= previous.age_max: + raise DemographicProfileError( + f"Overlapping age bands for sex='{sex or ''}' in '{file_path}':" + f" [{previous.age_min},{previous.age_max}] overlaps [{band.age_min},{band.age_max}]." + ) + if previous and band.age_min > previous.age_max + 1: + logger.warning( + "Gap detected between age bands for sex='%s' in '%s': [%s,%s] -> [%s,%s]", # type: ignore[str-format] + sex or "", + file_path, + previous.age_min, + previous.age_max, + band.age_min, + band.age_max, + ) + previous = band + if sorted_bands: + if sorted_bands[0].age_min > 0: + logger.warning( + "Age coverage for sex='%s' in '%s' starts at %s (>0).", sex or "", file_path, sorted_bands[0].age_min + ) + if sorted_bands[-1].age_max < 100: + logger.warning( + "Age coverage for sex='%s' in '%s' ends at %s (<100).", sex or "", file_path, sorted_bands[-1].age_max + ) + + +def _is_close(value: float, target: float, *, tolerance: float = 1e-6) -> bool: + return abs(value - target) <= tolerance diff --git a/datamimic_ce/domains/common/demographics/profile.py b/datamimic_ce/domains/common/demographics/profile.py new file mode 100644 index 00000000..c0f7d0b2 --- /dev/null +++ b/datamimic_ce/domains/common/demographics/profile.py @@ -0,0 +1,80 @@ +"""Demographic profile domain objects.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + +SexKey = str | None + + +def normalize_sex(sex: str | None) -> SexKey: + """Normalize raw sex codes coming from CSV files.""" + + if sex is None: + return None + value = sex.strip().upper() + return value or None + + +@dataclass(frozen=True) +class DemographicProfileId: + """Stable identifier for a demographic profile dataset.""" + + dataset: str + version: str + + +@dataclass(frozen=True) +class DemographicAgeBand: + """Age distribution entry scoped to a sex bucket.""" + + sex: SexKey + age_min: int + age_max: int + weight: float + + def contains(self, age: int) -> bool: + return self.age_min <= age <= self.age_max + + +@dataclass(frozen=True) +class DemographicConditionRate: + """Condition prevalence entry used for Bernoulli sampling.""" + + condition: str + sex: SexKey + age_min: int + age_max: int + prevalence: float + + def matches(self, *, age: int, sex: SexKey) -> bool: + return self.age_min <= age <= self.age_max and (self.sex is None or self.sex == sex) + + +@dataclass(frozen=True) +class DemographicProfile: + """Collection of demographic priors used by samplers.""" + + profile_id: DemographicProfileId + age_bands: Mapping[SexKey, tuple[DemographicAgeBand, ...]] + condition_rates: Mapping[str, tuple[DemographicConditionRate, ...]] + + def bands_for_sex(self, sex: SexKey) -> tuple[DemographicAgeBand, ...]: + """Return ordered bands for a given sex, falling back to combined data.""" + + normalized = normalize_sex(sex) + if normalized in self.age_bands: + return self.age_bands[normalized] + # Combined (sex-less) priors should backstop missing sex specific data without failing generation. + return self.age_bands.get(None, ()) + + def conditions_for(self, condition: str) -> tuple[DemographicConditionRate, ...]: + """Return ordered prevalence rows for a condition.""" + + return self.condition_rates.get(condition, ()) + + def sexes(self) -> Sequence[SexKey]: + """Expose the known sex buckets for downstream logic.""" + + return tuple(self.age_bands.keys()) diff --git a/datamimic_ce/domains/common/demographics/sampler.py b/datamimic_ce/domains/common/demographics/sampler.py new file mode 100644 index 00000000..4ad014a0 --- /dev/null +++ b/datamimic_ce/domains/common/demographics/sampler.py @@ -0,0 +1,118 @@ +"""Pure demographic sampler built on top of demographic profiles.""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from random import Random + +from .profile import ( + DemographicAgeBand, + DemographicConditionRate, + DemographicProfile, + DemographicProfileId, + SexKey, + normalize_sex, +) + + +@dataclass(frozen=True) +class DemographicSample: + """Sampled demographic attributes for a single entity.""" + + age: int | None + sex: SexKey + conditions: frozenset[str] + + +class DemographicSampler: + """Pure sampler that draws ages, sexes and conditions from a profile.""" + + def __init__(self, profile: DemographicProfile): + self._profile = profile + self._age_cdf = self._build_age_cdfs(profile) + self._sex_choices = self._build_sex_choices(profile) + + @property + def profile_id(self) -> DemographicProfileId: + return self._profile.profile_id + + def sample_age_sex(self, rng: Random) -> tuple[int, SexKey]: + """Sample an age band and sex using cumulative weights.""" + + sex = self._choose_sex(rng) + bands = self._age_cdf.get(sex) or self._age_cdf.get(None) + if not bands: + raise ValueError("No age bands available for demographic sampling") + pick = rng.random() + for threshold, band in bands: + if pick <= threshold: + age = rng.randint(band.age_min, band.age_max) + return age, sex + # Numerical precision might leave pick slightly above the last threshold. + last_band = bands[-1][1] + # Guarantee a return even when float rounding pushes us past the final threshold. + return rng.randint(last_band.age_min, last_band.age_max), sex + + def sample_conditions(self, age: int, sex: SexKey, rng: Random) -> frozenset[str]: + """Sample independent conditions based on prevalence tables.""" + + selected: set[str] = set() + normalized_sex = normalize_sex(sex) + for condition, rates in self._profile.condition_rates.items(): + rate = self._select_rate(rates, age, normalized_sex) + if rate is None: + continue + if rng.random() <= rate.prevalence: + selected.add(condition) + return frozenset(selected) + + def _select_rate( + self, + rates: Iterable[DemographicConditionRate], + age: int, + sex: SexKey, + ) -> DemographicConditionRate | None: + # Prefer exact sex match, otherwise fall back to combined entries. + exact_match = None + fallback = None + for rate in rates: + if not rate.matches(age=age, sex=sex): + continue + if rate.sex is None: + fallback = rate + else: + exact_match = rate + break + return exact_match or fallback + + def _build_age_cdfs(self, profile: DemographicProfile) -> dict[SexKey, list[tuple[float, DemographicAgeBand]]]: + cdfs: dict[SexKey, list[tuple[float, DemographicAgeBand]]] = {} + for sex, bands in profile.age_bands.items(): + cumulative = 0.0 + cdf: list[tuple[float, DemographicAgeBand]] = [] + for band in bands: + cumulative += band.weight + cdf.append((cumulative, band)) + if not cdf: + continue + # Normalize trailing rounding differences so cumulative ends at 1.0. + final_threshold, last_band = cdf[-1] + if abs(final_threshold - 1.0) > 1e-9: + adjust = final_threshold + cdf = [(threshold / adjust, band) for threshold, band in cdf] + cdfs[sex] = cdf + return cdfs + + def _build_sex_choices(self, profile: DemographicProfile) -> tuple[SexKey, ...]: + # Use deterministic ordering so seeded RNGs replay the same sex sequence across runs. + sexes = sorted(sex for sex in profile.sexes() if sex is not None) + if not sexes: + return (None,) + return tuple(sexes) + + def _choose_sex(self, rng: Random) -> SexKey: + if not self._sex_choices: + return None + # Treat provided sex buckets as equally likely until we collect richer priors (future roadmap). + return rng.choice(self._sex_choices) diff --git a/datamimic_ce/domains/common/generators/person_generator.py b/datamimic_ce/domains/common/generators/person_generator.py index 47540828..cfdad00b 100644 --- a/datamimic_ce/domains/common/generators/person_generator.py +++ b/datamimic_ce/domains/common/generators/person_generator.py @@ -7,9 +7,11 @@ from __future__ import annotations +from datetime import datetime from pathlib import Path from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSample, DemographicSampler from datamimic_ce.domains.common.generators.address_generator import AddressGenerator from datamimic_ce.domains.common.literal_generators.academic_title_generator import AcademicTitleGenerator from datamimic_ce.domains.common.literal_generators.birthdate_generator import BirthdateGenerator @@ -42,10 +44,12 @@ def __init__( noble_quota: float = 0.001, academic_title_quota: float = 0.5, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): self._dataset = dataset or "US" self._rng: Random = rng or Random() + self._demographic_sampler = demographic_sampler # Normalize demographic overrides once to keep SPOT and reuse downstream. resolved_config = (demographic_config or DemographicConfig()).with_defaults( default_age_min=min_age, @@ -84,12 +88,12 @@ def __init__( rng=self._derive_rng() if rng is not None else None, ) self._demographic_config = resolved_config - birth_min = self._demographic_config.age_min if self._demographic_config.age_min is not None else min_age - birth_max = self._demographic_config.age_max if self._demographic_config.age_max is not None else max_age + self._birth_min = self._demographic_config.age_min if self._demographic_config.age_min is not None else min_age + self._birth_max = self._demographic_config.age_max if self._demographic_config.age_max is not None else max_age # Clamp birthdate sampling to caller-provided bounds without scattering defaults. self._birthdate_generator = BirthdateGenerator( - min_age=birth_min, - max_age=birth_max, + min_age=self._birth_min, + max_age=self._birth_max, rng=self._derive_rng() if rng is not None else None, ) self._academic_title_generator = AcademicTitleGenerator( @@ -102,11 +106,26 @@ def __init__( noble_quota=noble_quota, rng=self._derive_rng() if rng is not None else None, ) + self._demographic_rng = self._derive_rng() if demographic_sampler is not None and rng is not None else Random() def _derive_rng(self) -> Random: # Spawn child RNGs from the base seed so seeded descriptors replay without entangling independent draws. return Random(self._rng.randrange(2**63)) if isinstance(self._rng, Random) else Random() + def reserve_demographic_sample(self) -> DemographicSample: + if self._demographic_sampler is None: + return DemographicSample(age=None, sex=None, conditions=frozenset()) + age, sex = self._demographic_sampler.sample_age_sex(self._demographic_rng) + clamped_age = max(self._birth_min, min(self._birth_max, age)) + conditions = self._demographic_sampler.sample_conditions(clamped_age, sex, self._demographic_rng) + # Sample once per entity so downstream services share consistent demographics. + return DemographicSample(age=clamped_age, sex=sex, conditions=conditions) + + def generate_birthdate_for_age(self, age: int) -> datetime: + # Dedicated generator keeps demographic birthdates independent from other literal draws. + generator = BirthdateGenerator(min_age=age, max_age=age, rng=self._derive_rng()) + return generator.generate() + @property def gender_generator(self) -> GenderGenerator: return self._gender_generator diff --git a/datamimic_ce/domains/common/literal_generators/data_faker_generator.py b/datamimic_ce/domains/common/literal_generators/data_faker_generator.py index de67c78e..c2ded6d4 100644 --- a/datamimic_ce/domains/common/literal_generators/data_faker_generator.py +++ b/datamimic_ce/domains/common/literal_generators/data_faker_generator.py @@ -36,7 +36,7 @@ def __init__( raise ValueError(f"Faker method '{method}' is not supported") self._faker = Faker(locale) if rng is not None: - # WHY: faker.Faker exposes a dynamic `random` attribute; cast to a protocol so mypy accepts the assignment. + # faker.Faker exposes a dynamic `random` attribute; cast to a protocol so mypy accepts the assignment. faker_with_random = cast(_SupportsRandom, self._faker) faker_with_random.random = rng self._method = method diff --git a/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py b/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py index 1fab655b..ecc40d0c 100644 --- a/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py +++ b/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py @@ -71,7 +71,7 @@ def generate_with_gender(self, gender: str) -> str: values = list(self._female_values) weights = list(self._female_weights) else: - # WHY: Merge available titles for non-binary genders instead of returning None when quota triggers. + # Merge available titles for non-binary genders instead of returning None when quota triggers. values = list(self._male_values) + list(self._female_values) weights = list(self._male_weights) + list(self._female_weights) diff --git a/datamimic_ce/domains/common/models/person.py b/datamimic_ce/domains/common/models/person.py index caa59d97..36d89a6d 100644 --- a/datamimic_ce/domains/common/models/person.py +++ b/datamimic_ce/domains/common/models/person.py @@ -14,6 +14,7 @@ from datetime import datetime from typing import Any +from datamimic_ce.domains.common.demographics.sampler import DemographicSample from datamimic_ce.domains.common.generators.person_generator import PersonGenerator from datamimic_ce.domains.common.models.address import Address from datamimic_ce.domains.domain_core import BaseEntity @@ -30,6 +31,7 @@ class Person(BaseEntity): def __init__(self, person_generator: PersonGenerator): super().__init__() self._person_generator = person_generator + self._demographic_sample: DemographicSample = person_generator.reserve_demographic_sample() @property @property_cache @@ -39,6 +41,13 @@ def gender(self) -> str: Returns: The gender of the person. """ + sample_sex = self._demographic_sample.sex + if sample_sex is not None: + normalized = sample_sex.upper() + if normalized == "F": + return "female" + if normalized == "M": + return "male" return self._person_generator.gender_generator.generate() @property @@ -175,6 +184,9 @@ def birthdate(self) -> datetime: Returns: The birthdate of the person. """ + if self._demographic_sample.age is not None: + # Align birthdate with demographic priors so age property matches sampled intent. + return self._person_generator.generate_birthdate_for_age(self._demographic_sample.age) return self._person_generator.birthdate_generator.generate() @property @@ -185,6 +197,10 @@ def transaction_profile(self) -> str | Mapping[str, float] | None: # Allow downstream finance generators to reuse demographic intent. return self._person_generator.demographic_config.transaction_profile + @property + def demographic_sample(self) -> DemographicSample: + return self._demographic_sample + @property @property_cache def academic_title(self) -> str | None: diff --git a/datamimic_ce/domains/common/services/person_service.py b/datamimic_ce/domains/common/services/person_service.py index e8d4094d..c0fed018 100644 --- a/datamimic_ce/domains/common/services/person_service.py +++ b/datamimic_ce/domains/common/services/person_service.py @@ -7,6 +7,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.generators.person_generator import PersonGenerator from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.common.models.person import Person @@ -27,6 +28,7 @@ def __init__( female_quota: float = 0.5, other_gender_quota: float = 0.0, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, noble_quota: float = 0.001, academic_title_quota: float = 0.5, @@ -46,6 +48,8 @@ def __init__( female_quota=female_quota, other_gender_quota=other_gender_quota, demographic_config=resolved_config, + # Pass sampler through the service so entity generation can honour population priors. + demographic_sampler=demographic_sampler, rng=rng, # Thread descriptor-level overrides for noble/title quotas into the generator for determinism. noble_quota=noble_quota, diff --git a/datamimic_ce/domains/healthcare/generators/doctor_generator.py b/datamimic_ce/domains/healthcare/generators/doctor_generator.py index 677b4099..3f24709b 100644 --- a/datamimic_ce/domains/healthcare/generators/doctor_generator.py +++ b/datamimic_ce/domains/healthcare/generators/doctor_generator.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig import random @@ -35,6 +36,7 @@ def __init__( dataset: str | None = None, rng: random.Random | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, ): # normalize dataset to ISO-3166 alpha-2 and keep lookup consistent self._dataset = (dataset or "US").upper() @@ -43,7 +45,11 @@ def __init__( demo = demographic_config if demographic_config is not None else _DC() self._person_generator = PersonGenerator( - dataset=self._dataset, demographic_config=demo, rng=self._rng, min_age=25 + dataset=self._dataset, + demographic_config=demo, + demographic_sampler=demographic_sampler, + rng=self._rng, + min_age=25, ) self._hospital_generator = HospitalGenerator(dataset=self._dataset, rng=self._rng) self._last_specialty: str | None = None diff --git a/datamimic_ce/domains/healthcare/generators/patient_generator.py b/datamimic_ce/domains/healthcare/generators/patient_generator.py index 7fb758ea..0687c89f 100644 --- a/datamimic_ce/domains/healthcare/generators/patient_generator.py +++ b/datamimic_ce/domains/healthcare/generators/patient_generator.py @@ -14,6 +14,7 @@ from pathlib import Path from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSample, DemographicSampler from datamimic_ce.domains.common.generators.person_generator import PersonGenerator from datamimic_ce.domains.common.literal_generators.family_name_generator import FamilyNameGenerator from datamimic_ce.domains.common.literal_generators.given_name_generator import GivenNameGenerator @@ -37,6 +38,7 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): self._dataset = dataset or "US" @@ -45,8 +47,10 @@ def __init__( self._person_generator = PersonGenerator( dataset=self._dataset, demographic_config=self._demographic_config, + demographic_sampler=demographic_sampler, rng=self._rng, ) + self._demographic_sampler = demographic_sampler # Fan out deterministic RNG so seeded patient cohorts remain reproducible across dependent literals. self._family_name_generator = FamilyNameGenerator( dataset=self._dataset, @@ -103,12 +107,27 @@ def pick_blood_type(self, *, start_path: str | None = None) -> str: self._last_blood_type = choice return choice - def generate_age_appropriate_conditions(self, age: int) -> list[str]: + def generate_age_appropriate_conditions( + self, age: int, demographic_sample: DemographicSample | None = None + ) -> list[str]: """Generate weighted medical conditions for the given age.""" base_weights = _load_condition_base_weights(self._dataset) include = self._demographic_config.normalized_includes() exclude = self._demographic_config.normalized_excludes() + if demographic_sample is not None and demographic_sample.conditions: + selections = [ + condition + for condition in demographic_sample.conditions + if condition.lower() not in {name.lower() for name in exclude} + ] + normalized_includes = {name.lower(): name for name in include} + present = {name.lower() for name in selections} + for key, original in normalized_includes.items(): + if key not in present and key not in {name.lower() for name in exclude}: + selections.append(original) + # Deterministic ordering keeps tests stable when sampler emits frozensets. + return sorted(selections) if not base_weights: return [] # Apply age multipliers while keeping CSV as the single source of names and base weights. diff --git a/datamimic_ce/domains/healthcare/models/patient.py b/datamimic_ce/domains/healthcare/models/patient.py index ab58975f..f4271c78 100644 --- a/datamimic_ce/domains/healthcare/models/patient.py +++ b/datamimic_ce/domains/healthcare/models/patient.py @@ -260,7 +260,9 @@ def conditions(self) -> list[str]: Returns: A list of medical conditions. """ - return self._patient_generator.generate_age_appropriate_conditions(self.age) + sample = self.person_data.demographic_sample + # Feed sampler-backed conditions first so overrides only adjust instead of resampling. + return self._patient_generator.generate_age_appropriate_conditions(self.age, sample) @property @property_cache diff --git a/datamimic_ce/domains/healthcare/services/doctor_service.py b/datamimic_ce/domains/healthcare/services/doctor_service.py index 0529ca0d..4566a8a8 100644 --- a/datamimic_ce/domains/healthcare/services/doctor_service.py +++ b/datamimic_ce/domains/healthcare/services/doctor_service.py @@ -12,6 +12,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.domain_core import BaseDomainService from datamimic_ce.domains.healthcare.generators.doctor_generator import DoctorGenerator @@ -29,12 +30,19 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ) -> None: import random as _r super().__init__( - DoctorGenerator(dataset=dataset, rng=rng or _r.Random(), demographic_config=demographic_config), Doctor + DoctorGenerator( + dataset=dataset, + rng=rng or _r.Random(), + demographic_config=demographic_config, + demographic_sampler=demographic_sampler, + ), + Doctor, ) @staticmethod diff --git a/datamimic_ce/domains/healthcare/services/patient_service.py b/datamimic_ce/domains/healthcare/services/patient_service.py index 68645e95..c722d474 100644 --- a/datamimic_ce/domains/healthcare/services/patient_service.py +++ b/datamimic_ce/domains/healthcare/services/patient_service.py @@ -12,6 +12,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.domain_core import BaseDomainService from datamimic_ce.domains.healthcare.generators.patient_generator import PatientGenerator @@ -29,11 +30,17 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): # Thread demographic and RNG overrides through the service layer. super().__init__( - PatientGenerator(dataset=dataset, demographic_config=demographic_config, rng=rng), + PatientGenerator( + dataset=dataset, + demographic_config=demographic_config, + demographic_sampler=demographic_sampler, + rng=rng, + ), Patient, ) diff --git a/datamimic_ce/domains/public_sector/generators/police_officer_generator.py b/datamimic_ce/domains/public_sector/generators/police_officer_generator.py index d8aebfd4..bb013714 100644 --- a/datamimic_ce/domains/public_sector/generators/police_officer_generator.py +++ b/datamimic_ce/domains/public_sector/generators/police_officer_generator.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from pathlib import Path @@ -29,6 +30,7 @@ def __init__( dataset: str | None = None, rng: random.Random | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, ): """Initialize the police officer generator. @@ -41,7 +43,11 @@ def __init__( demo = demographic_config if demographic_config is not None else _DC() self._person_generator = PersonGenerator( - dataset=self._dataset, demographic_config=demo, rng=self._rng, min_age=21 + dataset=self._dataset, + demographic_config=demo, + demographic_sampler=demographic_sampler, + rng=self._rng, + min_age=21, ) # Hand child generators a derived RNG so seeded officers replay consistently across runs. self._address_generator = AddressGenerator( diff --git a/datamimic_ce/domains/public_sector/services/police_officer_service.py b/datamimic_ce/domains/public_sector/services/police_officer_service.py index c31a4b0d..9b94711e 100644 --- a/datamimic_ce/domains/public_sector/services/police_officer_service.py +++ b/datamimic_ce/domains/public_sector/services/police_officer_service.py @@ -12,6 +12,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.domain_core import BaseDomainService from datamimic_ce.domains.public_sector.generators.police_officer_generator import PoliceOfficerGenerator @@ -29,12 +30,18 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): import random as _r super().__init__( - PoliceOfficerGenerator(dataset=dataset, rng=rng or _r.Random(), demographic_config=demographic_config), + PoliceOfficerGenerator( + dataset=dataset, + rng=rng or _r.Random(), + demographic_config=demographic_config, + demographic_sampler=demographic_sampler, + ), PoliceOfficer, ) diff --git a/datamimic_ce/domains/utils/rng_uuid.py b/datamimic_ce/domains/utils/rng_uuid.py index 98212266..81db6b6b 100644 --- a/datamimic_ce/domains/utils/rng_uuid.py +++ b/datamimic_ce/domains/utils/rng_uuid.py @@ -21,7 +21,7 @@ def uuid4_from_random(rng: Random) -> str: """ value = rng.getrandbits(128) - # WHY: Enforce UUID version/variant bits so downstream validators accept the value. + # Enforce UUID version/variant bits so downstream validators accept the value. value &= ~(0xF << 76) value |= 0x4 << 76 # Set version 4 value &= ~(0x3 << 62) diff --git a/datamimic_ce/model/demographics_model.py b/datamimic_ce/model/demographics_model.py new file mode 100644 index 00000000..c2d19810 --- /dev/null +++ b/datamimic_ce/model/demographics_model.py @@ -0,0 +1,23 @@ +"""Pydantic model for XML node.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field, model_validator + +from datamimic_ce.constants.attribute_constants import ATTR_DATASET, ATTR_DIR +from datamimic_ce.model.model_util import ModelUtil + + +class DemographicsModel(BaseModel): + dataset: str = Field(..., alias=ATTR_DATASET) + version: str + directory: str = Field(..., alias=ATTR_DIR) + rng_seed: int | None = Field(None, alias="rngSeed") + + @model_validator(mode="before") + @classmethod + def check_valid_attributes(cls, values: dict): + return ModelUtil.check_valid_attributes( + values=values, + valid_attributes={ATTR_DATASET, "version", ATTR_DIR, "rngSeed"}, + ) diff --git a/datamimic_ce/parsers/demographics_parser.py b/datamimic_ce/parsers/demographics_parser.py new file mode 100644 index 00000000..63fb9e00 --- /dev/null +++ b/datamimic_ce/parsers/demographics_parser.py @@ -0,0 +1,25 @@ +"""Parser for elements.""" + +from __future__ import annotations + +from pathlib import Path +from xml.etree.ElementTree import Element + +from datamimic_ce.constants.element_constants import EL_DEMOGRAPHICS +from datamimic_ce.model.demographics_model import DemographicsModel +from datamimic_ce.parsers.statement_parser import StatementParser +from datamimic_ce.statements.demographics_statement import DemographicsStatement + + +class DemographicsParser(StatementParser): + def __init__(self, element: Element, properties: dict | None): + super().__init__(element, properties, valid_element_tag=EL_DEMOGRAPHICS) + + def parse(self, descriptor_dir: Path) -> DemographicsStatement: + model = self.validate_attributes(DemographicsModel) + directory = Path(model.directory) + if not directory.is_absolute(): + # Resolve relative demographic directories against the descriptor to keep XML portable. + directory = (descriptor_dir / directory).resolve() + model.directory = str(directory) + return DemographicsStatement(model) diff --git a/datamimic_ce/parsers/parser_util.py b/datamimic_ce/parsers/parser_util.py index b34a6468..aabd1ebd 100644 --- a/datamimic_ce/parsers/parser_util.py +++ b/datamimic_ce/parsers/parser_util.py @@ -16,6 +16,7 @@ EL_ARRAY, EL_CONDITION, EL_DATABASE, + EL_DEMOGRAPHICS, EL_ECHO, EL_ELEMENT, EL_ELSE, @@ -97,6 +98,7 @@ def get_valid_sub_elements_set_by_tag(ele_tag: str) -> set | None: EL_ECHO, EL_VARIABLE, EL_GENERATOR, + EL_DEMOGRAPHICS, }, EL_NESTED_KEY: { EL_KEY, @@ -188,6 +190,10 @@ def _get_parser_by_element(element: Element, properties: dict): return ElementParser(element, properties) elif tag == EL_GENERATOR: return GeneratorParser(element, properties) + elif tag == EL_DEMOGRAPHICS: + from datamimic_ce.parsers.demographics_parser import DemographicsParser + + return DemographicsParser(element, properties) else: raise ValueError(f"Cannot get parser for element <{tag}>") diff --git a/datamimic_ce/statements/demographics_statement.py b/datamimic_ce/statements/demographics_statement.py new file mode 100644 index 00000000..fa42f3cc --- /dev/null +++ b/datamimic_ce/statements/demographics_statement.py @@ -0,0 +1,31 @@ +"""Statement object representing the node.""" + +from __future__ import annotations + +from datamimic_ce.model.demographics_model import DemographicsModel +from datamimic_ce.statements.statement import Statement + + +class DemographicsStatement(Statement): + def __init__(self, model: DemographicsModel): + super().__init__(None, None) + self._dataset = model.dataset + self._version = model.version + self._directory = model.directory + self._rng_seed = model.rng_seed + + @property + def dataset(self) -> str: + return self._dataset + + @property + def version(self) -> str: + return self._version + + @property + def directory(self) -> str: + return self._directory + + @property + def rng_seed(self) -> int | None: + return self._rng_seed diff --git a/datamimic_ce/tasks/demographics_task.py b/datamimic_ce/tasks/demographics_task.py new file mode 100644 index 00000000..633e6a67 --- /dev/null +++ b/datamimic_ce/tasks/demographics_task.py @@ -0,0 +1,37 @@ +"""Task that installs demographic profiles into the setup context.""" + +from __future__ import annotations + +from pathlib import Path +from random import Random + +from datamimic_ce.contexts.demographic_context import DemographicContext +from datamimic_ce.contexts.setup_context import SetupContext +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler +from datamimic_ce.domains.common.models.demographic_config import DemographicConfig +from datamimic_ce.statements.demographics_statement import DemographicsStatement +from datamimic_ce.tasks.task import SetupSubTask + + +class DemographicsTask(SetupSubTask): + def __init__(self, statement: DemographicsStatement): + self._statement = statement + + def execute(self, ctx: SetupContext) -> None: + directory = Path(self._statement.directory) + profile = load_demographic_profile(directory, self._statement.dataset, self._statement.version) + sampler = DemographicSampler(profile) + rng = Random(self._statement.rng_seed) if self._statement.rng_seed is not None else Random() + demographic_context = DemographicContext( + profile_id=profile.profile_id, + sampler=sampler, + overrides=DemographicConfig(), + rng=rng, + ) + # Store the context once so every entity derives deterministic child RNGs without hidden globals. + ctx.set_demographic_context(demographic_context) + + @property + def statement(self) -> DemographicsStatement: + return self._statement diff --git a/datamimic_ce/tasks/task_util.py b/datamimic_ce/tasks/task_util.py index 7d1ae1eb..4b570635 100644 --- a/datamimic_ce/tasks/task_util.py +++ b/datamimic_ce/tasks/task_util.py @@ -41,6 +41,7 @@ from datamimic_ce.statements.array_statement import ArrayStatement from datamimic_ce.statements.condition_statement import ConditionStatement from datamimic_ce.statements.database_statement import DatabaseStatement +from datamimic_ce.statements.demographics_statement import DemographicsStatement from datamimic_ce.statements.echo_statement import EchoStatement from datamimic_ce.statements.element_statement import ElementStatement from datamimic_ce.statements.else_if_statement import ElseIfStatement @@ -127,6 +128,10 @@ def get_task_by_statement( from datamimic_ce.tasks.condition_task import ConditionTask return ConditionTask(stmt) + elif isinstance(stmt, DemographicsStatement): + from datamimic_ce.tasks.demographics_task import DemographicsTask + + return DemographicsTask(stmt) elif isinstance(stmt, ElseIfStatement): from datamimic_ce.tasks.else_if_task import ElseIfTask diff --git a/datamimic_ce/tasks/variable_task.py b/datamimic_ce/tasks/variable_task.py index 2260b0a8..a35ff32e 100644 --- a/datamimic_ce/tasks/variable_task.py +++ b/datamimic_ce/tasks/variable_task.py @@ -249,6 +249,7 @@ def _get_entity_generator( entity_class_name, kwargs = StringUtil.parse_constructor_string(entity_name) # Inject dataset if not explicitly provided in constructor kwargs.setdefault("dataset", dataset) + demographic_context = getattr(ctx.root, "demographic_context", None) # Build demographic config + rng from statement attributes when present demo_cfg = None if any( @@ -272,6 +273,13 @@ def _get_entity_generator( conditions_exclude=excludes, ) rng_obj = Random(statement.rng_seed) if statement.rng_seed is not None else None + if demo_cfg is None and demographic_context is not None and demographic_context.overrides is not None: + # Share profile-level defaults when no per-variable overrides are provided. + demo_cfg = demographic_context.overrides + if rng_obj is None and demographic_context is not None: + # Derive entity-level RNGs from the demographics root seed to keep sampling reproducible. + rng_obj = Random(demographic_context.rng.randrange(2**63)) + demographic_sampler = demographic_context.sampler if demographic_context is not None else None # Build from the last parsed VariableTask (self is not accessible in staticmethod); use closure via locals() # Check if entity_class_name contains dots indicating a domain path @@ -318,16 +326,22 @@ def _get_entity_generator( if entity_class_name in entity_mappings: domain_entity_path = entity_mappings[entity_class_name] # Only attach demographic_config/rng for supported services - if demo_cfg is not None and entity_class_name in { + supported_demographic_entities = { "Patient", "Doctor", "PoliceOfficer", "Person", + } + config_enabled_entities = supported_demographic_entities | { "InsurancePolicy", "MedicalDevice", "CreditCard", - }: - kwargs["demographic_config"] = demo_cfg + } + if demo_cfg is not None and entity_class_name in config_enabled_entities: + kwargs.setdefault("demographic_config", demo_cfg) + if demographic_sampler is not None and entity_class_name in supported_demographic_entities: + # Thread pure sampler through to entities that know how to consume demographics. + kwargs.setdefault("demographic_sampler", demographic_sampler) if rng_obj is not None and entity_class_name in { "Patient", "Doctor", diff --git a/docs/demographics.md b/docs/demographics.md new file mode 100644 index 00000000..348fe5fa --- /dev/null +++ b/docs/demographics.md @@ -0,0 +1,69 @@ +# Demographic Profiles + +Demographic profiles provide reusable population priors for age/sex distributions and +condition prevalence. Profiles live as versioned CSV bundles that are loaded at runtime +and exposed through an explicit `DemographicContext` so generators can honour priors +without relying on globals. + +## `.dmgrp.csv` contract + +All demographic files share the `.dmgrp.csv` suffix to simplify discovery and +versioning. A complete profile currently consists of: + +- `age_pyramid.dmgrp.csv` + - Columns: `dataset,version,sex,age_min,age_max,weight` + - Weights normalise to 1.0 per sex bucket (or overall when `sex` is empty). +- `condition_rates.dmgrp.csv` + - Columns: `dataset,version,condition,sex,age_min,age_max,prevalence` + - `prevalence` is expressed as a probability in `[0,1]`. +- `profile_meta.dmgrp.csv` (optional metadata placeholder) + - Columns: `dataset,version,source,notes,checksum` + +Additional `.dmgrp.csv` files are ignored until the corresponding loader support is +implemented, which keeps the bundle forward compatible. + +## XML usage + +Declare demographics inside `` to make the profile and seeded sampler available +to all entities generated within the descriptor: + +```xml + + + + + + + + +``` + +- The loader validates dataset/version columns, ensures age buckets do not overlap and + raises descriptive errors for malformed rows. +- Each entity supported by demographics (`Person`, `Patient`, `Doctor`, + `PoliceOfficer`) receives the sampler together with an entity-scoped RNG derived from + the `` `rngSeed`. +- `DemographicConfig` overrides (`ageMin`, `ageMax`, `conditionsInclude`, + `conditionsExclude`) remain available per variable and win over sampler priors. + +## Determinism and seeds + +The `` node seeds a root RNG. Variable-level `rngSeed` attributes still +work; when omitted, the task derives child RNGs from the demographic context so every +entity consumes disjoint deterministic streams. Loader validation is pure, and the +sampler has no hidden state, which keeps test runs reproducible. + +## Limitations + +- Sex selection defaults to a uniform choice across available sex buckets until richer + marginal distributions are supplied. +- Condition prevalence is modelled as independent Bernoulli trials; comorbidity pairs + are not sampled yet. +- Birthdates outside `[0,100]` trigger warnings so unrealistic priors can be spotted + early. + +## Roadmap + +Pairwise odds ratios will eventually be supported behind a feature flag. Until then, +set `DEMOGRAPHY_PAIRS=false` in the environment to document the current limitation. diff --git a/tests_ce/integration_tests/test_demographics/DE/2023Q4/age_pyramid.dmgrp.csv b/tests_ce/integration_tests/test_demographics/DE/2023Q4/age_pyramid.dmgrp.csv new file mode 100644 index 00000000..5f2a2768 --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/DE/2023Q4/age_pyramid.dmgrp.csv @@ -0,0 +1,9 @@ +dataset,version,sex,age_min,age_max,weight +DE,2023Q4,F,0,17,0.18 +DE,2023Q4,F,18,39,0.28 +DE,2023Q4,F,40,64,0.34 +DE,2023Q4,F,65,100,0.20 +DE,2023Q4,M,0,17,0.19 +DE,2023Q4,M,18,39,0.29 +DE,2023Q4,M,40,64,0.33 +DE,2023Q4,M,65,100,0.19 diff --git a/tests_ce/integration_tests/test_demographics/DE/2023Q4/condition_rates.dmgrp.csv b/tests_ce/integration_tests/test_demographics/DE/2023Q4/condition_rates.dmgrp.csv new file mode 100644 index 00000000..e645f3b5 --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/DE/2023Q4/condition_rates.dmgrp.csv @@ -0,0 +1,5 @@ +dataset,version,condition,sex,age_min,age_max,prevalence +DE,2023Q4,Diabetes,F,65,100,0.19 +DE,2023Q4,Diabetes,M,65,100,0.22 +DE,2023Q4,Hypertension,,40,64,0.27 +DE,2023Q4,Hypertension,,65,100,0.48 diff --git a/tests_ce/integration_tests/test_demographics/DE/2023Q4/profile_meta.dmgrp.csv b/tests_ce/integration_tests/test_demographics/DE/2023Q4/profile_meta.dmgrp.csv new file mode 100644 index 00000000..749a1a4c --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/DE/2023Q4/profile_meta.dmgrp.csv @@ -0,0 +1,2 @@ +dataset,version,source,notes,checksum +DE,2023Q4,example,illustrative-only,abc123 diff --git a/tests_ce/integration_tests/test_demographics/demo.xml b/tests_ce/integration_tests/test_demographics/demo.xml new file mode 100644 index 00000000..748a81bb --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/demo.xml @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/tests_ce/integration_tests/test_demographics/test_demographics.py b/tests_ce/integration_tests/test_demographics/test_demographics.py new file mode 100644 index 00000000..28444ce7 --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/test_demographics.py @@ -0,0 +1,18 @@ +# DATAMIMIC +# Copyright (c) 2023-2025 Rapiddweller Asia Co., Ltd. +# This software is licensed under the MIT License. +# See LICENSE file for the full text of the license. +# For questions and support, contact: info@rapiddweller.com + + +from pathlib import Path + +from datamimic_ce.data_mimic_test import DataMimicTest + + +class TestDemographics: + _test_dir = Path(__file__).resolve().parent + + def test_simple(self): + test_engine = DataMimicTest(test_dir=self._test_dir, filename="demo.xml") + test_engine.test_with_timer() diff --git a/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py b/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py index 11098104..53119c44 100644 --- a/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py +++ b/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py @@ -10,7 +10,7 @@ def test_other_gender_returns_string_when_quota_hits() -> None: generator = NobilityTitleGenerator(dataset="US", noble_quota=1.0, rng=Random(1337)) - # WHY: Ensure non-binary genders still receive a string title when quota triggers. + # Ensure non-binary genders still receive a string title when quota triggers. title = generator.generate_with_gender("other") assert isinstance(title, str) diff --git a/tests_ce/unit_tests/tests_demographics/data/age_pyramid.dmgrp.csv b/tests_ce/unit_tests/tests_demographics/data/age_pyramid.dmgrp.csv new file mode 100644 index 00000000..03111ea4 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/data/age_pyramid.dmgrp.csv @@ -0,0 +1,7 @@ +dataset,version,sex,age_min,age_max,weight +TEST,v1,F,0,17,0.2 +TEST,v1,F,18,35,0.3 +TEST,v1,F,36,80,0.5 +TEST,v1,M,0,17,0.25 +TEST,v1,M,18,35,0.35 +TEST,v1,M,36,80,0.4 diff --git a/tests_ce/unit_tests/tests_demographics/data/condition_rates.dmgrp.csv b/tests_ce/unit_tests/tests_demographics/data/condition_rates.dmgrp.csv new file mode 100644 index 00000000..a1c92578 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/data/condition_rates.dmgrp.csv @@ -0,0 +1,6 @@ +dataset,version,condition,sex,age_min,age_max,prevalence +TEST,v1,Asthma,,0,17,0.1 +TEST,v1,Asthma,,18,80,0.05 +TEST,v1,Diabetes,F,36,80,0.15 +TEST,v1,Diabetes,M,36,80,0.12 +TEST,v1,Hypertension,,36,80,0.3 diff --git a/tests_ce/unit_tests/tests_demographics/test_distribution_fit.py b/tests_ce/unit_tests/tests_demographics/test_distribution_fit.py new file mode 100644 index 00000000..74b71597 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_distribution_fit.py @@ -0,0 +1,54 @@ +"""Statistical checks for demographic sampling.""" + +from __future__ import annotations + +import shutil +import sys +from collections import Counter +from pathlib import Path +from random import Random + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pytest + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler + +_test_dir = Path(__file__).resolve().parent + +@pytest.fixture() +def profile_and_sampler(tmp_path: Path): + fixture_dir = Path(_test_dir / "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + profile = load_demographic_profile(tmp_path, "TEST", "v1") + return profile, DemographicSampler(profile) + + +def test_age_distribution_matches_profile(profile_and_sampler) -> None: + profile, sampler = profile_and_sampler + rng = Random(2024) + sexes = [sex for sex in profile.sexes() if sex is not None] + counts = {sex: Counter() for sex in sexes} + samples = 50000 + for _ in range(samples): + age, sex = sampler.sample_age_sex(rng) + bands = profile.bands_for_sex(sex) + for band in bands: + if band.contains(age): + counts[sex][(band.age_min, band.age_max)] += 1 + break + chi_square = 0.0 + degrees_of_freedom = 0 + expected_per_sex = samples / max(len(sexes), 1) + for sex in sexes: + bands = profile.bands_for_sex(sex) + degrees_of_freedom += max(len(bands) - 1, 0) + for band in bands: + expected = expected_per_sex * band.weight + observed = counts[sex][(band.age_min, band.age_max)] + chi_square += (observed - expected) ** 2 / expected + # Critical value for df=4, alpha=0.05 is ~9.488 + assert degrees_of_freedom == 4 + assert chi_square <= 9.488 diff --git a/tests_ce/unit_tests/tests_demographics/test_loader_profile.py b/tests_ce/unit_tests/tests_demographics/test_loader_profile.py new file mode 100644 index 00000000..3ec57c5c --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_loader_profile.py @@ -0,0 +1,34 @@ +"""Loader tests for demographic profiles.""" + +from __future__ import annotations + +import shutil +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pytest + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile + +_test_dir = Path(__file__).resolve().parent + +@pytest.fixture() +def profile_dir(tmp_path: Path) -> Path: + fixture_dir = Path(_test_dir/ "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + return tmp_path + + +def test_load_profile_normalizes_and_indexes(profile_dir: Path) -> None: + profile = load_demographic_profile(profile_dir, "TEST", "v1") + female_bands = profile.bands_for_sex("F") + male_bands = profile.bands_for_sex("M") + assert len(female_bands) == 3 + assert len(male_bands) == 3 + assert pytest.approx(sum(b.weight for b in female_bands), rel=1e-9) == 1.0 + assert pytest.approx(sum(b.weight for b in male_bands), rel=1e-9) == 1.0 + rates = profile.conditions_for("Hypertension") + assert rates and rates[0].prevalence == pytest.approx(0.3) diff --git a/tests_ce/unit_tests/tests_demographics/test_overrides_precedence.py b/tests_ce/unit_tests/tests_demographics/test_overrides_precedence.py new file mode 100644 index 00000000..004affbc --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_overrides_precedence.py @@ -0,0 +1,45 @@ +"""Tests ensuring demographic overrides trump sampler priors.""" + +from __future__ import annotations + +import shutil +import sys +from pathlib import Path +from random import Random + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler, DemographicSample +from datamimic_ce.domains.common.models.demographic_config import DemographicConfig +from datamimic_ce.domains.healthcare.generators.patient_generator import PatientGenerator + +_test_dir = Path(__file__).resolve().parent + +def _load_sampler(tmp_path: Path) -> DemographicSampler: + fixture_dir = Path(_test_dir / "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + profile = load_demographic_profile(tmp_path, "TEST", "v1") + return DemographicSampler(profile) + + +def test_overrides_adjust_sampled_conditions(tmp_path: Path) -> None: + sampler = _load_sampler(tmp_path) + rng = Random(101) + age, sex = sampler.sample_age_sex(rng) + sampled_conditions = sampler.sample_conditions(age, sex, rng) + demo_sample = DemographicSample(age=age, sex=sex, conditions=sampled_conditions) + overrides = DemographicConfig( + conditions_include=frozenset({"Hypertension"}), + conditions_exclude=frozenset({"Asthma"}), + ) + generator = PatientGenerator( + dataset="US", + demographic_config=overrides, + demographic_sampler=sampler, + rng=Random(2024), + ) + conditions = generator.generate_age_appropriate_conditions(age, demo_sample) + assert "Hypertension" in set(conditions) + assert all(cond.lower() != "asthma" for cond in conditions) diff --git a/tests_ce/unit_tests/tests_demographics/test_sampler_determinism.py b/tests_ce/unit_tests/tests_demographics/test_sampler_determinism.py new file mode 100644 index 00000000..0e41d251 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_sampler_determinism.py @@ -0,0 +1,46 @@ +"""Determinism tests for the demographic sampler.""" + +from __future__ import annotations + +import shutil +import sys +from pathlib import Path +from random import Random + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pytest + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler + +_test_dir = Path(__file__).resolve().parent + +@pytest.fixture() +def sampler(tmp_path: Path) -> DemographicSampler: + fixture_dir = Path(_test_dir / "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + profile = load_demographic_profile(tmp_path, "TEST", "v1") + return DemographicSampler(profile) + + +def test_age_sex_sampling_is_deterministic(sampler: DemographicSampler) -> None: + rng_a = Random(123) + rng_b = Random(123) + sequence_a = [sampler.sample_age_sex(rng_a) for _ in range(10)] + sequence_b = [sampler.sample_age_sex(rng_b) for _ in range(10)] + assert sequence_a == sequence_b + + +def test_condition_sampling_is_deterministic(sampler: DemographicSampler) -> None: + rng_a = Random(987) + rng_b = Random(987) + draws_a = [] + draws_b = [] + for _ in range(10): + age_a, sex_a = sampler.sample_age_sex(rng_a) + draws_a.append(sampler.sample_conditions(age_a, sex_a, rng_a)) + age_b, sex_b = sampler.sample_age_sex(rng_b) + draws_b.append(sampler.sample_conditions(age_b, sex_b, rng_b)) + assert draws_a == draws_b