From 8f14444d000bcbab11ead47c7b7ad87d12262989 Mon Sep 17 00:00:00 2001 From: Alexander Kell Date: Mon, 6 Oct 2025 08:23:16 +0700 Subject: [PATCH] # Demographic Profiles Demographic profiles provide reusable population priors for age/sex distributions and condition prevalence. Profiles live as versioned CSV bundles that are loaded at runtime and exposed through an explicit `DemographicContext` so generators can honour priors without relying on globals. ## `.dmgrp.csv` contract All demographic files share the `.dmgrp.csv` suffix to simplify discovery and versioning. A complete profile currently consists of: - `age_pyramid.dmgrp.csv` - Columns: `dataset,version,sex,age_min,age_max,weight` - Weights normalise to 1.0 per sex bucket (or overall when `sex` is empty). - `condition_rates.dmgrp.csv` - Columns: `dataset,version,condition,sex,age_min,age_max,prevalence` - `prevalence` is expressed as a probability in `[0,1]`. - `profile_meta.dmgrp.csv` (optional metadata placeholder) - Columns: `dataset,version,source,notes,checksum` Additional `.dmgrp.csv` files are ignored until the corresponding loader support is implemented, which keeps the bundle forward compatible. ## XML usage Declare demographics inside `` to make the profile and seeded sampler available to all entities generated within the descriptor: ```xml ``` - The loader validates dataset/version columns, ensures age buckets do not overlap and raises descriptive errors for malformed rows. - Each entity supported by demographics (`Person`, `Patient`, `Doctor`, `PoliceOfficer`) receives the sampler together with an entity-scoped RNG derived from the `` `rngSeed`. - `DemographicConfig` overrides (`ageMin`, `ageMax`, `conditionsInclude`, `conditionsExclude`) remain available per variable and win over sampler priors. ## Determinism and seeds The `` node seeds a root RNG. Variable-level `rngSeed` attributes still work; when omitted, the task derives child RNGs from the demographic context so every entity consumes disjoint deterministic streams. Loader validation is pure, and the sampler has no hidden state, which keeps test runs reproducible. ## Limitations - Sex selection defaults to a uniform choice across available sex buckets until richer marginal distributions are supplied. - Condition prevalence is modelled as independent Bernoulli trials; comorbidity pairs are not sampled yet. - Birthdates outside `[0,100]` trigger warnings so unrealistic priors can be spotted early. --- datamimic_ce/constants/attribute_constants.py | 1 + datamimic_ce/constants/element_constants.py | 1 + datamimic_ce/contexts/demographic_context.py | 20 ++ datamimic_ce/contexts/setup_context.py | 12 ++ .../domains/common/demographics/__init__.py | 15 ++ .../domains/common/demographics/loader.py | 197 ++++++++++++++++++ .../domains/common/demographics/profile.py | 80 +++++++ .../domains/common/demographics/sampler.py | 118 +++++++++++ .../common/generators/person_generator.py | 27 ++- .../data_faker_generator.py | 2 +- .../nobility_title_generator.py | 2 +- datamimic_ce/domains/common/models/person.py | 16 ++ .../domains/common/services/person_service.py | 4 + .../healthcare/generators/doctor_generator.py | 8 +- .../generators/patient_generator.py | 21 +- .../domains/healthcare/models/patient.py | 4 +- .../healthcare/services/doctor_service.py | 10 +- .../healthcare/services/patient_service.py | 9 +- .../generators/police_officer_generator.py | 8 +- .../services/police_officer_service.py | 9 +- datamimic_ce/domains/utils/rng_uuid.py | 2 +- datamimic_ce/model/demographics_model.py | 23 ++ datamimic_ce/parsers/demographics_parser.py | 25 +++ datamimic_ce/parsers/parser_util.py | 6 + .../statements/demographics_statement.py | 31 +++ datamimic_ce/tasks/demographics_task.py | 37 ++++ datamimic_ce/tasks/task_util.py | 5 + datamimic_ce/tasks/variable_task.py | 20 +- docs/demographics.md | 69 ++++++ .../DE/2023Q4/age_pyramid.dmgrp.csv | 9 + .../DE/2023Q4/condition_rates.dmgrp.csv | 5 + .../DE/2023Q4/profile_meta.dmgrp.csv | 2 + .../test_demographics/demo.xml | 8 + .../test_demographics/test_demographics.py | 18 ++ .../test_nobility_title_generator.py | 2 +- .../data/age_pyramid.dmgrp.csv | 7 + .../data/condition_rates.dmgrp.csv | 6 + .../test_distribution_fit.py | 54 +++++ .../tests_demographics/test_loader_profile.py | 34 +++ .../test_overrides_precedence.py | 45 ++++ .../test_sampler_determinism.py | 46 ++++ 41 files changed, 1000 insertions(+), 18 deletions(-) create mode 100644 datamimic_ce/contexts/demographic_context.py create mode 100644 datamimic_ce/domains/common/demographics/__init__.py create mode 100644 datamimic_ce/domains/common/demographics/loader.py create mode 100644 datamimic_ce/domains/common/demographics/profile.py create mode 100644 datamimic_ce/domains/common/demographics/sampler.py create mode 100644 datamimic_ce/model/demographics_model.py create mode 100644 datamimic_ce/parsers/demographics_parser.py create mode 100644 datamimic_ce/statements/demographics_statement.py create mode 100644 datamimic_ce/tasks/demographics_task.py create mode 100644 docs/demographics.md create mode 100644 tests_ce/integration_tests/test_demographics/DE/2023Q4/age_pyramid.dmgrp.csv create mode 100644 tests_ce/integration_tests/test_demographics/DE/2023Q4/condition_rates.dmgrp.csv create mode 100644 tests_ce/integration_tests/test_demographics/DE/2023Q4/profile_meta.dmgrp.csv create mode 100644 tests_ce/integration_tests/test_demographics/demo.xml create mode 100644 tests_ce/integration_tests/test_demographics/test_demographics.py create mode 100644 tests_ce/unit_tests/tests_demographics/data/age_pyramid.dmgrp.csv create mode 100644 tests_ce/unit_tests/tests_demographics/data/condition_rates.dmgrp.csv create mode 100644 tests_ce/unit_tests/tests_demographics/test_distribution_fit.py create mode 100644 tests_ce/unit_tests/tests_demographics/test_loader_profile.py create mode 100644 tests_ce/unit_tests/tests_demographics/test_overrides_precedence.py create mode 100644 tests_ce/unit_tests/tests_demographics/test_sampler_determinism.py diff --git a/datamimic_ce/constants/attribute_constants.py b/datamimic_ce/constants/attribute_constants.py index 597dbf97..4dd4a6b8 100644 --- a/datamimic_ce/constants/attribute_constants.py +++ b/datamimic_ce/constants/attribute_constants.py @@ -70,6 +70,7 @@ ATTR_DEFAULT_VARIABLE_SUFFIX: Final = "defaultVariableSuffix" ATTR_VARIABLE_PREFIX: Final = "variablePrefix" ATTR_VARIABLE_SUFFIX: Final = "variableSuffix" +ATTR_DIR: Final = "dir" ATTR_STRING: Final = "string" ATTR_BUCKET: Final = "bucket" ATTR_MP_PLATFORM: Final = "mpPlatform" diff --git a/datamimic_ce/constants/element_constants.py b/datamimic_ce/constants/element_constants.py index ee06e00e..2d8fbbfa 100644 --- a/datamimic_ce/constants/element_constants.py +++ b/datamimic_ce/constants/element_constants.py @@ -25,3 +25,4 @@ EL_CONDITION = "condition" EL_ELSE_IF = "else-if" EL_ELSE = "else" +EL_DEMOGRAPHICS = "demographics" diff --git a/datamimic_ce/contexts/demographic_context.py b/datamimic_ce/contexts/demographic_context.py new file mode 100644 index 00000000..aec6f3ae --- /dev/null +++ b/datamimic_ce/contexts/demographic_context.py @@ -0,0 +1,20 @@ +"""Context wiring for demographic samplers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from random import Random + +from datamimic_ce.domains.common.demographics.profile import DemographicProfileId +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler +from datamimic_ce.domains.common.models.demographic_config import DemographicConfig + + +@dataclass(frozen=True) +class DemographicContext: + """Immutable container for the active demographic profile.""" + + profile_id: DemographicProfileId + sampler: DemographicSampler + overrides: DemographicConfig | None + rng: Random diff --git a/datamimic_ce/contexts/setup_context.py b/datamimic_ce/contexts/setup_context.py index 9bd8895e..5c9cf7b7 100644 --- a/datamimic_ce/contexts/setup_context.py +++ b/datamimic_ce/contexts/setup_context.py @@ -11,6 +11,7 @@ from datamimic_ce.clients.database_client import Client from datamimic_ce.contexts.context import Context +from datamimic_ce.contexts.demographic_context import DemographicContext from datamimic_ce.converter.converter import Converter from datamimic_ce.converter.custom_converter import CustomConverter from datamimic_ce.domains.domain_core.base_literal_generator import BaseLiteralGenerator @@ -49,6 +50,7 @@ def __init__( generators: dict | None = None, default_source_scripted: bool | None = None, report_logging: bool = True, + demographic_context: DemographicContext | None = None, ): # SetupContext is always its root_context super().__init__(self) @@ -86,6 +88,7 @@ def __init__( self._report_logging = report_logging self._current_seed = current_seed self._task_exporters: dict[str, dict[str, Any]] = {} + self._demographic_context = demographic_context def __deepcopy__(self, memo): """ @@ -125,6 +128,7 @@ def __deepcopy__(self, memo): default_source_scripted=self._default_source_scripted, report_logging=copy.deepcopy(self._report_logging), current_seed=self._current_seed, + demographic_context=copy.deepcopy(self._demographic_context, memo), ) def _deepcopy_clients(self, memo): @@ -226,6 +230,14 @@ def update_with_stmt(self, stmt: SetupStatement): if value is not None: setattr(self, key, value) + @property + def demographic_context(self) -> DemographicContext | None: + return self._demographic_context + + def set_demographic_context(self, context: DemographicContext) -> None: + # Keep demographics explicit on the root context instead of mutable module globals. + self._demographic_context = context + @property def clients(self) -> dict: return self._clients diff --git a/datamimic_ce/domains/common/demographics/__init__.py b/datamimic_ce/domains/common/demographics/__init__.py new file mode 100644 index 00000000..6958999a --- /dev/null +++ b/datamimic_ce/domains/common/demographics/__init__.py @@ -0,0 +1,15 @@ +"""Demographic profile domain package.""" + +from .loader import DemographicProfileError, load_demographic_profile +from .profile import DemographicProfile, DemographicProfileId, normalize_sex +from .sampler import DemographicSample, DemographicSampler + +__all__ = [ + "DemographicProfile", + "DemographicProfileId", + "DemographicProfileError", + "DemographicSampler", + "DemographicSample", + "load_demographic_profile", + "normalize_sex", +] diff --git a/datamimic_ce/domains/common/demographics/loader.py b/datamimic_ce/domains/common/demographics/loader.py new file mode 100644 index 00000000..38294890 --- /dev/null +++ b/datamimic_ce/domains/common/demographics/loader.py @@ -0,0 +1,197 @@ +"""CSV loader for demographic profiles.""" + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from pathlib import Path + +from datamimic_ce.logger import logger +from datamimic_ce.utils.file_util import FileUtil + +from .profile import ( + DemographicAgeBand, + DemographicConditionRate, + DemographicProfile, + DemographicProfileId, + SexKey, + normalize_sex, +) + +_REQUIRED_FILES = { + "age_pyramid.dmgrp.csv", + "condition_rates.dmgrp.csv", +} + + +class DemographicProfileError(ValueError): + """Raised when demographic CSV files are invalid.""" + + +def load_demographic_profile(directory: Path, dataset: str, version: str) -> DemographicProfile: + """Load a demographic profile from the given directory.""" + + base_dir = Path(directory) + if not base_dir.exists(): + raise DemographicProfileError(f"Demographic directory '{base_dir}' does not exist") + + files = {path.name: path for path in base_dir.glob("*.dmgrp.csv")} + missing = _REQUIRED_FILES.difference(files) + if missing: + raise DemographicProfileError( + f"Missing demographic files {sorted(missing)} in '{base_dir}'. Ensure required CSVs exist." + ) + + age_bands = _load_age_bands(files["age_pyramid.dmgrp.csv"], dataset, version) + condition_rates = _load_condition_rates(files["condition_rates.dmgrp.csv"], dataset, version) + + profile = DemographicProfile( + profile_id=DemographicProfileId(dataset=dataset, version=version), + age_bands=age_bands, + condition_rates=condition_rates, + ) + return profile + + +def _load_age_bands(file_path: Path, dataset: str, version: str) -> dict[SexKey, tuple[DemographicAgeBand, ...]]: + rows = FileUtil.read_csv_to_dict_list(file_path, separator=",") + grouped: defaultdict[SexKey, list[DemographicAgeBand]] = defaultdict(list) + for idx, row in enumerate(rows, start=2): + _ensure_dataset_version(row, dataset, version, file_path, idx) + sex = normalize_sex(row.get("sex")) + try: + age_min = int(row["age_min"]) + age_max = int(row["age_max"]) + weight = float(row["weight"]) + except (TypeError, ValueError) as exc: + raise DemographicProfileError( + f"Invalid numeric value in '{file_path}' line {idx}: {exc}." + " Expected integers for age_min/age_max and float for weight." + ) from exc + if age_min > age_max: + raise DemographicProfileError( + f"age_min must be <= age_max in '{file_path}' line {idx}: got {age_min}>{age_max}." + ) + if weight < 0: + raise DemographicProfileError(f"weight must be non-negative in '{file_path}' line {idx}: got {weight}.") + grouped[sex].append( + DemographicAgeBand( + sex=sex, + age_min=age_min, + age_max=age_max, + weight=weight, + ) + ) + + normalized: dict[SexKey, tuple[DemographicAgeBand, ...]] = {} + for sex, bands in grouped.items(): + sorted_bands = sorted(bands, key=lambda b: (b.age_min, b.age_max)) + _validate_band_coverage(sorted_bands, file_path, sex) + total_weight = sum(b.weight for b in sorted_bands) + if not _is_close(total_weight, 1.0): + raise DemographicProfileError( + f"Weights must sum to 1.0 per sex in '{file_path}' for sex='{sex or ''}' (sum={total_weight:.6f})." + ) + normalized[sex] = tuple(sorted_bands) + if not normalized: + raise DemographicProfileError(f"No rows parsed from '{file_path}'.") + return normalized + + +def _load_condition_rates( + file_path: Path, dataset: str, version: str +) -> dict[str, tuple[DemographicConditionRate, ...]]: + rows = FileUtil.read_csv_to_dict_list(file_path, separator=",") + grouped: defaultdict[str, list[DemographicConditionRate]] = defaultdict(list) + for idx, row in enumerate(rows, start=2): + _ensure_dataset_version(row, dataset, version, file_path, idx) + condition = (row.get("condition") or "").strip() + if not condition: + raise DemographicProfileError( + f"Condition name missing in '{file_path}' line {idx}. Provide canonical condition labels." + ) + sex = normalize_sex(row.get("sex")) + try: + age_min = int(row["age_min"]) + age_max = int(row["age_max"]) + prevalence = float(row["prevalence"]) + except (TypeError, ValueError) as exc: + raise DemographicProfileError( + f"Invalid numeric value in '{file_path}' line {idx}: {exc}." + " Expected integers for age_min/age_max and float for prevalence." + ) from exc + if age_min > age_max: + raise DemographicProfileError( + f"age_min must be <= age_max in '{file_path}' line {idx}: got {age_min}>{age_max}." + ) + if not 0.0 <= prevalence <= 1.0: + raise DemographicProfileError( + f"prevalence must be within [0,1] in '{file_path}' line {idx}: got {prevalence}." + ) + grouped[condition].append( + DemographicConditionRate( + condition=condition, + sex=sex, + age_min=age_min, + age_max=age_max, + prevalence=prevalence, + ) + ) + + normalized: dict[str, tuple[DemographicConditionRate, ...]] = {} + for condition, rates in grouped.items(): + normalized[condition] = tuple(sorted(rates, key=_condition_sort_key)) + return normalized + + +def _condition_sort_key(rate: DemographicConditionRate) -> tuple[int, int, int]: + # Stable ordering ensures deterministic sampling and makes tests reproducible. + return (0 if rate.sex is not None else 1, rate.age_min, rate.age_max) + + +def _ensure_dataset_version( + row: dict, + dataset: str, + version: str, + file_path: Path, + line_number: int, +) -> None: + if (row.get("dataset") or "").strip() != dataset: + raise DemographicProfileError(f"Dataset mismatch in '{file_path}' line {line_number}: expected '{dataset}'.") + if (row.get("version") or "").strip() != version: + raise DemographicProfileError(f"Version mismatch in '{file_path}' line {line_number}: expected '{version}'.") + + +def _validate_band_coverage(bands: Iterable[DemographicAgeBand], file_path: Path, sex: SexKey) -> None: + sorted_bands = list(bands) + previous = None + for band in sorted_bands: + if previous and band.age_min <= previous.age_max: + raise DemographicProfileError( + f"Overlapping age bands for sex='{sex or ''}' in '{file_path}':" + f" [{previous.age_min},{previous.age_max}] overlaps [{band.age_min},{band.age_max}]." + ) + if previous and band.age_min > previous.age_max + 1: + logger.warning( + "Gap detected between age bands for sex='%s' in '%s': [%s,%s] -> [%s,%s]", # type: ignore[str-format] + sex or "", + file_path, + previous.age_min, + previous.age_max, + band.age_min, + band.age_max, + ) + previous = band + if sorted_bands: + if sorted_bands[0].age_min > 0: + logger.warning( + "Age coverage for sex='%s' in '%s' starts at %s (>0).", sex or "", file_path, sorted_bands[0].age_min + ) + if sorted_bands[-1].age_max < 100: + logger.warning( + "Age coverage for sex='%s' in '%s' ends at %s (<100).", sex or "", file_path, sorted_bands[-1].age_max + ) + + +def _is_close(value: float, target: float, *, tolerance: float = 1e-6) -> bool: + return abs(value - target) <= tolerance diff --git a/datamimic_ce/domains/common/demographics/profile.py b/datamimic_ce/domains/common/demographics/profile.py new file mode 100644 index 00000000..c0f7d0b2 --- /dev/null +++ b/datamimic_ce/domains/common/demographics/profile.py @@ -0,0 +1,80 @@ +"""Demographic profile domain objects.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + +SexKey = str | None + + +def normalize_sex(sex: str | None) -> SexKey: + """Normalize raw sex codes coming from CSV files.""" + + if sex is None: + return None + value = sex.strip().upper() + return value or None + + +@dataclass(frozen=True) +class DemographicProfileId: + """Stable identifier for a demographic profile dataset.""" + + dataset: str + version: str + + +@dataclass(frozen=True) +class DemographicAgeBand: + """Age distribution entry scoped to a sex bucket.""" + + sex: SexKey + age_min: int + age_max: int + weight: float + + def contains(self, age: int) -> bool: + return self.age_min <= age <= self.age_max + + +@dataclass(frozen=True) +class DemographicConditionRate: + """Condition prevalence entry used for Bernoulli sampling.""" + + condition: str + sex: SexKey + age_min: int + age_max: int + prevalence: float + + def matches(self, *, age: int, sex: SexKey) -> bool: + return self.age_min <= age <= self.age_max and (self.sex is None or self.sex == sex) + + +@dataclass(frozen=True) +class DemographicProfile: + """Collection of demographic priors used by samplers.""" + + profile_id: DemographicProfileId + age_bands: Mapping[SexKey, tuple[DemographicAgeBand, ...]] + condition_rates: Mapping[str, tuple[DemographicConditionRate, ...]] + + def bands_for_sex(self, sex: SexKey) -> tuple[DemographicAgeBand, ...]: + """Return ordered bands for a given sex, falling back to combined data.""" + + normalized = normalize_sex(sex) + if normalized in self.age_bands: + return self.age_bands[normalized] + # Combined (sex-less) priors should backstop missing sex specific data without failing generation. + return self.age_bands.get(None, ()) + + def conditions_for(self, condition: str) -> tuple[DemographicConditionRate, ...]: + """Return ordered prevalence rows for a condition.""" + + return self.condition_rates.get(condition, ()) + + def sexes(self) -> Sequence[SexKey]: + """Expose the known sex buckets for downstream logic.""" + + return tuple(self.age_bands.keys()) diff --git a/datamimic_ce/domains/common/demographics/sampler.py b/datamimic_ce/domains/common/demographics/sampler.py new file mode 100644 index 00000000..4ad014a0 --- /dev/null +++ b/datamimic_ce/domains/common/demographics/sampler.py @@ -0,0 +1,118 @@ +"""Pure demographic sampler built on top of demographic profiles.""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from random import Random + +from .profile import ( + DemographicAgeBand, + DemographicConditionRate, + DemographicProfile, + DemographicProfileId, + SexKey, + normalize_sex, +) + + +@dataclass(frozen=True) +class DemographicSample: + """Sampled demographic attributes for a single entity.""" + + age: int | None + sex: SexKey + conditions: frozenset[str] + + +class DemographicSampler: + """Pure sampler that draws ages, sexes and conditions from a profile.""" + + def __init__(self, profile: DemographicProfile): + self._profile = profile + self._age_cdf = self._build_age_cdfs(profile) + self._sex_choices = self._build_sex_choices(profile) + + @property + def profile_id(self) -> DemographicProfileId: + return self._profile.profile_id + + def sample_age_sex(self, rng: Random) -> tuple[int, SexKey]: + """Sample an age band and sex using cumulative weights.""" + + sex = self._choose_sex(rng) + bands = self._age_cdf.get(sex) or self._age_cdf.get(None) + if not bands: + raise ValueError("No age bands available for demographic sampling") + pick = rng.random() + for threshold, band in bands: + if pick <= threshold: + age = rng.randint(band.age_min, band.age_max) + return age, sex + # Numerical precision might leave pick slightly above the last threshold. + last_band = bands[-1][1] + # Guarantee a return even when float rounding pushes us past the final threshold. + return rng.randint(last_band.age_min, last_band.age_max), sex + + def sample_conditions(self, age: int, sex: SexKey, rng: Random) -> frozenset[str]: + """Sample independent conditions based on prevalence tables.""" + + selected: set[str] = set() + normalized_sex = normalize_sex(sex) + for condition, rates in self._profile.condition_rates.items(): + rate = self._select_rate(rates, age, normalized_sex) + if rate is None: + continue + if rng.random() <= rate.prevalence: + selected.add(condition) + return frozenset(selected) + + def _select_rate( + self, + rates: Iterable[DemographicConditionRate], + age: int, + sex: SexKey, + ) -> DemographicConditionRate | None: + # Prefer exact sex match, otherwise fall back to combined entries. + exact_match = None + fallback = None + for rate in rates: + if not rate.matches(age=age, sex=sex): + continue + if rate.sex is None: + fallback = rate + else: + exact_match = rate + break + return exact_match or fallback + + def _build_age_cdfs(self, profile: DemographicProfile) -> dict[SexKey, list[tuple[float, DemographicAgeBand]]]: + cdfs: dict[SexKey, list[tuple[float, DemographicAgeBand]]] = {} + for sex, bands in profile.age_bands.items(): + cumulative = 0.0 + cdf: list[tuple[float, DemographicAgeBand]] = [] + for band in bands: + cumulative += band.weight + cdf.append((cumulative, band)) + if not cdf: + continue + # Normalize trailing rounding differences so cumulative ends at 1.0. + final_threshold, last_band = cdf[-1] + if abs(final_threshold - 1.0) > 1e-9: + adjust = final_threshold + cdf = [(threshold / adjust, band) for threshold, band in cdf] + cdfs[sex] = cdf + return cdfs + + def _build_sex_choices(self, profile: DemographicProfile) -> tuple[SexKey, ...]: + # Use deterministic ordering so seeded RNGs replay the same sex sequence across runs. + sexes = sorted(sex for sex in profile.sexes() if sex is not None) + if not sexes: + return (None,) + return tuple(sexes) + + def _choose_sex(self, rng: Random) -> SexKey: + if not self._sex_choices: + return None + # Treat provided sex buckets as equally likely until we collect richer priors (future roadmap). + return rng.choice(self._sex_choices) diff --git a/datamimic_ce/domains/common/generators/person_generator.py b/datamimic_ce/domains/common/generators/person_generator.py index 47540828..cfdad00b 100644 --- a/datamimic_ce/domains/common/generators/person_generator.py +++ b/datamimic_ce/domains/common/generators/person_generator.py @@ -7,9 +7,11 @@ from __future__ import annotations +from datetime import datetime from pathlib import Path from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSample, DemographicSampler from datamimic_ce.domains.common.generators.address_generator import AddressGenerator from datamimic_ce.domains.common.literal_generators.academic_title_generator import AcademicTitleGenerator from datamimic_ce.domains.common.literal_generators.birthdate_generator import BirthdateGenerator @@ -42,10 +44,12 @@ def __init__( noble_quota: float = 0.001, academic_title_quota: float = 0.5, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): self._dataset = dataset or "US" self._rng: Random = rng or Random() + self._demographic_sampler = demographic_sampler # Normalize demographic overrides once to keep SPOT and reuse downstream. resolved_config = (demographic_config or DemographicConfig()).with_defaults( default_age_min=min_age, @@ -84,12 +88,12 @@ def __init__( rng=self._derive_rng() if rng is not None else None, ) self._demographic_config = resolved_config - birth_min = self._demographic_config.age_min if self._demographic_config.age_min is not None else min_age - birth_max = self._demographic_config.age_max if self._demographic_config.age_max is not None else max_age + self._birth_min = self._demographic_config.age_min if self._demographic_config.age_min is not None else min_age + self._birth_max = self._demographic_config.age_max if self._demographic_config.age_max is not None else max_age # Clamp birthdate sampling to caller-provided bounds without scattering defaults. self._birthdate_generator = BirthdateGenerator( - min_age=birth_min, - max_age=birth_max, + min_age=self._birth_min, + max_age=self._birth_max, rng=self._derive_rng() if rng is not None else None, ) self._academic_title_generator = AcademicTitleGenerator( @@ -102,11 +106,26 @@ def __init__( noble_quota=noble_quota, rng=self._derive_rng() if rng is not None else None, ) + self._demographic_rng = self._derive_rng() if demographic_sampler is not None and rng is not None else Random() def _derive_rng(self) -> Random: # Spawn child RNGs from the base seed so seeded descriptors replay without entangling independent draws. return Random(self._rng.randrange(2**63)) if isinstance(self._rng, Random) else Random() + def reserve_demographic_sample(self) -> DemographicSample: + if self._demographic_sampler is None: + return DemographicSample(age=None, sex=None, conditions=frozenset()) + age, sex = self._demographic_sampler.sample_age_sex(self._demographic_rng) + clamped_age = max(self._birth_min, min(self._birth_max, age)) + conditions = self._demographic_sampler.sample_conditions(clamped_age, sex, self._demographic_rng) + # Sample once per entity so downstream services share consistent demographics. + return DemographicSample(age=clamped_age, sex=sex, conditions=conditions) + + def generate_birthdate_for_age(self, age: int) -> datetime: + # Dedicated generator keeps demographic birthdates independent from other literal draws. + generator = BirthdateGenerator(min_age=age, max_age=age, rng=self._derive_rng()) + return generator.generate() + @property def gender_generator(self) -> GenderGenerator: return self._gender_generator diff --git a/datamimic_ce/domains/common/literal_generators/data_faker_generator.py b/datamimic_ce/domains/common/literal_generators/data_faker_generator.py index de67c78e..c2ded6d4 100644 --- a/datamimic_ce/domains/common/literal_generators/data_faker_generator.py +++ b/datamimic_ce/domains/common/literal_generators/data_faker_generator.py @@ -36,7 +36,7 @@ def __init__( raise ValueError(f"Faker method '{method}' is not supported") self._faker = Faker(locale) if rng is not None: - # WHY: faker.Faker exposes a dynamic `random` attribute; cast to a protocol so mypy accepts the assignment. + # faker.Faker exposes a dynamic `random` attribute; cast to a protocol so mypy accepts the assignment. faker_with_random = cast(_SupportsRandom, self._faker) faker_with_random.random = rng self._method = method diff --git a/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py b/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py index 1fab655b..ecc40d0c 100644 --- a/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py +++ b/datamimic_ce/domains/common/literal_generators/nobility_title_generator.py @@ -71,7 +71,7 @@ def generate_with_gender(self, gender: str) -> str: values = list(self._female_values) weights = list(self._female_weights) else: - # WHY: Merge available titles for non-binary genders instead of returning None when quota triggers. + # Merge available titles for non-binary genders instead of returning None when quota triggers. values = list(self._male_values) + list(self._female_values) weights = list(self._male_weights) + list(self._female_weights) diff --git a/datamimic_ce/domains/common/models/person.py b/datamimic_ce/domains/common/models/person.py index caa59d97..36d89a6d 100644 --- a/datamimic_ce/domains/common/models/person.py +++ b/datamimic_ce/domains/common/models/person.py @@ -14,6 +14,7 @@ from datetime import datetime from typing import Any +from datamimic_ce.domains.common.demographics.sampler import DemographicSample from datamimic_ce.domains.common.generators.person_generator import PersonGenerator from datamimic_ce.domains.common.models.address import Address from datamimic_ce.domains.domain_core import BaseEntity @@ -30,6 +31,7 @@ class Person(BaseEntity): def __init__(self, person_generator: PersonGenerator): super().__init__() self._person_generator = person_generator + self._demographic_sample: DemographicSample = person_generator.reserve_demographic_sample() @property @property_cache @@ -39,6 +41,13 @@ def gender(self) -> str: Returns: The gender of the person. """ + sample_sex = self._demographic_sample.sex + if sample_sex is not None: + normalized = sample_sex.upper() + if normalized == "F": + return "female" + if normalized == "M": + return "male" return self._person_generator.gender_generator.generate() @property @@ -175,6 +184,9 @@ def birthdate(self) -> datetime: Returns: The birthdate of the person. """ + if self._demographic_sample.age is not None: + # Align birthdate with demographic priors so age property matches sampled intent. + return self._person_generator.generate_birthdate_for_age(self._demographic_sample.age) return self._person_generator.birthdate_generator.generate() @property @@ -185,6 +197,10 @@ def transaction_profile(self) -> str | Mapping[str, float] | None: # Allow downstream finance generators to reuse demographic intent. return self._person_generator.demographic_config.transaction_profile + @property + def demographic_sample(self) -> DemographicSample: + return self._demographic_sample + @property @property_cache def academic_title(self) -> str | None: diff --git a/datamimic_ce/domains/common/services/person_service.py b/datamimic_ce/domains/common/services/person_service.py index e8d4094d..c0fed018 100644 --- a/datamimic_ce/domains/common/services/person_service.py +++ b/datamimic_ce/domains/common/services/person_service.py @@ -7,6 +7,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.generators.person_generator import PersonGenerator from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.common.models.person import Person @@ -27,6 +28,7 @@ def __init__( female_quota: float = 0.5, other_gender_quota: float = 0.0, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, noble_quota: float = 0.001, academic_title_quota: float = 0.5, @@ -46,6 +48,8 @@ def __init__( female_quota=female_quota, other_gender_quota=other_gender_quota, demographic_config=resolved_config, + # Pass sampler through the service so entity generation can honour population priors. + demographic_sampler=demographic_sampler, rng=rng, # Thread descriptor-level overrides for noble/title quotas into the generator for determinism. noble_quota=noble_quota, diff --git a/datamimic_ce/domains/healthcare/generators/doctor_generator.py b/datamimic_ce/domains/healthcare/generators/doctor_generator.py index 677b4099..3f24709b 100644 --- a/datamimic_ce/domains/healthcare/generators/doctor_generator.py +++ b/datamimic_ce/domains/healthcare/generators/doctor_generator.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig import random @@ -35,6 +36,7 @@ def __init__( dataset: str | None = None, rng: random.Random | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, ): # normalize dataset to ISO-3166 alpha-2 and keep lookup consistent self._dataset = (dataset or "US").upper() @@ -43,7 +45,11 @@ def __init__( demo = demographic_config if demographic_config is not None else _DC() self._person_generator = PersonGenerator( - dataset=self._dataset, demographic_config=demo, rng=self._rng, min_age=25 + dataset=self._dataset, + demographic_config=demo, + demographic_sampler=demographic_sampler, + rng=self._rng, + min_age=25, ) self._hospital_generator = HospitalGenerator(dataset=self._dataset, rng=self._rng) self._last_specialty: str | None = None diff --git a/datamimic_ce/domains/healthcare/generators/patient_generator.py b/datamimic_ce/domains/healthcare/generators/patient_generator.py index 7fb758ea..0687c89f 100644 --- a/datamimic_ce/domains/healthcare/generators/patient_generator.py +++ b/datamimic_ce/domains/healthcare/generators/patient_generator.py @@ -14,6 +14,7 @@ from pathlib import Path from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSample, DemographicSampler from datamimic_ce.domains.common.generators.person_generator import PersonGenerator from datamimic_ce.domains.common.literal_generators.family_name_generator import FamilyNameGenerator from datamimic_ce.domains.common.literal_generators.given_name_generator import GivenNameGenerator @@ -37,6 +38,7 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): self._dataset = dataset or "US" @@ -45,8 +47,10 @@ def __init__( self._person_generator = PersonGenerator( dataset=self._dataset, demographic_config=self._demographic_config, + demographic_sampler=demographic_sampler, rng=self._rng, ) + self._demographic_sampler = demographic_sampler # Fan out deterministic RNG so seeded patient cohorts remain reproducible across dependent literals. self._family_name_generator = FamilyNameGenerator( dataset=self._dataset, @@ -103,12 +107,27 @@ def pick_blood_type(self, *, start_path: str | None = None) -> str: self._last_blood_type = choice return choice - def generate_age_appropriate_conditions(self, age: int) -> list[str]: + def generate_age_appropriate_conditions( + self, age: int, demographic_sample: DemographicSample | None = None + ) -> list[str]: """Generate weighted medical conditions for the given age.""" base_weights = _load_condition_base_weights(self._dataset) include = self._demographic_config.normalized_includes() exclude = self._demographic_config.normalized_excludes() + if demographic_sample is not None and demographic_sample.conditions: + selections = [ + condition + for condition in demographic_sample.conditions + if condition.lower() not in {name.lower() for name in exclude} + ] + normalized_includes = {name.lower(): name for name in include} + present = {name.lower() for name in selections} + for key, original in normalized_includes.items(): + if key not in present and key not in {name.lower() for name in exclude}: + selections.append(original) + # Deterministic ordering keeps tests stable when sampler emits frozensets. + return sorted(selections) if not base_weights: return [] # Apply age multipliers while keeping CSV as the single source of names and base weights. diff --git a/datamimic_ce/domains/healthcare/models/patient.py b/datamimic_ce/domains/healthcare/models/patient.py index ab58975f..f4271c78 100644 --- a/datamimic_ce/domains/healthcare/models/patient.py +++ b/datamimic_ce/domains/healthcare/models/patient.py @@ -260,7 +260,9 @@ def conditions(self) -> list[str]: Returns: A list of medical conditions. """ - return self._patient_generator.generate_age_appropriate_conditions(self.age) + sample = self.person_data.demographic_sample + # Feed sampler-backed conditions first so overrides only adjust instead of resampling. + return self._patient_generator.generate_age_appropriate_conditions(self.age, sample) @property @property_cache diff --git a/datamimic_ce/domains/healthcare/services/doctor_service.py b/datamimic_ce/domains/healthcare/services/doctor_service.py index 0529ca0d..4566a8a8 100644 --- a/datamimic_ce/domains/healthcare/services/doctor_service.py +++ b/datamimic_ce/domains/healthcare/services/doctor_service.py @@ -12,6 +12,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.domain_core import BaseDomainService from datamimic_ce.domains.healthcare.generators.doctor_generator import DoctorGenerator @@ -29,12 +30,19 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ) -> None: import random as _r super().__init__( - DoctorGenerator(dataset=dataset, rng=rng or _r.Random(), demographic_config=demographic_config), Doctor + DoctorGenerator( + dataset=dataset, + rng=rng or _r.Random(), + demographic_config=demographic_config, + demographic_sampler=demographic_sampler, + ), + Doctor, ) @staticmethod diff --git a/datamimic_ce/domains/healthcare/services/patient_service.py b/datamimic_ce/domains/healthcare/services/patient_service.py index 68645e95..c722d474 100644 --- a/datamimic_ce/domains/healthcare/services/patient_service.py +++ b/datamimic_ce/domains/healthcare/services/patient_service.py @@ -12,6 +12,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.domain_core import BaseDomainService from datamimic_ce.domains.healthcare.generators.patient_generator import PatientGenerator @@ -29,11 +30,17 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): # Thread demographic and RNG overrides through the service layer. super().__init__( - PatientGenerator(dataset=dataset, demographic_config=demographic_config, rng=rng), + PatientGenerator( + dataset=dataset, + demographic_config=demographic_config, + demographic_sampler=demographic_sampler, + rng=rng, + ), Patient, ) diff --git a/datamimic_ce/domains/public_sector/generators/police_officer_generator.py b/datamimic_ce/domains/public_sector/generators/police_officer_generator.py index d8aebfd4..bb013714 100644 --- a/datamimic_ce/domains/public_sector/generators/police_officer_generator.py +++ b/datamimic_ce/domains/public_sector/generators/police_officer_generator.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from pathlib import Path @@ -29,6 +30,7 @@ def __init__( dataset: str | None = None, rng: random.Random | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, ): """Initialize the police officer generator. @@ -41,7 +43,11 @@ def __init__( demo = demographic_config if demographic_config is not None else _DC() self._person_generator = PersonGenerator( - dataset=self._dataset, demographic_config=demo, rng=self._rng, min_age=21 + dataset=self._dataset, + demographic_config=demo, + demographic_sampler=demographic_sampler, + rng=self._rng, + min_age=21, ) # Hand child generators a derived RNG so seeded officers replay consistently across runs. self._address_generator = AddressGenerator( diff --git a/datamimic_ce/domains/public_sector/services/police_officer_service.py b/datamimic_ce/domains/public_sector/services/police_officer_service.py index c31a4b0d..9b94711e 100644 --- a/datamimic_ce/domains/public_sector/services/police_officer_service.py +++ b/datamimic_ce/domains/public_sector/services/police_officer_service.py @@ -12,6 +12,7 @@ from random import Random +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler from datamimic_ce.domains.common.models.demographic_config import DemographicConfig from datamimic_ce.domains.domain_core import BaseDomainService from datamimic_ce.domains.public_sector.generators.police_officer_generator import PoliceOfficerGenerator @@ -29,12 +30,18 @@ def __init__( self, dataset: str | None = None, demographic_config: DemographicConfig | None = None, + demographic_sampler: DemographicSampler | None = None, rng: Random | None = None, ): import random as _r super().__init__( - PoliceOfficerGenerator(dataset=dataset, rng=rng or _r.Random(), demographic_config=demographic_config), + PoliceOfficerGenerator( + dataset=dataset, + rng=rng or _r.Random(), + demographic_config=demographic_config, + demographic_sampler=demographic_sampler, + ), PoliceOfficer, ) diff --git a/datamimic_ce/domains/utils/rng_uuid.py b/datamimic_ce/domains/utils/rng_uuid.py index 98212266..81db6b6b 100644 --- a/datamimic_ce/domains/utils/rng_uuid.py +++ b/datamimic_ce/domains/utils/rng_uuid.py @@ -21,7 +21,7 @@ def uuid4_from_random(rng: Random) -> str: """ value = rng.getrandbits(128) - # WHY: Enforce UUID version/variant bits so downstream validators accept the value. + # Enforce UUID version/variant bits so downstream validators accept the value. value &= ~(0xF << 76) value |= 0x4 << 76 # Set version 4 value &= ~(0x3 << 62) diff --git a/datamimic_ce/model/demographics_model.py b/datamimic_ce/model/demographics_model.py new file mode 100644 index 00000000..c2d19810 --- /dev/null +++ b/datamimic_ce/model/demographics_model.py @@ -0,0 +1,23 @@ +"""Pydantic model for XML node.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field, model_validator + +from datamimic_ce.constants.attribute_constants import ATTR_DATASET, ATTR_DIR +from datamimic_ce.model.model_util import ModelUtil + + +class DemographicsModel(BaseModel): + dataset: str = Field(..., alias=ATTR_DATASET) + version: str + directory: str = Field(..., alias=ATTR_DIR) + rng_seed: int | None = Field(None, alias="rngSeed") + + @model_validator(mode="before") + @classmethod + def check_valid_attributes(cls, values: dict): + return ModelUtil.check_valid_attributes( + values=values, + valid_attributes={ATTR_DATASET, "version", ATTR_DIR, "rngSeed"}, + ) diff --git a/datamimic_ce/parsers/demographics_parser.py b/datamimic_ce/parsers/demographics_parser.py new file mode 100644 index 00000000..63fb9e00 --- /dev/null +++ b/datamimic_ce/parsers/demographics_parser.py @@ -0,0 +1,25 @@ +"""Parser for elements.""" + +from __future__ import annotations + +from pathlib import Path +from xml.etree.ElementTree import Element + +from datamimic_ce.constants.element_constants import EL_DEMOGRAPHICS +from datamimic_ce.model.demographics_model import DemographicsModel +from datamimic_ce.parsers.statement_parser import StatementParser +from datamimic_ce.statements.demographics_statement import DemographicsStatement + + +class DemographicsParser(StatementParser): + def __init__(self, element: Element, properties: dict | None): + super().__init__(element, properties, valid_element_tag=EL_DEMOGRAPHICS) + + def parse(self, descriptor_dir: Path) -> DemographicsStatement: + model = self.validate_attributes(DemographicsModel) + directory = Path(model.directory) + if not directory.is_absolute(): + # Resolve relative demographic directories against the descriptor to keep XML portable. + directory = (descriptor_dir / directory).resolve() + model.directory = str(directory) + return DemographicsStatement(model) diff --git a/datamimic_ce/parsers/parser_util.py b/datamimic_ce/parsers/parser_util.py index b34a6468..aabd1ebd 100644 --- a/datamimic_ce/parsers/parser_util.py +++ b/datamimic_ce/parsers/parser_util.py @@ -16,6 +16,7 @@ EL_ARRAY, EL_CONDITION, EL_DATABASE, + EL_DEMOGRAPHICS, EL_ECHO, EL_ELEMENT, EL_ELSE, @@ -97,6 +98,7 @@ def get_valid_sub_elements_set_by_tag(ele_tag: str) -> set | None: EL_ECHO, EL_VARIABLE, EL_GENERATOR, + EL_DEMOGRAPHICS, }, EL_NESTED_KEY: { EL_KEY, @@ -188,6 +190,10 @@ def _get_parser_by_element(element: Element, properties: dict): return ElementParser(element, properties) elif tag == EL_GENERATOR: return GeneratorParser(element, properties) + elif tag == EL_DEMOGRAPHICS: + from datamimic_ce.parsers.demographics_parser import DemographicsParser + + return DemographicsParser(element, properties) else: raise ValueError(f"Cannot get parser for element <{tag}>") diff --git a/datamimic_ce/statements/demographics_statement.py b/datamimic_ce/statements/demographics_statement.py new file mode 100644 index 00000000..fa42f3cc --- /dev/null +++ b/datamimic_ce/statements/demographics_statement.py @@ -0,0 +1,31 @@ +"""Statement object representing the node.""" + +from __future__ import annotations + +from datamimic_ce.model.demographics_model import DemographicsModel +from datamimic_ce.statements.statement import Statement + + +class DemographicsStatement(Statement): + def __init__(self, model: DemographicsModel): + super().__init__(None, None) + self._dataset = model.dataset + self._version = model.version + self._directory = model.directory + self._rng_seed = model.rng_seed + + @property + def dataset(self) -> str: + return self._dataset + + @property + def version(self) -> str: + return self._version + + @property + def directory(self) -> str: + return self._directory + + @property + def rng_seed(self) -> int | None: + return self._rng_seed diff --git a/datamimic_ce/tasks/demographics_task.py b/datamimic_ce/tasks/demographics_task.py new file mode 100644 index 00000000..633e6a67 --- /dev/null +++ b/datamimic_ce/tasks/demographics_task.py @@ -0,0 +1,37 @@ +"""Task that installs demographic profiles into the setup context.""" + +from __future__ import annotations + +from pathlib import Path +from random import Random + +from datamimic_ce.contexts.demographic_context import DemographicContext +from datamimic_ce.contexts.setup_context import SetupContext +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler +from datamimic_ce.domains.common.models.demographic_config import DemographicConfig +from datamimic_ce.statements.demographics_statement import DemographicsStatement +from datamimic_ce.tasks.task import SetupSubTask + + +class DemographicsTask(SetupSubTask): + def __init__(self, statement: DemographicsStatement): + self._statement = statement + + def execute(self, ctx: SetupContext) -> None: + directory = Path(self._statement.directory) + profile = load_demographic_profile(directory, self._statement.dataset, self._statement.version) + sampler = DemographicSampler(profile) + rng = Random(self._statement.rng_seed) if self._statement.rng_seed is not None else Random() + demographic_context = DemographicContext( + profile_id=profile.profile_id, + sampler=sampler, + overrides=DemographicConfig(), + rng=rng, + ) + # Store the context once so every entity derives deterministic child RNGs without hidden globals. + ctx.set_demographic_context(demographic_context) + + @property + def statement(self) -> DemographicsStatement: + return self._statement diff --git a/datamimic_ce/tasks/task_util.py b/datamimic_ce/tasks/task_util.py index 7d1ae1eb..4b570635 100644 --- a/datamimic_ce/tasks/task_util.py +++ b/datamimic_ce/tasks/task_util.py @@ -41,6 +41,7 @@ from datamimic_ce.statements.array_statement import ArrayStatement from datamimic_ce.statements.condition_statement import ConditionStatement from datamimic_ce.statements.database_statement import DatabaseStatement +from datamimic_ce.statements.demographics_statement import DemographicsStatement from datamimic_ce.statements.echo_statement import EchoStatement from datamimic_ce.statements.element_statement import ElementStatement from datamimic_ce.statements.else_if_statement import ElseIfStatement @@ -127,6 +128,10 @@ def get_task_by_statement( from datamimic_ce.tasks.condition_task import ConditionTask return ConditionTask(stmt) + elif isinstance(stmt, DemographicsStatement): + from datamimic_ce.tasks.demographics_task import DemographicsTask + + return DemographicsTask(stmt) elif isinstance(stmt, ElseIfStatement): from datamimic_ce.tasks.else_if_task import ElseIfTask diff --git a/datamimic_ce/tasks/variable_task.py b/datamimic_ce/tasks/variable_task.py index 2260b0a8..a35ff32e 100644 --- a/datamimic_ce/tasks/variable_task.py +++ b/datamimic_ce/tasks/variable_task.py @@ -249,6 +249,7 @@ def _get_entity_generator( entity_class_name, kwargs = StringUtil.parse_constructor_string(entity_name) # Inject dataset if not explicitly provided in constructor kwargs.setdefault("dataset", dataset) + demographic_context = getattr(ctx.root, "demographic_context", None) # Build demographic config + rng from statement attributes when present demo_cfg = None if any( @@ -272,6 +273,13 @@ def _get_entity_generator( conditions_exclude=excludes, ) rng_obj = Random(statement.rng_seed) if statement.rng_seed is not None else None + if demo_cfg is None and demographic_context is not None and demographic_context.overrides is not None: + # Share profile-level defaults when no per-variable overrides are provided. + demo_cfg = demographic_context.overrides + if rng_obj is None and demographic_context is not None: + # Derive entity-level RNGs from the demographics root seed to keep sampling reproducible. + rng_obj = Random(demographic_context.rng.randrange(2**63)) + demographic_sampler = demographic_context.sampler if demographic_context is not None else None # Build from the last parsed VariableTask (self is not accessible in staticmethod); use closure via locals() # Check if entity_class_name contains dots indicating a domain path @@ -318,16 +326,22 @@ def _get_entity_generator( if entity_class_name in entity_mappings: domain_entity_path = entity_mappings[entity_class_name] # Only attach demographic_config/rng for supported services - if demo_cfg is not None and entity_class_name in { + supported_demographic_entities = { "Patient", "Doctor", "PoliceOfficer", "Person", + } + config_enabled_entities = supported_demographic_entities | { "InsurancePolicy", "MedicalDevice", "CreditCard", - }: - kwargs["demographic_config"] = demo_cfg + } + if demo_cfg is not None and entity_class_name in config_enabled_entities: + kwargs.setdefault("demographic_config", demo_cfg) + if demographic_sampler is not None and entity_class_name in supported_demographic_entities: + # Thread pure sampler through to entities that know how to consume demographics. + kwargs.setdefault("demographic_sampler", demographic_sampler) if rng_obj is not None and entity_class_name in { "Patient", "Doctor", diff --git a/docs/demographics.md b/docs/demographics.md new file mode 100644 index 00000000..348fe5fa --- /dev/null +++ b/docs/demographics.md @@ -0,0 +1,69 @@ +# Demographic Profiles + +Demographic profiles provide reusable population priors for age/sex distributions and +condition prevalence. Profiles live as versioned CSV bundles that are loaded at runtime +and exposed through an explicit `DemographicContext` so generators can honour priors +without relying on globals. + +## `.dmgrp.csv` contract + +All demographic files share the `.dmgrp.csv` suffix to simplify discovery and +versioning. A complete profile currently consists of: + +- `age_pyramid.dmgrp.csv` + - Columns: `dataset,version,sex,age_min,age_max,weight` + - Weights normalise to 1.0 per sex bucket (or overall when `sex` is empty). +- `condition_rates.dmgrp.csv` + - Columns: `dataset,version,condition,sex,age_min,age_max,prevalence` + - `prevalence` is expressed as a probability in `[0,1]`. +- `profile_meta.dmgrp.csv` (optional metadata placeholder) + - Columns: `dataset,version,source,notes,checksum` + +Additional `.dmgrp.csv` files are ignored until the corresponding loader support is +implemented, which keeps the bundle forward compatible. + +## XML usage + +Declare demographics inside `` to make the profile and seeded sampler available +to all entities generated within the descriptor: + +```xml + + + + + + + + +``` + +- The loader validates dataset/version columns, ensures age buckets do not overlap and + raises descriptive errors for malformed rows. +- Each entity supported by demographics (`Person`, `Patient`, `Doctor`, + `PoliceOfficer`) receives the sampler together with an entity-scoped RNG derived from + the `` `rngSeed`. +- `DemographicConfig` overrides (`ageMin`, `ageMax`, `conditionsInclude`, + `conditionsExclude`) remain available per variable and win over sampler priors. + +## Determinism and seeds + +The `` node seeds a root RNG. Variable-level `rngSeed` attributes still +work; when omitted, the task derives child RNGs from the demographic context so every +entity consumes disjoint deterministic streams. Loader validation is pure, and the +sampler has no hidden state, which keeps test runs reproducible. + +## Limitations + +- Sex selection defaults to a uniform choice across available sex buckets until richer + marginal distributions are supplied. +- Condition prevalence is modelled as independent Bernoulli trials; comorbidity pairs + are not sampled yet. +- Birthdates outside `[0,100]` trigger warnings so unrealistic priors can be spotted + early. + +## Roadmap + +Pairwise odds ratios will eventually be supported behind a feature flag. Until then, +set `DEMOGRAPHY_PAIRS=false` in the environment to document the current limitation. diff --git a/tests_ce/integration_tests/test_demographics/DE/2023Q4/age_pyramid.dmgrp.csv b/tests_ce/integration_tests/test_demographics/DE/2023Q4/age_pyramid.dmgrp.csv new file mode 100644 index 00000000..5f2a2768 --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/DE/2023Q4/age_pyramid.dmgrp.csv @@ -0,0 +1,9 @@ +dataset,version,sex,age_min,age_max,weight +DE,2023Q4,F,0,17,0.18 +DE,2023Q4,F,18,39,0.28 +DE,2023Q4,F,40,64,0.34 +DE,2023Q4,F,65,100,0.20 +DE,2023Q4,M,0,17,0.19 +DE,2023Q4,M,18,39,0.29 +DE,2023Q4,M,40,64,0.33 +DE,2023Q4,M,65,100,0.19 diff --git a/tests_ce/integration_tests/test_demographics/DE/2023Q4/condition_rates.dmgrp.csv b/tests_ce/integration_tests/test_demographics/DE/2023Q4/condition_rates.dmgrp.csv new file mode 100644 index 00000000..e645f3b5 --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/DE/2023Q4/condition_rates.dmgrp.csv @@ -0,0 +1,5 @@ +dataset,version,condition,sex,age_min,age_max,prevalence +DE,2023Q4,Diabetes,F,65,100,0.19 +DE,2023Q4,Diabetes,M,65,100,0.22 +DE,2023Q4,Hypertension,,40,64,0.27 +DE,2023Q4,Hypertension,,65,100,0.48 diff --git a/tests_ce/integration_tests/test_demographics/DE/2023Q4/profile_meta.dmgrp.csv b/tests_ce/integration_tests/test_demographics/DE/2023Q4/profile_meta.dmgrp.csv new file mode 100644 index 00000000..749a1a4c --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/DE/2023Q4/profile_meta.dmgrp.csv @@ -0,0 +1,2 @@ +dataset,version,source,notes,checksum +DE,2023Q4,example,illustrative-only,abc123 diff --git a/tests_ce/integration_tests/test_demographics/demo.xml b/tests_ce/integration_tests/test_demographics/demo.xml new file mode 100644 index 00000000..748a81bb --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/demo.xml @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/tests_ce/integration_tests/test_demographics/test_demographics.py b/tests_ce/integration_tests/test_demographics/test_demographics.py new file mode 100644 index 00000000..28444ce7 --- /dev/null +++ b/tests_ce/integration_tests/test_demographics/test_demographics.py @@ -0,0 +1,18 @@ +# DATAMIMIC +# Copyright (c) 2023-2025 Rapiddweller Asia Co., Ltd. +# This software is licensed under the MIT License. +# See LICENSE file for the full text of the license. +# For questions and support, contact: info@rapiddweller.com + + +from pathlib import Path + +from datamimic_ce.data_mimic_test import DataMimicTest + + +class TestDemographics: + _test_dir = Path(__file__).resolve().parent + + def test_simple(self): + test_engine = DataMimicTest(test_dir=self._test_dir, filename="demo.xml") + test_engine.test_with_timer() diff --git a/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py b/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py index 11098104..53119c44 100644 --- a/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py +++ b/tests_ce/unit_tests/test_generator/test_nobility_title_generator.py @@ -10,7 +10,7 @@ def test_other_gender_returns_string_when_quota_hits() -> None: generator = NobilityTitleGenerator(dataset="US", noble_quota=1.0, rng=Random(1337)) - # WHY: Ensure non-binary genders still receive a string title when quota triggers. + # Ensure non-binary genders still receive a string title when quota triggers. title = generator.generate_with_gender("other") assert isinstance(title, str) diff --git a/tests_ce/unit_tests/tests_demographics/data/age_pyramid.dmgrp.csv b/tests_ce/unit_tests/tests_demographics/data/age_pyramid.dmgrp.csv new file mode 100644 index 00000000..03111ea4 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/data/age_pyramid.dmgrp.csv @@ -0,0 +1,7 @@ +dataset,version,sex,age_min,age_max,weight +TEST,v1,F,0,17,0.2 +TEST,v1,F,18,35,0.3 +TEST,v1,F,36,80,0.5 +TEST,v1,M,0,17,0.25 +TEST,v1,M,18,35,0.35 +TEST,v1,M,36,80,0.4 diff --git a/tests_ce/unit_tests/tests_demographics/data/condition_rates.dmgrp.csv b/tests_ce/unit_tests/tests_demographics/data/condition_rates.dmgrp.csv new file mode 100644 index 00000000..a1c92578 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/data/condition_rates.dmgrp.csv @@ -0,0 +1,6 @@ +dataset,version,condition,sex,age_min,age_max,prevalence +TEST,v1,Asthma,,0,17,0.1 +TEST,v1,Asthma,,18,80,0.05 +TEST,v1,Diabetes,F,36,80,0.15 +TEST,v1,Diabetes,M,36,80,0.12 +TEST,v1,Hypertension,,36,80,0.3 diff --git a/tests_ce/unit_tests/tests_demographics/test_distribution_fit.py b/tests_ce/unit_tests/tests_demographics/test_distribution_fit.py new file mode 100644 index 00000000..74b71597 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_distribution_fit.py @@ -0,0 +1,54 @@ +"""Statistical checks for demographic sampling.""" + +from __future__ import annotations + +import shutil +import sys +from collections import Counter +from pathlib import Path +from random import Random + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pytest + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler + +_test_dir = Path(__file__).resolve().parent + +@pytest.fixture() +def profile_and_sampler(tmp_path: Path): + fixture_dir = Path(_test_dir / "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + profile = load_demographic_profile(tmp_path, "TEST", "v1") + return profile, DemographicSampler(profile) + + +def test_age_distribution_matches_profile(profile_and_sampler) -> None: + profile, sampler = profile_and_sampler + rng = Random(2024) + sexes = [sex for sex in profile.sexes() if sex is not None] + counts = {sex: Counter() for sex in sexes} + samples = 50000 + for _ in range(samples): + age, sex = sampler.sample_age_sex(rng) + bands = profile.bands_for_sex(sex) + for band in bands: + if band.contains(age): + counts[sex][(band.age_min, band.age_max)] += 1 + break + chi_square = 0.0 + degrees_of_freedom = 0 + expected_per_sex = samples / max(len(sexes), 1) + for sex in sexes: + bands = profile.bands_for_sex(sex) + degrees_of_freedom += max(len(bands) - 1, 0) + for band in bands: + expected = expected_per_sex * band.weight + observed = counts[sex][(band.age_min, band.age_max)] + chi_square += (observed - expected) ** 2 / expected + # Critical value for df=4, alpha=0.05 is ~9.488 + assert degrees_of_freedom == 4 + assert chi_square <= 9.488 diff --git a/tests_ce/unit_tests/tests_demographics/test_loader_profile.py b/tests_ce/unit_tests/tests_demographics/test_loader_profile.py new file mode 100644 index 00000000..3ec57c5c --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_loader_profile.py @@ -0,0 +1,34 @@ +"""Loader tests for demographic profiles.""" + +from __future__ import annotations + +import shutil +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pytest + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile + +_test_dir = Path(__file__).resolve().parent + +@pytest.fixture() +def profile_dir(tmp_path: Path) -> Path: + fixture_dir = Path(_test_dir/ "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + return tmp_path + + +def test_load_profile_normalizes_and_indexes(profile_dir: Path) -> None: + profile = load_demographic_profile(profile_dir, "TEST", "v1") + female_bands = profile.bands_for_sex("F") + male_bands = profile.bands_for_sex("M") + assert len(female_bands) == 3 + assert len(male_bands) == 3 + assert pytest.approx(sum(b.weight for b in female_bands), rel=1e-9) == 1.0 + assert pytest.approx(sum(b.weight for b in male_bands), rel=1e-9) == 1.0 + rates = profile.conditions_for("Hypertension") + assert rates and rates[0].prevalence == pytest.approx(0.3) diff --git a/tests_ce/unit_tests/tests_demographics/test_overrides_precedence.py b/tests_ce/unit_tests/tests_demographics/test_overrides_precedence.py new file mode 100644 index 00000000..004affbc --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_overrides_precedence.py @@ -0,0 +1,45 @@ +"""Tests ensuring demographic overrides trump sampler priors.""" + +from __future__ import annotations + +import shutil +import sys +from pathlib import Path +from random import Random + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler, DemographicSample +from datamimic_ce.domains.common.models.demographic_config import DemographicConfig +from datamimic_ce.domains.healthcare.generators.patient_generator import PatientGenerator + +_test_dir = Path(__file__).resolve().parent + +def _load_sampler(tmp_path: Path) -> DemographicSampler: + fixture_dir = Path(_test_dir / "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + profile = load_demographic_profile(tmp_path, "TEST", "v1") + return DemographicSampler(profile) + + +def test_overrides_adjust_sampled_conditions(tmp_path: Path) -> None: + sampler = _load_sampler(tmp_path) + rng = Random(101) + age, sex = sampler.sample_age_sex(rng) + sampled_conditions = sampler.sample_conditions(age, sex, rng) + demo_sample = DemographicSample(age=age, sex=sex, conditions=sampled_conditions) + overrides = DemographicConfig( + conditions_include=frozenset({"Hypertension"}), + conditions_exclude=frozenset({"Asthma"}), + ) + generator = PatientGenerator( + dataset="US", + demographic_config=overrides, + demographic_sampler=sampler, + rng=Random(2024), + ) + conditions = generator.generate_age_appropriate_conditions(age, demo_sample) + assert "Hypertension" in set(conditions) + assert all(cond.lower() != "asthma" for cond in conditions) diff --git a/tests_ce/unit_tests/tests_demographics/test_sampler_determinism.py b/tests_ce/unit_tests/tests_demographics/test_sampler_determinism.py new file mode 100644 index 00000000..0e41d251 --- /dev/null +++ b/tests_ce/unit_tests/tests_demographics/test_sampler_determinism.py @@ -0,0 +1,46 @@ +"""Determinism tests for the demographic sampler.""" + +from __future__ import annotations + +import shutil +import sys +from pathlib import Path +from random import Random + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pytest + +from datamimic_ce.domains.common.demographics.loader import load_demographic_profile +from datamimic_ce.domains.common.demographics.sampler import DemographicSampler + +_test_dir = Path(__file__).resolve().parent + +@pytest.fixture() +def sampler(tmp_path: Path) -> DemographicSampler: + fixture_dir = Path(_test_dir / "data") + for name in ("age_pyramid.dmgrp.csv", "condition_rates.dmgrp.csv"): + shutil.copy(fixture_dir / name, tmp_path / name) + profile = load_demographic_profile(tmp_path, "TEST", "v1") + return DemographicSampler(profile) + + +def test_age_sex_sampling_is_deterministic(sampler: DemographicSampler) -> None: + rng_a = Random(123) + rng_b = Random(123) + sequence_a = [sampler.sample_age_sex(rng_a) for _ in range(10)] + sequence_b = [sampler.sample_age_sex(rng_b) for _ in range(10)] + assert sequence_a == sequence_b + + +def test_condition_sampling_is_deterministic(sampler: DemographicSampler) -> None: + rng_a = Random(987) + rng_b = Random(987) + draws_a = [] + draws_b = [] + for _ in range(10): + age_a, sex_a = sampler.sample_age_sex(rng_a) + draws_a.append(sampler.sample_conditions(age_a, sex_a, rng_a)) + age_b, sex_b = sampler.sample_age_sex(rng_b) + draws_b.append(sampler.sample_conditions(age_b, sex_b, rng_b)) + assert draws_a == draws_b