Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing_extensions import Self

from .exttypes.pyobjectpath import PyObjectPath
from .local_persistence import register_local_subclass

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,6 +157,15 @@ class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta
- Registration by name, and coercion from string name to allow for object re-use from the configs
"""

__ccflow_local_registration_kind__: ClassVar[str] = "model"

@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
# __pydantic_init_subclass__ is the supported hook point once Pydantic finishes wiring the subclass.
super().__pydantic_init_subclass__(**kwargs)
kind = getattr(cls, "__ccflow_local_registration_kind__", "model")
register_local_subclass(cls, kind=kind)

@computed_field(
alias="_target_",
repr=False,
Expand Down Expand Up @@ -820,6 +830,8 @@ class ContextBase(ResultBase):
that is an input into another CallableModel.
"""

__ccflow_local_registration_kind__: ClassVar[str] = "context"

model_config = ConfigDict(
frozen=True,
arbitrary_types_allowed=False,
Expand Down
2 changes: 2 additions & 0 deletions ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class _CallableModel(BaseModel, abc.ABC):
The purpose of this class is to provide type definitions of context_type and return_type.
"""

__ccflow_local_registration_kind__: ClassVar[str] = "callable_model"

model_config = ConfigDict(
ignored_types=(property,),
)
Expand Down
139 changes: 139 additions & 0 deletions ccflow/local_persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Helpers for persisting BaseModel-derived classes defined inside local scopes."""

from __future__ import annotations

import importlib.abc
import importlib.util
import re
import sys
from collections import defaultdict
from itertools import count
from typing import TYPE_CHECKING, Any, DefaultDict, Type

if TYPE_CHECKING:
from importlib.machinery import ModuleSpec
from types import ModuleType

__all__ = ("LOCAL_ARTIFACTS_MODULE_NAME", "register_local_subclass")


LOCAL_ARTIFACTS_MODULE_NAME = "ccflow._local_artifacts"
_LOCAL_MODULE_DOC = "Auto-generated BaseModel subclasses created outside importable modules."

_SANITIZE_PATTERN = re.compile(r"[^0-9A-Za-z_]")
_LOCAL_KIND_COUNTERS: DefaultDict[str, count] = defaultdict(lambda: count())
_LOCAL_ARTIFACTS_MODULE: "ModuleType | None" = None


class _LocalArtifactsLoader(importlib.abc.Loader):
"""Minimal loader so importlib can reload our synthetic module if needed."""

def __init__(self, *, doc: str) -> None:
self._doc = doc

def create_module(self, spec: "ModuleSpec") -> "ModuleType | None":
"""Defer to default module creation (keeping importlib from recursing)."""
return None

def exec_module(self, module: "ModuleType") -> None:
module.__doc__ = module.__doc__ or self._doc


class _LocalArtifactsFinder(importlib.abc.MetaPathFinder):
"""Ensures importlib can rediscover the synthetic module when reloading."""

def find_spec(
self,
fullname: str,
path: Any,
target: "ModuleType | None" = None,
) -> "ModuleSpec | None":
if fullname != LOCAL_ARTIFACTS_MODULE_NAME:
return None
return _build_module_spec(fullname, _LOCAL_MODULE_DOC)


def _build_module_spec(name: str, doc: str) -> "ModuleSpec":
loader = _LocalArtifactsLoader(doc=doc)
spec = importlib.util.spec_from_loader(
name,
loader=loader,
origin="ccflow.local_persistence:_ensure_module",
)
if spec is None:
raise ImportError(f"Unable to create spec for dynamic module {name!r}.")
spec.has_location = False
return spec


def _ensure_finder_installed() -> None:
for finder in sys.meta_path:
if isinstance(finder, _LocalArtifactsFinder):
return
sys.meta_path.insert(0, _LocalArtifactsFinder())


def _ensure_module(name: str, doc: str) -> "ModuleType":
"""Materialize the synthetic module with a real spec/loader so importlib treats it like disk-backed code.

We only do this on demand, but once built the finder/spec/loader plumbing
keeps reload, pickling, and other importlib consumers happy. The Python docs recommend this approach instead of creating modules directly via the constructor."""
_ensure_finder_installed()
module = sys.modules.get(name)
if module is None:
# Create a proper ModuleSpec + loader so importlib reloads and introspection behave
# the same way they would for filesystem-backed modules
spec = _build_module_spec(name, doc)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore[union-attr]
sys.modules[name] = module
parent_name, _, attr = name.rpartition(".")
if parent_name:
parent_module = sys.modules.get(parent_name)
if parent_module and not hasattr(parent_module, attr):
setattr(parent_module, attr, module)
return module


def _get_local_artifacts_module() -> "ModuleType":
"""Lazily materialize the synthetic module to avoid errors during creation until needed."""
global _LOCAL_ARTIFACTS_MODULE
if _LOCAL_ARTIFACTS_MODULE is None:
_LOCAL_ARTIFACTS_MODULE = _ensure_module(LOCAL_ARTIFACTS_MODULE_NAME, _LOCAL_MODULE_DOC)
return _LOCAL_ARTIFACTS_MODULE


def _needs_registration(cls: Type[Any]) -> bool:
qualname = getattr(cls, "__qualname__", "")
return "<locals>" in qualname


def _sanitize_identifier(value: str, fallback: str) -> str:
sanitized = _SANITIZE_PATTERN.sub("_", value or "")
sanitized = sanitized.strip("_") or fallback
if sanitized[0].isdigit():
sanitized = f"_{sanitized}"
return sanitized


def _build_unique_name(*, kind_slug: str, name_hint: str) -> str:
sanitized_hint = _sanitize_identifier(name_hint, "BaseModel")
counter = _LOCAL_KIND_COUNTERS[kind_slug]
return f"{kind_slug}__{sanitized_hint}__{next(counter)}"


def register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None:
"""Register BaseModel subclasses created in local scopes."""
if getattr(cls, "__module__", "").startswith(LOCAL_ARTIFACTS_MODULE_NAME):
return
if not _needs_registration(cls):
return

name_hint = f"{getattr(cls, '__module__', '')}.{getattr(cls, '__qualname__', '')}"
kind_slug = _sanitize_identifier(kind, "model").lower()
unique_name = _build_unique_name(kind_slug=kind_slug, name_hint=name_hint)
artifacts_module = _get_local_artifacts_module()
setattr(artifacts_module, unique_name, cls)
cls.__module__ = artifacts_module.__name__
cls.__qualname__ = unique_name
setattr(cls, "__ccflow_dynamic_origin__", name_hint)
66 changes: 65 additions & 1 deletion ccflow/tests/evaluators/test_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import logging
from datetime import date
from typing import ClassVar
from unittest import TestCase

import pandas as pd
import pyarrow as pa

from ccflow import DateContext, DateRangeContext, Evaluator, FlowOptionsOverride, ModelEvaluationContext, NullContext
from ccflow import (
CallableModel,
ContextBase,
DateContext,
DateRangeContext,
Evaluator,
Flow,
FlowOptions,
FlowOptionsOverride,
GenericResult,
ModelEvaluationContext,
NullContext,
)
from ccflow.evaluators import (
FallbackEvaluator,
GraphEvaluator,
Expand All @@ -16,6 +29,7 @@
combine_evaluators,
get_dependency_graph,
)
from ccflow.tests.local_helpers import build_meta_sensor_planner, build_nested_graph_chain

from .util import CircularModel, MyDateCallable, MyDateRangeCallable, MyRaisingCallable, NodeModel, ResultModel

Expand Down Expand Up @@ -206,6 +220,22 @@ def test_logging_options_nested(self):
self.assertIn("End evaluation of __call__", captured.records[2].getMessage())
self.assertIn("time elapsed", captured.records[2].getMessage())

def test_meta_callable_logged_with_evaluator(self):
"""Meta callables can spin up request-scoped specialists and still inherit evaluator instrumentation."""
SensorQuery, MetaSensorPlanner, captured = build_meta_sensor_planner()
evaluator = LoggingEvaluator(log_level=logging.INFO, verbose=False)
request = SensorQuery(sensor_type="pressure-valve", site="orbital-lab", window=4)
meta = MetaSensorPlanner(warm_start=2)
with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator)):
with self.assertLogs(level=logging.INFO) as captured_logs:
result = meta(request)
self.assertEqual(result.value, "planner:orbital-lab:pressure-valve:6")
start_messages = [record.getMessage() for record in captured_logs.records if "Start evaluation" in record.getMessage()]
self.assertEqual(len(start_messages), 2)
self.assertTrue(any("MetaSensorPlanner" in msg for msg in start_messages))
specialist_name = captured["callable_cls"].__name__
self.assertTrue(any(specialist_name in msg for msg in start_messages))


