Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datamimic_ce/domains/common/demographics/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _validate_band_coverage(bands: Iterable[DemographicAgeBand], file_path: Path
)
if previous and band.age_min > previous.age_max + 1:
logger.warning(
"Gap detected between age bands for sex='%s' in '%s': [%s,%s] -> [%s,%s]", # type: ignore[str-format]
"Gap detected between age bands for sex='%s' in '%s': [%s,%s] -> [%s,%s]",
sex or "",
file_path,
previous.age_min,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,31 +98,38 @@ def generate_next_maintenance_date(self) -> str:
def generate_device_type(self) -> str:
file_path = dataset_path("healthcare", "medical", f"device_types_{self._dataset}.csv", start=Path(__file__))
loaded_data = FileUtil.read_weight_csv(file_path)
return self._rng.choices(loaded_data[0], weights=loaded_data[1], k=1)[0] # type: ignore[arg-type]
values: list[str] = [str(v) for v in loaded_data[0].tolist() if v is not None]
weights: list[float] = [float(w) for w in loaded_data[1].tolist()]
return self._rng.choices(values, weights=weights, k=1)[0]

def generate_manufacturer(self) -> str:
file_path = dataset_path("healthcare", "medical", f"manufacturers_{self._dataset}.csv", start=Path(__file__))
loaded_data = FileUtil.read_weight_csv(file_path)
values, weights = loaded_data[0], loaded_data[1]
values: list[str] = [str(v) for v in loaded_data[0].tolist() if v is not None]
weights: list[float] = [float(w) for w in loaded_data[1].tolist()]
# avoid immediate repetition when possible
if self._last_manufacturer in values and len(values) > 1:
pool = [(v, float(w)) for v, w in zip(values, weights, strict=False) if v != self._last_manufacturer]
p_vals, p_w = zip(*pool, strict=False)
choice = self._rng.choices(list(p_vals), weights=list(p_w), k=1)[0]
else:
choice = self._rng.choices(values, weights=weights, k=1)[0] # type: ignore[arg-type]
choice = self._rng.choices(values, weights=weights, k=1)[0]
self._last_manufacturer = choice
return choice # type: ignore[return-value]
return choice

def generate_device_status(self) -> str:
file_path = dataset_path("healthcare", "medical", f"device_statuses_{self._dataset}.csv", start=Path(__file__))
loaded_data = FileUtil.read_weight_csv(file_path)
return self._rng.choices(loaded_data[0], weights=loaded_data[1], k=1)[0] # type: ignore[arg-type]
values: list[str] = [str(v) for v in loaded_data[0].tolist() if v is not None]
weights: list[float] = [float(w) for w in loaded_data[1].tolist()]
return self._rng.choices(values, weights=weights, k=1)[0]

def generate_location(self) -> str:
file_path = dataset_path("healthcare", "medical", f"locations_{self._dataset}.csv", start=Path(__file__))
loaded_data = FileUtil.read_weight_csv(file_path) # type: ignore[arg-type]
return self._rng.choices(loaded_data[0], weights=loaded_data[1], k=1)[0] # type: ignore[arg-type]
loaded_data = FileUtil.read_weight_csv(file_path)
values: list[str] = [str(v) for v in loaded_data[0].tolist() if v is not None]
weights: list[float] = [float(w) for w in loaded_data[1].tolist()]
return self._rng.choices(values, weights=weights, k=1)[0]

# Helpers for specifications (modes, detector types, etc.) used by the model
def pick_ventilator_mode(self) -> str:
Expand Down
5 changes: 5 additions & 0 deletions datamimic_ce/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Public entrypoints for the MCP integration."""

from .server import create_server, mount_mcp

__all__ = ["create_server", "mount_mcp"]
101 changes: 101 additions & 0 deletions datamimic_ce/mcp/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Command line interface for serving the FastMCP endpoint.

WHY: Typer in some environments doesn't support ``typing.Literal`` for
option types, raising "Type not yet supported". To keep broad
compatibility without upgrading tooling, we use ``Enum`` for choices.
This keeps the CLI thin, explicit, and compatible across Typer versions.
"""

from __future__ import annotations

import os
from enum import Enum
from typing import Annotated

import typer
import uvicorn

from datamimic_ce.mcp.server import (
HTTP_MIDDLEWARE_ATTR,
build_sse_app,
create_server,
)

app = typer.Typer(help="Run the DataMimic MCP server")

_DEFAULT_HOST = "127.0.0.1"
_DEFAULT_PORT = 8765

# WHY: Avoid ruff-bugbear B008 (no function calls in argument defaults).
# We compute environment-derived defaults at module load and use them as
# plain defaults, while passing Typer options via `Annotated` metadata.
_ENV_HOST_DEFAULT = os.getenv("DATAMIMIC_MCP_HOST", _DEFAULT_HOST)
_ENV_PORT_DEFAULT = int(os.getenv("DATAMIMIC_MCP_PORT", str(_DEFAULT_PORT)))


class Transport(str, Enum):
"""Supported transport mechanisms for FastMCP.

WHY: Replaces ``Literal['sse', 'stdio']`` to avoid Typer limitations.
"""

sse = "sse"
stdio = "stdio"


class LogLevel(str, Enum):
"""Supported log levels for uvicorn.

WHY: Replaces ``Literal[...]`` to avoid Typer limitations.
"""

critical = "critical"
error = "error"
warning = "warning"
info = "info"
debug = "debug"


@app.command()
def serve(
host: Annotated[
str,
typer.Option(help="Host interface to bind for SSE transport"),
] = _ENV_HOST_DEFAULT,
port: Annotated[
int,
typer.Option(help="TCP port to bind for SSE transport"),
] = _ENV_PORT_DEFAULT,
transport: Annotated[
Transport,
typer.Option(
help="FastMCP transport (sse for network clients, stdio for agent runtimes)",
),
] = Transport.sse,
log_level: Annotated[
LogLevel,
typer.Option(help="Log level when running the SSE server"),
] = LogLevel.info,
) -> None:
"""Start the FastMCP server with optional API key gating."""

api_key = os.getenv("DATAMIMIC_MCP_API_KEY")
server = create_server(api_key=api_key)
middleware = getattr(server, HTTP_MIDDLEWARE_ATTR, None)

if transport == Transport.stdio:
# WHY: stdio transport is typically embedded; it does not honour host/port.
server.run(Transport.stdio.value)
return

sse_app = build_sse_app(server, middleware)
uvicorn.run(
sse_app,
host=host,
port=port,
log_level=log_level.value,
)


if __name__ == "__main__": # pragma: no cover
app()
159 changes: 159 additions & 0 deletions datamimic_ce/mcp/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Pydantic models backing the MCP surface."""

from __future__ import annotations

import csv
from functools import cache
from typing import Any

from pydantic import BaseModel, Field, model_validator

from datamimic_ce.domains.common.locale_registry import dataset_code_for_locale
from datamimic_ce.domains.locales import SUPPORTED_DATASET_CODES
from datamimic_ce.domains.utils.dataset_path import dataset_path

MAX_COUNT = 10_000
_DEFAULT_LOCALE = "en_US"
_DEFAULT_CLOCK = "2025-01-01T00:00:00Z"

_DATASET_LOCALE_DEFAULTS: dict[str, str] = {
"US": "en_US",
"DE": "de_DE",
"VN": "vi_VN",
}


@cache
def _default_locale_for_dataset(dataset: str) -> str | None:
"""Infer the canonical locale for a dataset using packaged metadata."""

normalized = dataset.upper()
explicit = _DATASET_LOCALE_DEFAULTS.get(normalized)
if explicit:
return explicit

try:
path = dataset_path("common", f"country_{normalized}.csv")
except OSError: # pragma: no cover - defensive fall-back for env overrides
return None

expected_name = f"country_{normalized}.csv"
if path.name != expected_name or not path.exists():
return None

with path.open("r", encoding="utf-8") as handle:
reader = csv.reader(handle)
for row in reader:
if not row:
continue
code = row[0].strip().upper()
if code != normalized or len(row) < 2:
continue
locale = row[1].strip().replace("-", "_")
if locale:
return locale
return None


class GenerateArgs(BaseModel):
"""Validated request arguments for the ``generate`` MCP tool."""

domain: str = Field(..., description="Target domain identifier")
version: str = Field("v1", description="Domain contract version")
count: int = Field(
1,
ge=0,
le=MAX_COUNT,
description="Number of records to generate (capped at 10k)",
)
seed: int | str = Field(
"0",
description="Seed propagated to the deterministic generators",
)
locale: str | None = Field(
None,
description="Locale identifier (e.g. en_US). Defaults to dataset derived locale",
)
dataset: str | None = Field(
None,
description=("Dataset code that implies a locale (e.g. US). Mutually exclusive with custom locale overrides."),
)
constraints: dict[str, Any] | None = Field(
None,
description="Optional domain-specific constraint overrides",
)
profile_id: str | None = Field(
None,
description="Explicit profile selection (mutually exclusive with component_id)",
)
component_id: str | None = Field(
None,
description="Component-driven profile selection (mutually exclusive with profile_id)",
)
clock: str = Field(
_DEFAULT_CLOCK,
description="ISO8601 timestamp establishing the deterministic reference clock",
)

@model_validator(mode="after")
def _enforce_contracts(self) -> GenerateArgs:
if self.profile_id and self.component_id:
raise ValueError("profile_id and component_id cannot be combined")

resolved_dataset = self._normalize_dataset(self.dataset)
resolved_locale = self._normalize_locale(self.locale, resolved_dataset)

if self.locale and resolved_dataset:
dataset_for_locale = dataset_code_for_locale(self.locale)
if dataset_for_locale.upper() != resolved_dataset:
raise ValueError("locale and dataset disagree on the backing data pack")

object.__setattr__(self, "dataset", resolved_dataset)
object.__setattr__(self, "locale", resolved_locale)
return self

@staticmethod
def _normalize_dataset(raw: str | None) -> str | None:
"""Canonicalize dataset identifiers and enforce the supported list."""
if raw is None:
return None
normalized = raw.strip().upper()
if not normalized:
return None
if normalized not in SUPPORTED_DATASET_CODES:
supported = ", ".join(SUPPORTED_DATASET_CODES)
raise ValueError(
f"Unsupported dataset '{normalized}'. Expected one of [{supported}]",
)
return normalized

@staticmethod
def _normalize_locale(locale: str | None, dataset: str | None) -> str:
"""Derive the locale, falling back to dataset defaults or the global default."""
if locale:
return locale
if dataset:
inferred = _default_locale_for_dataset(dataset)
if inferred:
return inferred
return _DEFAULT_LOCALE

def to_payload(self) -> dict[str, Any]:
"""Serialize the request into the facade payload shape."""
payload: dict[str, Any] = {
"domain": self.domain,
"version": self.version,
"count": self.count,
"seed": self.seed,
"locale": self.locale,
"constraints": self.constraints or {},
"clock": self.clock,
}
if self.profile_id:
payload["profile_id"] = self.profile_id
if self.component_id:
payload["component_id"] = self.component_id
return payload


__all__ = ["GenerateArgs", "MAX_COUNT"]
56 changes: 56 additions & 0 deletions datamimic_ce/mcp/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Expose JSON Schema resources for MCP clients."""

from __future__ import annotations

import json
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, cast

from datamimic_ce.domains import facade

SchemaKind = Literal["request", "response"]


@dataclass(frozen=True)
class SchemaResource:
"""Descriptor for a domain schema resource."""

domain: str
version: str
kind: SchemaKind

@property
def uri(self) -> str:
"""Return the canonical MCP resource URI."""
return f"resource://datamimic/schemas/{self.domain}/{self.version}/{self.kind}.json"

@property
def path(self) -> Path:
"""Resolve the on-disk schema path for the resource."""
filename = f"{self.domain}.{self.version}.{self.kind}.json"
return Path(__file__).resolve().parents[1] / "domains" / "schemas" / filename


def iter_schema_resources() -> Iterator[SchemaResource]:
"""Yield schema descriptors for each domain and version pair."""
for domain, version in sorted(facade.REGISTRY):
yield SchemaResource(domain=domain, version=version, kind="request")
yield SchemaResource(domain=domain, version=version, kind="response")


def load_schema(domain: str, version: str, kind: SchemaKind) -> dict[str, Any]:
"""Load a schema document from disk using the canonical registry layout."""
resource = SchemaResource(domain=domain, version=version, kind=kind)
path = resource.path
if not path.exists():
raise FileNotFoundError(f"Missing schema file for {resource.uri}")
with path.open("r", encoding="utf-8") as handle:
data = json.load(handle)
if not isinstance(data, dict): # Defensive guardrail for schema drift.
raise TypeError(f"Schema at {path} must decode into an object")
return cast(dict[str, Any], data)


__all__ = ["SchemaResource", "SchemaKind", "iter_schema_resources", "load_schema"]
Loading