class SubContext(DateContext):
pass
Expand Down Expand Up @@ -473,3 +503,37 @@ def test_graph_evaluator_circular(self):
evaluator = GraphEvaluator()
with FlowOptionsOverride(options={"evaluator": evaluator}):
self.assertRaises(Exception, root, context)

def test_graph_evaluator_with_local_models_and_cache(self):
ParentCls, ChildCls = build_nested_graph_chain()
ChildCls.call_count = 0
model = ParentCls(child=ChildCls())
evaluator = MultiEvaluator(evaluators=[GraphEvaluator(), MemoryCacheEvaluator()])
with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator, cacheable=True)):
first = model(NullContext())
second = model(NullContext())
self.assertEqual(first.value, second.value)
self.assertEqual(ChildCls.call_count, 1)


class TestMemoryCacheEvaluatorLocal(TestCase):
def test_memory_cache_handles_local_context_and_callable(self):
class LocalContext(ContextBase):
value: int

class LocalModel(CallableModel):
call_count: ClassVar[int] = 0

@Flow.call
def __call__(self, context: LocalContext) -> GenericResult:
type(self).call_count += 1
return GenericResult(value=context.value * 2)

evaluator = MemoryCacheEvaluator()
LocalModel.call_count = 0
model = LocalModel()
with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator, cacheable=True)):
result1 = model(LocalContext(value=5))
result2 = model(LocalContext(value=5))
self.assertEqual(result1.value, result2.value)
self.assertEqual(LocalModel.call_count, 1)
90 changes: 90 additions & 0 deletions ccflow/tests/local_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Shared helpers for constructing local-scope contexts/models in tests."""

from typing import ClassVar, Dict, Tuple, Type

from ccflow import CallableModel, ContextBase, Flow, GenericResult, GraphDepList, NullContext


def build_meta_sensor_planner():
"""Return a (SensorQuery, MetaSensorPlanner, captured) tuple for meta-callable tests."""

captured: Dict[str, Type] = {}

class SensorQuery(ContextBase):
sensor_type: str
site: str
window: int

class MetaSensorPlanner(CallableModel):
warm_start: int = 2

@Flow.call
def __call__(self, context: SensorQuery) -> GenericResult:
# Define request-scoped specialist wiring with a bespoke context/model pair.
class SpecialistContext(ContextBase):
sensor_type: str
window: int
pipeline: str

class SpecialistCallable(CallableModel):
pipeline: str

@Flow.call
def __call__(self, context: SpecialistContext) -> GenericResult:
payload = f"{self.pipeline}:{context.sensor_type}:{context.window}"
return GenericResult(value=payload)

captured["context_cls"] = SpecialistContext
captured["callable_cls"] = SpecialistCallable

window = context.window + self.warm_start
local_context = SpecialistContext(
sensor_type=context.sensor_type,
window=window,
pipeline=f"{context.site}-calibration",
)
specialist = SpecialistCallable(pipeline=f"planner:{context.site}")
return specialist(local_context)

return SensorQuery, MetaSensorPlanner, captured


def build_local_callable(name: str = "LocalCallable") -> Type[CallableModel]:
class _LocalCallable(CallableModel):
@Flow.call
def __call__(self, context: NullContext) -> GenericResult:
return GenericResult(value="local")

_LocalCallable.__name__ = name
return _LocalCallable


def build_local_context(name: str = "LocalContext") -> Type[ContextBase]:
class _LocalContext(ContextBase):
value: int

_LocalContext.__name__ = name
return _LocalContext


def build_nested_graph_chain() -> Tuple[Type[CallableModel], Type[CallableModel]]:
class LocalLeaf(CallableModel):
call_count: ClassVar[int] = 0

@Flow.call
def __call__(self, context: NullContext) -> GenericResult:
type(self).call_count += 1
return GenericResult(value="leaf")

class LocalParent(CallableModel):
child: LocalLeaf

@Flow.call
def __call__(self, context: NullContext) -> GenericResult:
return self.child(context)

@Flow.deps
def __deps__(self, context: NullContext) -> GraphDepList:
return [(self.child, [context])]

return LocalParent, LocalLeaf
Loading