From 58b4650c901fb6a910e80689e1133d8efac2a183 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 19 Dec 2025 19:42:50 -0500 Subject: [PATCH 1/8] Register locally defined CallableModels and Contexts in a module to work with PyObjectPath Signed-off-by: Nijat Khanbabayev --- ccflow/base.py | 12 ++ ccflow/callable.py | 2 + ccflow/local_persistence.py | 72 +++++++++ ccflow/tests/evaluators/test_common.py | 50 ++++++- ccflow/tests/local_helpers.py | 46 ++++++ ccflow/tests/test_base.py | 70 ++++++++- ccflow/tests/test_callable.py | 193 ++++++++++++++++++++++++- ccflow/tests/test_local_persistence.py | 81 +++++++++++ docs/wiki/Key-Features.md | 3 + 9 files changed, 524 insertions(+), 5 deletions(-) create mode 100644 ccflow/local_persistence.py create mode 100644 ccflow/tests/local_helpers.py create mode 100644 ccflow/tests/test_local_persistence.py diff --git a/ccflow/base.py b/ccflow/base.py index 18e2eac..adfec2e 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -30,6 +30,7 @@ from typing_extensions import Self from .exttypes.pyobjectpath import PyObjectPath +from .local_persistence import register_local_subclass log = logging.getLogger(__name__) @@ -156,6 +157,15 @@ class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta - Registration by name, and coercion from string name to allow for object re-use from the configs """ + __ccflow_local_registration_kind__: ClassVar[str] = "model" + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + # __pydantic_init_subclass__ is the supported hook point once Pydantic finishes wiring the subclass. + super().__pydantic_init_subclass__(**kwargs) + kind = getattr(cls, "__ccflow_local_registration_kind__", "model") + register_local_subclass(cls, kind=kind) + @computed_field( alias="_target_", repr=False, @@ -820,6 +830,8 @@ class ContextBase(ResultBase): that is an input into another CallableModel. """ + __ccflow_local_registration_kind__: ClassVar[str] = "context" + model_config = ConfigDict( frozen=True, arbitrary_types_allowed=False, diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..595079d 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -74,6 +74,8 @@ class _CallableModel(BaseModel, abc.ABC): The purpose of this class is to provide type definitions of context_type and return_type. """ + __ccflow_local_registration_kind__: ClassVar[str] = "callable_model" + model_config = ConfigDict( ignored_types=(property,), ) diff --git a/ccflow/local_persistence.py b/ccflow/local_persistence.py new file mode 100644 index 0000000..7331a9b --- /dev/null +++ b/ccflow/local_persistence.py @@ -0,0 +1,72 @@ +"""Helpers for persisting BaseModel-derived classes defined inside local scopes.""" + +from __future__ import annotations + +import re +import sys +from collections import defaultdict +from itertools import count +from types import ModuleType +from typing import Any, DefaultDict, Type + +__all__ = ("LOCAL_ARTIFACTS_MODULE_NAME", "register_local_subclass") + + +LOCAL_ARTIFACTS_MODULE_NAME = "ccflow._local_artifacts" +_LOCAL_MODULE_DOC = "Auto-generated BaseModel subclasses created outside importable modules." + +_SANITIZE_PATTERN = re.compile(r"[^0-9A-Za-z_]") +_LOCAL_KIND_COUNTERS: DefaultDict[str, count] = defaultdict(lambda: count()) + + +def _ensure_module(name: str, doc: str) -> ModuleType: + """Ensure the dynamic module exists so import paths remain stable.""" + module = sys.modules.get(name) + if module is None: + module = ModuleType(name, doc) + sys.modules[name] = module + parent_name, _, attr = name.rpartition(".") + if parent_name: + parent_module = sys.modules.get(parent_name) + if parent_module and not hasattr(parent_module, attr): + setattr(parent_module, attr, module) + return module + + +_LOCAL_ARTIFACTS_MODULE = _ensure_module(LOCAL_ARTIFACTS_MODULE_NAME, _LOCAL_MODULE_DOC) + + +def _needs_registration(cls: Type[Any]) -> bool: + module = getattr(cls, "__module__", "") + qualname = getattr(cls, "__qualname__", "") + return "" in qualname or module == "__main__" + + +def _sanitize_identifier(value: str, fallback: str) -> str: + sanitized = _SANITIZE_PATTERN.sub("_", value or "") + sanitized = sanitized.strip("_") or fallback + if sanitized[0].isdigit(): + sanitized = f"_{sanitized}" + return sanitized + + +def _build_unique_name(*, kind_slug: str, name_hint: str) -> str: + sanitized_hint = _sanitize_identifier(name_hint, "BaseModel") + counter = _LOCAL_KIND_COUNTERS[kind_slug] + return f"{kind_slug}__{sanitized_hint}__{next(counter)}" + + +def register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None: + """Register BaseModel subclasses created in local scopes.""" + if getattr(cls, "__module__", "").startswith(LOCAL_ARTIFACTS_MODULE_NAME): + return + if not _needs_registration(cls): + return + + name_hint = f"{getattr(cls, '__module__', '')}.{getattr(cls, '__qualname__', '')}" + kind_slug = _sanitize_identifier(kind, "model").lower() + unique_name = _build_unique_name(kind_slug=kind_slug, name_hint=name_hint) + setattr(_LOCAL_ARTIFACTS_MODULE, unique_name, cls) + cls.__module__ = _LOCAL_ARTIFACTS_MODULE.__name__ + cls.__qualname__ = unique_name + setattr(cls, "__ccflow_dynamic_origin__", name_hint) diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index 7b5abf0..4a4b503 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -1,11 +1,24 @@ import logging from datetime import date +from typing import ClassVar from unittest import TestCase import pandas as pd import pyarrow as pa -from ccflow import DateContext, DateRangeContext, Evaluator, FlowOptionsOverride, ModelEvaluationContext, NullContext +from ccflow import ( + CallableModel, + ContextBase, + DateContext, + DateRangeContext, + Evaluator, + Flow, + FlowOptions, + FlowOptionsOverride, + GenericResult, + ModelEvaluationContext, + NullContext, +) from ccflow.evaluators import ( FallbackEvaluator, GraphEvaluator, @@ -16,6 +29,7 @@ combine_evaluators, get_dependency_graph, ) +from ccflow.tests.local_helpers import build_nested_graph_chain from .util import CircularModel, MyDateCallable, MyDateRangeCallable, MyRaisingCallable, NodeModel, ResultModel @@ -473,3 +487,37 @@ def test_graph_evaluator_circular(self): evaluator = GraphEvaluator() with FlowOptionsOverride(options={"evaluator": evaluator}): self.assertRaises(Exception, root, context) + + def test_graph_evaluator_with_local_models_and_cache(self): + ParentCls, ChildCls = build_nested_graph_chain() + ChildCls.call_count = 0 + model = ParentCls(child=ChildCls()) + evaluator = MultiEvaluator(evaluators=[GraphEvaluator(), MemoryCacheEvaluator()]) + with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator, cacheable=True)): + first = model(NullContext()) + second = model(NullContext()) + self.assertEqual(first.value, second.value) + self.assertEqual(ChildCls.call_count, 1) + + +class TestMemoryCacheEvaluatorLocal(TestCase): + def test_memory_cache_handles_local_context_and_callable(self): + class LocalContext(ContextBase): + value: int + + class LocalModel(CallableModel): + call_count: ClassVar[int] = 0 + + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + type(self).call_count += 1 + return GenericResult(value=context.value * 2) + + evaluator = MemoryCacheEvaluator() + LocalModel.call_count = 0 + model = LocalModel() + with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator, cacheable=True)): + result1 = model(LocalContext(value=5)) + result2 = model(LocalContext(value=5)) + self.assertEqual(result1.value, result2.value) + self.assertEqual(LocalModel.call_count, 1) diff --git a/ccflow/tests/local_helpers.py b/ccflow/tests/local_helpers.py new file mode 100644 index 0000000..5f616cf --- /dev/null +++ b/ccflow/tests/local_helpers.py @@ -0,0 +1,46 @@ +"""Shared helpers for constructing local-scope contexts/models in tests.""" + +from typing import ClassVar, Tuple, Type + +from ccflow import CallableModel, ContextBase, Flow, GenericResult, GraphDepList, NullContext + + +def build_local_callable(name: str = "LocalCallable") -> Type[CallableModel]: + class _LocalCallable(CallableModel): + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + return GenericResult(value="local") + + _LocalCallable.__name__ = name + return _LocalCallable + + +def build_local_context(name: str = "LocalContext") -> Type[ContextBase]: + class _LocalContext(ContextBase): + value: int + + _LocalContext.__name__ = name + return _LocalContext + + +def build_nested_graph_chain() -> Tuple[Type[CallableModel], Type[CallableModel]]: + class LocalLeaf(CallableModel): + call_count: ClassVar[int] = 0 + + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + type(self).call_count += 1 + return GenericResult(value="leaf") + + class LocalParent(CallableModel): + child: LocalLeaf + + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + return self.child(context) + + @Flow.deps + def __deps__(self, context: NullContext) -> GraphDepList: + return [(self.child, [context])] + + return LocalParent, LocalLeaf diff --git a/ccflow/tests/test_base.py b/ccflow/tests/test_base.py index 3c6fc02..d6d9413 100644 --- a/ccflow/tests/test_base.py +++ b/ccflow/tests/test_base.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List -from unittest import TestCase +from unittest import TestCase, mock -from pydantic import ConfigDict, ValidationError +from pydantic import BaseModel as PydanticBaseModel, ConfigDict, ValidationError -from ccflow import BaseModel, PyObjectPath +from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, NullContext, PyObjectPath +from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME class ModelA(BaseModel): @@ -106,6 +107,20 @@ def test_type_after_assignment(self): self.assertIsInstance(m.type_, PyObjectPath) self.assertEqual(m.type_, path) + def test_pyobjectpath_requires_ccflow_local_registration(self): + class PlainLocalModel(PydanticBaseModel): + value: int + + with self.assertRaises(ValueError): + PyObjectPath.validate(PlainLocalModel) + + class LocalCcflowModel(BaseModel): + value: int + + path = PyObjectPath.validate(LocalCcflowModel) + self.assertEqual(path.object, LocalCcflowModel) + self.assertTrue(str(path).startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) + def test_validate(self): self.assertEqual(ModelA.model_validate({"x": "foo"}), ModelA(x="foo")) type_ = "ccflow.tests.test_base.ModelA" @@ -157,3 +172,52 @@ def test_widget(self): {"expanded": True, "root": "bar"}, ), ) + + +class TestLocalRegistrationKind(TestCase): + def test_base_model_defaults_to_model_kind(self): + with mock.patch("ccflow.base.register_local_subclass") as register: + + class LocalModel(BaseModel): + value: int + + register.assert_called_once() + args, kwargs = register.call_args + self.assertIs(args[0], LocalModel) + self.assertEqual(kwargs["kind"], "model") + + def test_context_defaults_to_context_kind(self): + with mock.patch("ccflow.base.register_local_subclass") as register: + + class LocalContext(ContextBase): + value: int + + register.assert_called_once() + args, kwargs = register.call_args + self.assertIs(args[0], LocalContext) + self.assertEqual(kwargs["kind"], "context") + + def test_callable_defaults_to_callable_kind(self): + with mock.patch("ccflow.base.register_local_subclass") as register: + + class LocalCallable(CallableModel): + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + return GenericResult(value="ok") + + register.assert_called_once() + args, kwargs = register.call_args + self.assertIs(args[0], LocalCallable) + self.assertEqual(kwargs["kind"], "callable_model") + + def test_explicit_override_respected(self): + with mock.patch("ccflow.base.register_local_subclass") as register: + + class CustomKind(BaseModel): + __ccflow_local_registration_kind__ = "custom" + value: int + + register.assert_called_once() + args, kwargs = register.call_args + self.assertIs(args[0], CustomKind) + self.assertEqual(kwargs["kind"], "custom") diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 0ba24ea..5b0f0a9 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -1,5 +1,6 @@ +import sys from pickle import dumps as pdumps, loads as ploads -from typing import Generic, List, Optional, Tuple, Type, TypeVar, Union +from typing import ClassVar, Generic, List, Optional, Tuple, Type, TypeVar, Union from unittest import TestCase import ray @@ -21,6 +22,39 @@ ResultType, WrapperModel, ) +from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME +from ccflow.tests.local_helpers import build_local_callable, build_local_context + + +def _find_registered_name(module, cls): + for name, value in vars(module).items(): + if value is cls: + return name + raise AssertionError(f"{cls} not found in {module.__name__}") + + +def _build_main_module_callable(): + namespace = { + "__name__": "__main__", + "ClassVar": ClassVar, + "CallableModel": CallableModel, + "Flow": Flow, + "GenericResult": GenericResult, + "NullContext": NullContext, + } + exec( + """ +class MainModuleCallable(CallableModel): + call_count: ClassVar[int] = 0 + + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + type(self).call_count += 1 + return GenericResult(value=\"main\") +""", + namespace, + ) + return namespace["MainModuleCallable"] class MyContext(ContextBase): @@ -493,6 +527,163 @@ def test_union_return(self): self.assertEqual(result.a, 1) +class TestCallableModelRegistration(TestCase): + def test_module_level_class_retains_module(self): + self.assertEqual(MyCallable.__module__, __name__) + dynamic_module = sys.modules.get(LOCAL_ARTIFACTS_MODULE_NAME) + if dynamic_module: + self.assertFalse(any(value is MyCallable for value in vars(dynamic_module).values())) + + def test_module_level_context_retains_module(self): + self.assertEqual(MyContext.__module__, __name__) + dynamic_module = sys.modules.get(LOCAL_ARTIFACTS_MODULE_NAME) + if dynamic_module: + self.assertFalse(any(value is MyContext for value in vars(dynamic_module).values())) + + def test_local_class_moves_under_dynamic_namespace(self): + LocalCallable = build_local_callable() + module_name = LocalCallable.__module__ + self.assertEqual(module_name, LOCAL_ARTIFACTS_MODULE_NAME) + dynamic_module = sys.modules[module_name] + self.assertIs(getattr(dynamic_module, LocalCallable.__qualname__), LocalCallable) + self.assertIn("", getattr(LocalCallable, "__ccflow_dynamic_origin__")) + result = LocalCallable()(NullContext()) + self.assertEqual(result.value, "local") + + def test_multiple_local_definitions_have_unique_identifiers(self): + first = build_local_callable() + second = build_local_callable() + self.assertNotEqual(first.__qualname__, second.__qualname__) + dynamic_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + self.assertIs(getattr(dynamic_module, first.__qualname__), first) + self.assertIs(getattr(dynamic_module, second.__qualname__), second) + + def test_context_and_callable_same_name_do_not_collide(self): + def build_conflicting(): + class LocalThing(ContextBase): + value: int + + context_cls = LocalThing + + class LocalThing(CallableModel): + @Flow.call + def __call__(self, context: context_cls) -> GenericResult: + return GenericResult(value=context.value) + + callable_cls = LocalThing + return context_cls, callable_cls + + LocalContext, LocalCallable = build_conflicting() + locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + ctx_attr = _find_registered_name(locals_module, LocalContext) + model_attr = _find_registered_name(locals_module, LocalCallable) + self.assertTrue(ctx_attr.startswith("context__")) + self.assertTrue(model_attr.startswith("callable_model__")) + ctx_hint = ctx_attr.partition("__")[2].rsplit("__", 1)[0] + model_hint = model_attr.partition("__")[2].rsplit("__", 1)[0] + self.assertEqual(ctx_hint, model_hint) + self.assertEqual(getattr(locals_module, ctx_attr), LocalContext) + self.assertEqual(getattr(locals_module, model_attr), LocalCallable) + self.assertNotEqual(ctx_attr, model_attr, "Kind-prefixed names keep contexts and callables distinct.") + + def test_local_callable_type_path_roundtrip(self): + LocalCallable = build_local_callable() + instance = LocalCallable() + path = instance.type_ + self.assertEqual(path.object, LocalCallable) + self.assertTrue(str(path).startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) + + def test_local_context_type_path_roundtrip(self): + LocalContext = build_local_context() + ctx = LocalContext(value=10) + path = ctx.type_ + self.assertEqual(path.object, LocalContext) + self.assertTrue(str(path).startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) + + def test_exec_defined_main_module_class_registered(self): + MainCallable = _build_main_module_callable() + self.assertEqual(MainCallable.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + self.assertTrue(getattr(MainCallable, "__ccflow_dynamic_origin__").startswith("__main__.")) + model = MainCallable() + MainCallable.call_count = 0 + result = model(NullContext()) + self.assertEqual(result.value, "main") + self.assertEqual(MainCallable.call_count, 1) + + def test_local_context_and_model_serialization_roundtrip(self): + class LocalContext(ContextBase): + value: int + + class LocalModel(CallableModel): + factor: int = 2 + + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + return GenericResult(value=context.value * self.factor) + + instance = LocalModel(factor=5) + context = LocalContext(value=7) + serialized_model = instance.model_dump(mode="python") + restored_model = LocalModel.model_validate(serialized_model) + self.assertEqual(restored_model, instance) + serialized_context = context.model_dump(mode="python") + restored_context = LocalContext.model_validate(serialized_context) + self.assertEqual(restored_context, context) + + def test_multiple_nested_levels_unique_paths(self): + created = [] + + def layer(depth: int): + class LocalContext(ContextBase): + value: int + + class LocalModel(CallableModel): + multiplier: int = depth + 1 + call_count: ClassVar[int] = 0 + + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + type(self).call_count += 1 + return GenericResult(value=context.value * self.multiplier) + + created.append((depth, LocalContext, LocalModel)) + + if depth < 2: + + def inner(): + layer(depth + 1) + + inner() + + def sibling_group(): + class LocalContext(ContextBase): + value: int + + class LocalModel(CallableModel): + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + return GenericResult(value=context.value) + + created.append(("sibling", LocalContext, LocalModel)) + + layer(0) + sibling_group() + sibling_group() + + locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + + context_names = {ctx.__qualname__ for _, ctx, _ in created} + model_names = {model.__qualname__ for _, _, model in created} + self.assertEqual(len(context_names), len(created)) + self.assertEqual(len(model_names), len(created)) + + for _, ctx_cls, model_cls in created: + self.assertIs(getattr(locals_module, ctx_cls.__qualname__), ctx_cls) + self.assertIs(getattr(locals_module, model_cls.__qualname__), model_cls) + self.assertIn("", getattr(ctx_cls, "__ccflow_dynamic_origin__")) + self.assertIn("", getattr(model_cls, "__ccflow_dynamic_origin__")) + + class TestWrapperModel(TestCase): def test_wrapper(self): md = MetaData(name="foo", description="My Foo") diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py new file mode 100644 index 0000000..dccd2ef --- /dev/null +++ b/ccflow/tests/test_local_persistence.py @@ -0,0 +1,81 @@ +from collections import defaultdict +from itertools import count +from unittest import TestCase, mock + +import ccflow.local_persistence as local_persistence +from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, NullContext + + +class ModuleLevelModel(BaseModel): + value: int + + +class ModuleLevelContext(ContextBase): + value: int + + +class ModuleLevelCallable(CallableModel): + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + return GenericResult(value="ok") + + +class TestNeedsRegistration(TestCase): + def test_module_level_ccflow_classes_do_not_need_registration(self): + for cls in (ModuleLevelModel, ModuleLevelContext, ModuleLevelCallable): + with self.subTest(cls=cls): + self.assertFalse(local_persistence._needs_registration(cls)) + + def test_local_scope_class_needs_registration(self): + def build_class(): + class LocalClass: + pass + + return LocalClass + + LocalClass = build_class() + self.assertTrue(local_persistence._needs_registration(LocalClass)) + + def test_main_module_class_needs_registration(self): + cls = type("MainModuleClass", (), {}) + cls.__module__ = "__main__" + cls.__qualname__ = "MainModuleClass" + self.assertTrue(local_persistence._needs_registration(cls)) + + def test_module_level_non_ccflow_class_does_not_need_registration(self): + cls = type("ExternalClass", (), {}) + cls.__module__ = "ccflow.tests.test_local_persistence" + cls.__qualname__ = "ExternalClass" + self.assertFalse(local_persistence._needs_registration(cls)) + + +class TestBuildUniqueName(TestCase): + def test_build_unique_name_sanitizes_hint_and_increments_counter(self): + with mock.patch.object(local_persistence, "_LOCAL_KIND_COUNTERS", defaultdict(lambda: count())): + name = local_persistence._build_unique_name( + kind_slug="callable_model", + name_hint="module.path:MyCallable", + ) + self.assertTrue(name.startswith("callable_model__module_path_MyCallable_One_")) + self.assertTrue(name.endswith("__0")) + + second = local_persistence._build_unique_name( + kind_slug="callable_model", + name_hint="module.path:MyCallable", + ) + self.assertTrue(second.endswith("__1")) + + def test_counters_are_namespaced_by_kind(self): + with mock.patch.object(local_persistence, "_LOCAL_KIND_COUNTERS", defaultdict(lambda: count())): + first_context = local_persistence._build_unique_name(kind_slug="context", name_hint="Context") + first_callable = local_persistence._build_unique_name(kind_slug="callable_model", name_hint="Callable") + second_context = local_persistence._build_unique_name(kind_slug="context", name_hint="Other") + + self.assertTrue(first_context.endswith("__0")) + self.assertTrue(first_callable.endswith("__0")) + self.assertTrue(second_context.endswith("__1")) + + def test_empty_hint_uses_fallback(self): + with mock.patch.object(local_persistence, "_LOCAL_KIND_COUNTERS", defaultdict(lambda: count())): + name = local_persistence._build_unique_name(kind_slug="model", name_hint="") + self.assertEqual(name, "model__BaseModel__0") diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..1f5bd18 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -22,6 +22,9 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan `CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +> [!NOTE] +> `CallableModel`, `ContextBase`, and other `ccflow.BaseModel` subclasses declared inside local scopes (functions, tests, notebooks, REPLs, etc.) are automatically persisted under `ccflow._local_artifacts`. Each stored class is prefixed with its kind (for example, `callable_model__...` versus `context__...`) so PyObjectPath-based evaluators can serialize/deserialize them without collisions. The backing module is append-only; long-lived processes should avoid generating unbounded unique classes if cleanup is required. + ## Model Registry A `ModelRegistry` is a named collection of models. From 7d852437010b4be02ffb386aba99d57357694c59 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 19 Dec 2025 19:52:15 -0500 Subject: [PATCH 2/8] Try to improve code coverage Signed-off-by: Nijat Khanbabayev --- ccflow/tests/test_base.py | 9 ++++++--- ccflow/tests/test_callable.py | 12 +++++++++++- docs/wiki/Key-Features.md | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ccflow/tests/test_base.py b/ccflow/tests/test_base.py index d6d9413..656d841 100644 --- a/ccflow/tests/test_base.py +++ b/ccflow/tests/test_base.py @@ -205,10 +205,13 @@ class LocalCallable(CallableModel): def __call__(self, context: NullContext) -> GenericResult: return GenericResult(value="ok") - register.assert_called_once() - args, kwargs = register.call_args - self.assertIs(args[0], LocalCallable) + result = LocalCallable()(NullContext()) + + calls_for_local = [(args, kwargs) for args, kwargs in register.call_args_list if args and args[0] is LocalCallable] + self.assertEqual(len(calls_for_local), 1) + _, kwargs = calls_for_local[0] self.assertEqual(kwargs["kind"], "callable_model") + self.assertEqual(result.value, "ok") def test_explicit_override_respected(self): with mock.patch("ccflow.base.register_local_subclass") as register: diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 5b0f0a9..80ada25 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -623,12 +623,16 @@ def __call__(self, context: LocalContext) -> GenericResult: instance = LocalModel(factor=5) context = LocalContext(value=7) + result = instance(context) + self.assertEqual(result.value, 35) serialized_model = instance.model_dump(mode="python") restored_model = LocalModel.model_validate(serialized_model) self.assertEqual(restored_model, instance) serialized_context = context.model_dump(mode="python") restored_context = LocalContext.model_validate(serialized_context) self.assertEqual(restored_context, context) + restored_result = restored_model(restored_context) + self.assertEqual(restored_result.value, 35) def test_multiple_nested_levels_unique_paths(self): created = [] @@ -677,11 +681,17 @@ def __call__(self, context: LocalContext) -> GenericResult: self.assertEqual(len(context_names), len(created)) self.assertEqual(len(model_names), len(created)) - for _, ctx_cls, model_cls in created: + for label, ctx_cls, model_cls in created: self.assertIs(getattr(locals_module, ctx_cls.__qualname__), ctx_cls) self.assertIs(getattr(locals_module, model_cls.__qualname__), model_cls) self.assertIn("", getattr(ctx_cls, "__ccflow_dynamic_origin__")) self.assertIn("", getattr(model_cls, "__ccflow_dynamic_origin__")) + ctx_instance = ctx_cls(value=4) + result = model_cls()(ctx_instance) + if isinstance(label, str): # sibling group + self.assertEqual(result.value, ctx_instance.value) + else: + self.assertEqual(result.value, ctx_instance.value * (label + 1)) class TestWrapperModel(TestCase): diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 1f5bd18..ea78c63 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -23,7 +23,7 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. > [!NOTE] -> `CallableModel`, `ContextBase`, and other `ccflow.BaseModel` subclasses declared inside local scopes (functions, tests, notebooks, REPLs, etc.) are automatically persisted under `ccflow._local_artifacts`. Each stored class is prefixed with its kind (for example, `callable_model__...` versus `context__...`) so PyObjectPath-based evaluators can serialize/deserialize them without collisions. The backing module is append-only; long-lived processes should avoid generating unbounded unique classes if cleanup is required. +> `CallableModel`, `ContextBase`, and other `ccflow.BaseModel` subclasses declared inside local scopes (functions, tests, etc.) are automatically persisted under `ccflow._local_artifacts`. Each stored class is prefixed with its kind (for example, `callable_model__...` versus `context__...`) so PyObjectPath-based evaluators can serialize/deserialize them without collisions. The backing module is append-only; long-lived processes should avoid generating unbounded unique classes if cleanup is required. ## Model Registry From 9fb2a924c8582cdb23461935090d0d1c11078ccd Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 19 Dec 2025 20:20:17 -0500 Subject: [PATCH 3/8] Only register not __main__ Signed-off-by: Nijat Khanbabayev --- ccflow/local_persistence.py | 3 +-- ccflow/tests/test_callable.py | 34 -------------------------- ccflow/tests/test_local_persistence.py | 4 +-- 3 files changed, 3 insertions(+), 38 deletions(-) diff --git a/ccflow/local_persistence.py b/ccflow/local_persistence.py index 7331a9b..7a2cc7b 100644 --- a/ccflow/local_persistence.py +++ b/ccflow/local_persistence.py @@ -37,9 +37,8 @@ def _ensure_module(name: str, doc: str) -> ModuleType: def _needs_registration(cls: Type[Any]) -> bool: - module = getattr(cls, "__module__", "") qualname = getattr(cls, "__qualname__", "") - return "" in qualname or module == "__main__" + return "" in qualname def _sanitize_identifier(value: str, fallback: str) -> str: diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 80ada25..329c21a 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -33,30 +33,6 @@ def _find_registered_name(module, cls): raise AssertionError(f"{cls} not found in {module.__name__}") -def _build_main_module_callable(): - namespace = { - "__name__": "__main__", - "ClassVar": ClassVar, - "CallableModel": CallableModel, - "Flow": Flow, - "GenericResult": GenericResult, - "NullContext": NullContext, - } - exec( - """ -class MainModuleCallable(CallableModel): - call_count: ClassVar[int] = 0 - - @Flow.call - def __call__(self, context: NullContext) -> GenericResult: - type(self).call_count += 1 - return GenericResult(value=\"main\") -""", - namespace, - ) - return namespace["MainModuleCallable"] - - class MyContext(ContextBase): a: str @@ -600,16 +576,6 @@ def test_local_context_type_path_roundtrip(self): self.assertEqual(path.object, LocalContext) self.assertTrue(str(path).startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) - def test_exec_defined_main_module_class_registered(self): - MainCallable = _build_main_module_callable() - self.assertEqual(MainCallable.__module__, LOCAL_ARTIFACTS_MODULE_NAME) - self.assertTrue(getattr(MainCallable, "__ccflow_dynamic_origin__").startswith("__main__.")) - model = MainCallable() - MainCallable.call_count = 0 - result = model(NullContext()) - self.assertEqual(result.value, "main") - self.assertEqual(MainCallable.call_count, 1) - def test_local_context_and_model_serialization_roundtrip(self): class LocalContext(ContextBase): value: int diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index dccd2ef..bebf0bf 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -36,11 +36,11 @@ class LocalClass: LocalClass = build_class() self.assertTrue(local_persistence._needs_registration(LocalClass)) - def test_main_module_class_needs_registration(self): + def test_main_module_class_does_not_need_registration(self): cls = type("MainModuleClass", (), {}) cls.__module__ = "__main__" cls.__qualname__ = "MainModuleClass" - self.assertTrue(local_persistence._needs_registration(cls)) + self.assertFalse(local_persistence._needs_registration(cls)) def test_module_level_non_ccflow_class_does_not_need_registration(self): cls = type("ExternalClass", (), {}) From 32ae2196d2679a44a85dc99a6f5cec0a20f17173 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 19 Dec 2025 20:35:22 -0500 Subject: [PATCH 4/8] Add dummy example with CallableModel making callable models Signed-off-by: Nijat Khanbabayev --- ccflow/tests/evaluators/test_common.py | 18 +++++++++- ccflow/tests/local_helpers.py | 46 +++++++++++++++++++++++++- ccflow/tests/test_callable.py | 17 +++++++++- 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index 4a4b503..fa85958 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -29,7 +29,7 @@ combine_evaluators, get_dependency_graph, ) -from ccflow.tests.local_helpers import build_nested_graph_chain +from ccflow.tests.local_helpers import build_meta_sensor_planner, build_nested_graph_chain from .util import CircularModel, MyDateCallable, MyDateRangeCallable, MyRaisingCallable, NodeModel, ResultModel @@ -220,6 +220,22 @@ def test_logging_options_nested(self): self.assertIn("End evaluation of __call__", captured.records[2].getMessage()) self.assertIn("time elapsed", captured.records[2].getMessage()) + def test_meta_callable_logged_with_evaluator(self): + """Meta callables can spin up request-scoped specialists and still inherit evaluator instrumentation.""" + SensorQuery, MetaSensorPlanner, captured = build_meta_sensor_planner() + evaluator = LoggingEvaluator(log_level=logging.INFO, verbose=False) + request = SensorQuery(sensor_type="pressure-valve", site="orbital-lab", window=4) + meta = MetaSensorPlanner(warm_start=2) + with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator)): + with self.assertLogs(level=logging.INFO) as captured_logs: + result = meta(request) + self.assertEqual(result.value, "planner:orbital-lab:pressure-valve:6") + start_messages = [record.getMessage() for record in captured_logs.records if "Start evaluation" in record.getMessage()] + self.assertEqual(len(start_messages), 2) + self.assertTrue(any("MetaSensorPlanner" in msg for msg in start_messages)) + specialist_name = captured["callable_cls"].__name__ + self.assertTrue(any(specialist_name in msg for msg in start_messages)) + class SubContext(DateContext): pass diff --git a/ccflow/tests/local_helpers.py b/ccflow/tests/local_helpers.py index 5f616cf..b5ffcdf 100644 --- a/ccflow/tests/local_helpers.py +++ b/ccflow/tests/local_helpers.py @@ -1,10 +1,54 @@ """Shared helpers for constructing local-scope contexts/models in tests.""" -from typing import ClassVar, Tuple, Type +from typing import ClassVar, Dict, Tuple, Type from ccflow import CallableModel, ContextBase, Flow, GenericResult, GraphDepList, NullContext +def build_meta_sensor_planner(): + """Return a (SensorQuery, MetaSensorPlanner, captured) tuple for meta-callable tests.""" + + captured: Dict[str, Type] = {} + + class SensorQuery(ContextBase): + sensor_type: str + site: str + window: int + + class MetaSensorPlanner(CallableModel): + warm_start: int = 2 + + @Flow.call + def __call__(self, context: SensorQuery) -> GenericResult: + # Define request-scoped specialist wiring with a bespoke context/model pair. + class SpecialistContext(ContextBase): + sensor_type: str + window: int + pipeline: str + + class SpecialistCallable(CallableModel): + pipeline: str + + @Flow.call + def __call__(self, context: SpecialistContext) -> GenericResult: + payload = f"{self.pipeline}:{context.sensor_type}:{context.window}" + return GenericResult(value=payload) + + captured["context_cls"] = SpecialistContext + captured["callable_cls"] = SpecialistCallable + + window = context.window + self.warm_start + local_context = SpecialistContext( + sensor_type=context.sensor_type, + window=window, + pipeline=f"{context.site}-calibration", + ) + specialist = SpecialistCallable(pipeline=f"planner:{context.site}") + return specialist(local_context) + + return SensorQuery, MetaSensorPlanner, captured + + def build_local_callable(name: str = "LocalCallable") -> Type[CallableModel]: class _LocalCallable(CallableModel): @Flow.call diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 329c21a..fd4a77c 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -23,7 +23,7 @@ WrapperModel, ) from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME -from ccflow.tests.local_helpers import build_local_callable, build_local_context +from ccflow.tests.local_helpers import build_local_callable, build_local_context, build_meta_sensor_planner def _find_registered_name(module, cls): @@ -659,6 +659,21 @@ def __call__(self, context: LocalContext) -> GenericResult: else: self.assertEqual(result.value, ctx_instance.value * (label + 1)) + def test_meta_callable_builds_dynamic_specialist(self): + SensorQuery, MetaSensorPlanner, captured = build_meta_sensor_planner() + request = SensorQuery(sensor_type="wind-turbine", site="ridge-line", window=5) + meta = MetaSensorPlanner(warm_start=3) + result = meta(request) + self.assertEqual(result.value, "planner:ridge-line:wind-turbine:8") + + locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + SpecialistContext = captured["context_cls"] + SpecialistCallable = captured["callable_cls"] + self.assertEqual(SpecialistContext.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + self.assertEqual(SpecialistCallable.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + self.assertIs(getattr(locals_module, SpecialistContext.__qualname__), SpecialistContext) + self.assertIs(getattr(locals_module, SpecialistCallable.__qualname__), SpecialistCallable) + class TestWrapperModel(TestCase): def test_wrapper(self): From 7504a89135b61b9c3f81a94784cc2aca8226310b Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 19 Dec 2025 21:37:38 -0500 Subject: [PATCH 5/8] Add tests for pickling and unpickling local callable models and contexts Signed-off-by: Nijat Khanbabayev --- ccflow/tests/test_callable.py | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index fd4a77c..0d4a642 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -33,6 +33,20 @@ def _find_registered_name(module, cls): raise AssertionError(f"{cls} not found in {module.__name__}") +def create_sensor_scope_models(): + class SensorScope(ContextBase): + reading: int + + class SensorCalibrator(CallableModel): + offset: int = 1 + + @Flow.call + def __call__(self, context: SensorScope) -> GenericResult: + return GenericResult(value=context.reading + self.offset) + + return SensorScope, SensorCalibrator + + class MyContext(ContextBase): a: str @@ -674,6 +688,30 @@ def test_meta_callable_builds_dynamic_specialist(self): self.assertIs(getattr(locals_module, SpecialistContext.__qualname__), SpecialistContext) self.assertIs(getattr(locals_module, SpecialistCallable.__qualname__), SpecialistCallable) + def test_dynamic_factory_pickle_roundtrip(self): + serializers = [(pdumps, ploads), (rcpdumps, rcploads)] + for dumps, loads in serializers: + factory = loads(dumps(create_sensor_scope_models)) + SensorContext, SensorModel = factory() + self.assertEqual(SensorContext.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + self.assertEqual(SensorModel.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + self.assertIs(getattr(locals_module, SensorContext.__qualname__), SensorContext) + self.assertIs(getattr(locals_module, SensorModel.__qualname__), SensorModel) + + context = SensorContext(reading=41) + model = SensorModel(offset=1) + self.assertEqual(model(context).value, 42) + + restored_context_cls = loads(dumps(SensorContext)) + restored_model_cls = loads(dumps(SensorModel)) + self.assertIs(restored_context_cls, SensorContext) + self.assertIs(restored_model_cls, SensorModel) + + restored_context = loads(dumps(context)) + restored_model = loads(dumps(model)) + self.assertEqual(restored_model(restored_context).value, 42) + class TestWrapperModel(TestCase): def test_wrapper(self): From dea6dae884849eb4b0a6f06b4155120ec754b57d Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Mon, 22 Dec 2025 04:55:05 -0500 Subject: [PATCH 6/8] Utilize finder/loader pattern for dynamic module, add subprocess tests for isolation. Signed-off-by: Nijat Khanbabayev --- ccflow/local_persistence.py | 84 +++++++++++++++++++++++--- ccflow/tests/test_global_state.py | 5 +- ccflow/tests/test_local_persistence.py | 80 +++++++++++++++++++++++- 3 files changed, 157 insertions(+), 12 deletions(-) diff --git a/ccflow/local_persistence.py b/ccflow/local_persistence.py index 7a2cc7b..e173410 100644 --- a/ccflow/local_persistence.py +++ b/ccflow/local_persistence.py @@ -2,12 +2,17 @@ from __future__ import annotations +import importlib.abc +import importlib.util import re import sys from collections import defaultdict from itertools import count -from types import ModuleType -from typing import Any, DefaultDict, Type +from typing import TYPE_CHECKING, Any, DefaultDict, Type + +if TYPE_CHECKING: + from importlib.machinery import ModuleSpec + from types import ModuleType __all__ = ("LOCAL_ARTIFACTS_MODULE_NAME", "register_local_subclass") @@ -17,13 +22,70 @@ _SANITIZE_PATTERN = re.compile(r"[^0-9A-Za-z_]") _LOCAL_KIND_COUNTERS: DefaultDict[str, count] = defaultdict(lambda: count()) +_LOCAL_ARTIFACTS_MODULE: "ModuleType | None" = None + + +class _LocalArtifactsLoader(importlib.abc.Loader): + """Minimal loader so importlib can reload our synthetic module if needed.""" + + def __init__(self, *, doc: str) -> None: + self._doc = doc + + def create_module(self, spec: "ModuleSpec") -> "ModuleType | None": + """Defer to default module creation (keeping importlib from recursing).""" + return None + + def exec_module(self, module: "ModuleType") -> None: + module.__doc__ = module.__doc__ or self._doc + + +class _LocalArtifactsFinder(importlib.abc.MetaPathFinder): + """Ensures importlib can rediscover the synthetic module when reloading.""" + + def find_spec( + self, + fullname: str, + path: Any, + target: "ModuleType | None" = None, + ) -> "ModuleSpec | None": + if fullname != LOCAL_ARTIFACTS_MODULE_NAME: + return None + return _build_module_spec(fullname, _LOCAL_MODULE_DOC) + + +def _build_module_spec(name: str, doc: str) -> "ModuleSpec": + loader = _LocalArtifactsLoader(doc=doc) + spec = importlib.util.spec_from_loader( + name, + loader=loader, + origin="ccflow.local_persistence:_ensure_module", + ) + if spec is None: + raise ImportError(f"Unable to create spec for dynamic module {name!r}.") + spec.has_location = False + return spec + + +def _ensure_finder_installed() -> None: + for finder in sys.meta_path: + if isinstance(finder, _LocalArtifactsFinder): + return + sys.meta_path.insert(0, _LocalArtifactsFinder()) + +def _ensure_module(name: str, doc: str) -> "ModuleType": + """Materialize the synthetic module with a real spec/loader so importlib treats it like disk-backed code. -def _ensure_module(name: str, doc: str) -> ModuleType: - """Ensure the dynamic module exists so import paths remain stable.""" + We only do this on demand, but once built the finder/spec/loader plumbing + keeps reload, pickling, and other importlib consumers happy. The Python docs recommend this approach instead of creating modules directly via the constructor.""" + _ensure_finder_installed() module = sys.modules.get(name) if module is None: - module = ModuleType(name, doc) + # Create a proper ModuleSpec + loader so importlib reloads and introspection behave + # the same way they would for filesystem-backed modules + spec = _build_module_spec(name, doc) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # type: ignore[union-attr] sys.modules[name] = module parent_name, _, attr = name.rpartition(".") if parent_name: @@ -33,7 +95,12 @@ def _ensure_module(name: str, doc: str) -> ModuleType: return module -_LOCAL_ARTIFACTS_MODULE = _ensure_module(LOCAL_ARTIFACTS_MODULE_NAME, _LOCAL_MODULE_DOC) +def _get_local_artifacts_module() -> "ModuleType": + """Lazily materialize the synthetic module to avoid errors during creation until needed.""" + global _LOCAL_ARTIFACTS_MODULE + if _LOCAL_ARTIFACTS_MODULE is None: + _LOCAL_ARTIFACTS_MODULE = _ensure_module(LOCAL_ARTIFACTS_MODULE_NAME, _LOCAL_MODULE_DOC) + return _LOCAL_ARTIFACTS_MODULE def _needs_registration(cls: Type[Any]) -> bool: @@ -65,7 +132,8 @@ def register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None: name_hint = f"{getattr(cls, '__module__', '')}.{getattr(cls, '__qualname__', '')}" kind_slug = _sanitize_identifier(kind, "model").lower() unique_name = _build_unique_name(kind_slug=kind_slug, name_hint=name_hint) - setattr(_LOCAL_ARTIFACTS_MODULE, unique_name, cls) - cls.__module__ = _LOCAL_ARTIFACTS_MODULE.__name__ + artifacts_module = _get_local_artifacts_module() + setattr(artifacts_module, unique_name, cls) + cls.__module__ = artifacts_module.__name__ cls.__qualname__ = unique_name setattr(cls, "__ccflow_dynamic_origin__", name_hint) diff --git a/ccflow/tests/test_global_state.py b/ccflow/tests/test_global_state.py index e5b0d68..470afec 100644 --- a/ccflow/tests/test_global_state.py +++ b/ccflow/tests/test_global_state.py @@ -48,9 +48,8 @@ def test_global_state(root_registry): assert state3.open_overrides == {} -def test_global_state_pickle(): - r = ModelRegistry.root() - r.add("foo", DummyModel(name="foo")) +def test_global_state_pickle(root_registry): + root_registry.add("foo", DummyModel(name="foo")) evaluator = DummyEvaluator() with FlowOptionsOverride(options=dict(evaluator=evaluator)) as override: state = GlobalState() diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index bebf0bf..e9121e6 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -1,3 +1,6 @@ +import subprocess +import sys +import textwrap from collections import defaultdict from itertools import count from unittest import TestCase, mock @@ -23,7 +26,7 @@ def __call__(self, context: NullContext) -> GenericResult: class TestNeedsRegistration(TestCase): def test_module_level_ccflow_classes_do_not_need_registration(self): for cls in (ModuleLevelModel, ModuleLevelContext, ModuleLevelCallable): - with self.subTest(cls=cls): + with self.subTest(cls=cls.__name__): self.assertFalse(local_persistence._needs_registration(cls)) def test_local_scope_class_needs_registration(self): @@ -79,3 +82,78 @@ def test_empty_hint_uses_fallback(self): with mock.patch.object(local_persistence, "_LOCAL_KIND_COUNTERS", defaultdict(lambda: count())): name = local_persistence._build_unique_name(kind_slug="model", name_hint="") self.assertEqual(name, "model__BaseModel__0") + + +def _run_subprocess(code: str) -> str: + """Execute code in a clean interpreter so sys.modules starts empty.""" + result = subprocess.run( + [sys.executable, "-c", textwrap.dedent(code)], + check=True, + capture_output=True, + text=True, + ) + return result.stdout.strip() + + +def test_local_artifacts_module_is_lazy(): + output = _run_subprocess( + """ + import sys + import ccflow.local_persistence as lp + + print(lp.LOCAL_ARTIFACTS_MODULE_NAME in sys.modules) + """ + ) + assert output == "False" + + +def test_local_artifacts_module_reload_preserves_dynamic_attrs_and_qualname(): + output = _run_subprocess( + """ + import importlib + import ccflow.local_persistence as lp + + def build_cls(): + class _Temp: + pass + return _Temp + + Temp = build_cls() + lp.register_local_subclass(Temp, kind="demo") + module = importlib.import_module(lp.LOCAL_ARTIFACTS_MODULE_NAME) + qual_before = Temp.__qualname__ + before = getattr(module, qual_before) is Temp + module = importlib.reload(module) + after = getattr(module, qual_before) is Temp + print(before, after, qual_before == Temp.__qualname__) + """ + ) + assert output.split() == ["True", "True", "True"] + + +def test_register_local_subclass_sets_module_qualname_and_origin(): + output = _run_subprocess( + """ + import sys + import ccflow.local_persistence as lp + + def build(): + class Foo: + pass + return Foo + + Foo = build() + lp.register_local_subclass(Foo, kind="ModelThing") + module = sys.modules[lp.LOCAL_ARTIFACTS_MODULE_NAME] + print(Foo.__module__) + print(Foo.__qualname__) + print(hasattr(module, Foo.__qualname__)) + print(Foo.__ccflow_dynamic_origin__) + """ + ) + lines = output.splitlines() + assert lines[0] == "ccflow._local_artifacts" + assert lines[2] == "True" + assert lines[3] == "__main__.build..Foo" + assert lines[1].startswith("modelthing__") + assert lines[1].endswith("__0") From 9e52ab94ef65f5754cdc96ed6640861e47c0a122 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Mon, 22 Dec 2025 05:00:23 -0500 Subject: [PATCH 7/8] Make dynamic module functions private Signed-off-by: Nijat Khanbabayev --- ccflow/base.py | 4 ++-- ccflow/local_persistence.py | 5 ++--- ccflow/tests/test_base.py | 8 ++++---- ccflow/tests/test_local_persistence.py | 4 ++-- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/ccflow/base.py b/ccflow/base.py index adfec2e..e111731 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -30,7 +30,7 @@ from typing_extensions import Self from .exttypes.pyobjectpath import PyObjectPath -from .local_persistence import register_local_subclass +from .local_persistence import _register_local_subclass log = logging.getLogger(__name__) @@ -164,7 +164,7 @@ def __pydantic_init_subclass__(cls, **kwargs): # __pydantic_init_subclass__ is the supported hook point once Pydantic finishes wiring the subclass. super().__pydantic_init_subclass__(**kwargs) kind = getattr(cls, "__ccflow_local_registration_kind__", "model") - register_local_subclass(cls, kind=kind) + _register_local_subclass(cls, kind=kind) @computed_field( alias="_target_", diff --git a/ccflow/local_persistence.py b/ccflow/local_persistence.py index e173410..478b2b1 100644 --- a/ccflow/local_persistence.py +++ b/ccflow/local_persistence.py @@ -14,8 +14,7 @@ from importlib.machinery import ModuleSpec from types import ModuleType -__all__ = ("LOCAL_ARTIFACTS_MODULE_NAME", "register_local_subclass") - +__all__ = ("LOCAL_ARTIFACTS_MODULE_NAME",) LOCAL_ARTIFACTS_MODULE_NAME = "ccflow._local_artifacts" _LOCAL_MODULE_DOC = "Auto-generated BaseModel subclasses created outside importable modules." @@ -122,7 +121,7 @@ def _build_unique_name(*, kind_slug: str, name_hint: str) -> str: return f"{kind_slug}__{sanitized_hint}__{next(counter)}" -def register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None: +def _register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None: """Register BaseModel subclasses created in local scopes.""" if getattr(cls, "__module__", "").startswith(LOCAL_ARTIFACTS_MODULE_NAME): return diff --git a/ccflow/tests/test_base.py b/ccflow/tests/test_base.py index 656d841..bd24cc1 100644 --- a/ccflow/tests/test_base.py +++ b/ccflow/tests/test_base.py @@ -176,7 +176,7 @@ def test_widget(self): class TestLocalRegistrationKind(TestCase): def test_base_model_defaults_to_model_kind(self): - with mock.patch("ccflow.base.register_local_subclass") as register: + with mock.patch("ccflow.base._register_local_subclass") as register: class LocalModel(BaseModel): value: int @@ -187,7 +187,7 @@ class LocalModel(BaseModel): self.assertEqual(kwargs["kind"], "model") def test_context_defaults_to_context_kind(self): - with mock.patch("ccflow.base.register_local_subclass") as register: + with mock.patch("ccflow.base._register_local_subclass") as register: class LocalContext(ContextBase): value: int @@ -198,7 +198,7 @@ class LocalContext(ContextBase): self.assertEqual(kwargs["kind"], "context") def test_callable_defaults_to_callable_kind(self): - with mock.patch("ccflow.base.register_local_subclass") as register: + with mock.patch("ccflow.base._register_local_subclass") as register: class LocalCallable(CallableModel): @Flow.call @@ -214,7 +214,7 @@ def __call__(self, context: NullContext) -> GenericResult: self.assertEqual(result.value, "ok") def test_explicit_override_respected(self): - with mock.patch("ccflow.base.register_local_subclass") as register: + with mock.patch("ccflow.base._register_local_subclass") as register: class CustomKind(BaseModel): __ccflow_local_registration_kind__ = "custom" diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index e9121e6..72d6be4 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -119,7 +119,7 @@ class _Temp: return _Temp Temp = build_cls() - lp.register_local_subclass(Temp, kind="demo") + lp._register_local_subclass(Temp, kind="demo") module = importlib.import_module(lp.LOCAL_ARTIFACTS_MODULE_NAME) qual_before = Temp.__qualname__ before = getattr(module, qual_before) is Temp @@ -143,7 +143,7 @@ class Foo: return Foo Foo = build() - lp.register_local_subclass(Foo, kind="ModelThing") + lp._register_local_subclass(Foo, kind="ModelThing") module = sys.modules[lp.LOCAL_ARTIFACTS_MODULE_NAME] print(Foo.__module__) print(Foo.__qualname__) From ff6acf168437929b1453e7882f79060f736435f0 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 26 Dec 2025 00:58:30 -0500 Subject: [PATCH 8/8] Adjust local module persistence to work with cloudpickle across processes Signed-off-by: Nijat Khanbabayev --- ccflow/exttypes/pyobjectpath.py | 11 +- ccflow/local_persistence.py | 70 ++- ccflow/tests/test_callable.py | 139 +++-- ccflow/tests/test_local_persistence.py | 692 ++++++++++++++++++++++++- 4 files changed, 845 insertions(+), 67 deletions(-) diff --git a/ccflow/exttypes/pyobjectpath.py b/ccflow/exttypes/pyobjectpath.py index 9c91b2f..9ef2bed 100644 --- a/ccflow/exttypes/pyobjectpath.py +++ b/ccflow/exttypes/pyobjectpath.py @@ -8,6 +8,8 @@ from pydantic_core import core_schema from typing_extensions import Self +from ccflow.local_persistence import _ensure_registered_at_import_path + _import_string_adapter = TypeAdapter(ImportString) @@ -56,7 +58,14 @@ def _validate(cls, value: Any): origin = get_origin(value) if origin: value = origin - if hasattr(value, "__module__") and hasattr(value, "__qualname__"): + + # Check for ccflow's import path override first (used for local-scope classes) + # This allows classes with '' in __qualname__ to remain importable + # while preserving cloudpickle's ability to serialize the class definition + if hasattr(value, "__ccflow_import_path__"): + _ensure_registered_at_import_path(value) + value = cls(value.__ccflow_import_path__) + elif hasattr(value, "__module__") and hasattr(value, "__qualname__"): if value.__module__ == "__builtin__": module = "builtins" else: diff --git a/ccflow/local_persistence.py b/ccflow/local_persistence.py index 478b2b1..5fd0efa 100644 --- a/ccflow/local_persistence.py +++ b/ccflow/local_persistence.py @@ -1,4 +1,20 @@ -"""Helpers for persisting BaseModel-derived classes defined inside local scopes.""" +"""Helpers for persisting BaseModel-derived classes defined inside local scopes. + +This module enables PyObjectPath validation for classes defined inside functions (which have +'' in their __qualname__ and aren't normally importable). + +Key design decision: We DON'T modify __module__ or __qualname__. This preserves cloudpickle's +ability to serialize the class definition for cross-process transfer. Instead, we set a +separate __ccflow_import_path__ attribute that PyObjectPath uses. + +Cross-process cloudpickle flow: +1. Process A creates a local class -> we set __ccflow_import_path__ on it +2. cloudpickle.dumps() serializes the class definition (because '' in __qualname__) + INCLUDING the __ccflow_import_path__ attribute we set +3. Process B: cloudpickle.loads() reconstructs the class with __ccflow_import_path__ already set +4. Process B: __pydantic_init_subclass__ runs, sees __ccflow_import_path__ exists, + re-registers the class in this process's _local_artifacts module +""" from __future__ import annotations @@ -121,10 +137,49 @@ def _build_unique_name(*, kind_slug: str, name_hint: str) -> str: return f"{kind_slug}__{sanitized_hint}__{next(counter)}" +def _ensure_registered_at_import_path(cls: Type[Any]) -> None: + """Ensure a class with __ccflow_import_path__ is actually registered in _local_artifacts. + + This handles the cross-process cloudpickle case: when cloudpickle reconstructs a class, + it has __ccflow_import_path__ set (serialized with the class definition), but the class + isn't registered in _local_artifacts in the new process yet. + + Called from both _register_local_subclass (during class creation/unpickling) and + PyObjectPath validation (when accessing type_). + """ + import_path = getattr(cls, "__ccflow_import_path__", None) + if import_path is None or not import_path.startswith(LOCAL_ARTIFACTS_MODULE_NAME + "."): + return + + registered_name = import_path.rsplit(".", 1)[-1] + artifacts_module = _get_local_artifacts_module() + + # Re-register if not present or points to different class + if getattr(artifacts_module, registered_name, None) is not cls: + setattr(artifacts_module, registered_name, cls) + + def _register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None: - """Register BaseModel subclasses created in local scopes.""" - if getattr(cls, "__module__", "").startswith(LOCAL_ARTIFACTS_MODULE_NAME): + """Register BaseModel subclasses created in local scopes. + + This enables PyObjectPath validation for classes that aren't normally importable + (e.g., classes defined inside functions). The class is registered in a synthetic + module (`ccflow._local_artifacts`) so it can be imported via the stored path. + + IMPORTANT: This function does NOT change __module__ or __qualname__. This is + intentional - it preserves cloudpickle's ability to serialize the class definition + for cross-process transfer. If __qualname__ contains '', cloudpickle + recognizes the class isn't normally importable and serializes its full definition. + + Args: + cls: The class to register. + kind: A slug identifying the type of class (e.g., "model", "context", "callable_model"). + """ + # If already has import path, just ensure it's registered (handles cross-process unpickling) + if hasattr(cls, "__ccflow_import_path__"): + _ensure_registered_at_import_path(cls) return + if not _needs_registration(cls): return @@ -133,6 +188,9 @@ def _register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None: unique_name = _build_unique_name(kind_slug=kind_slug, name_hint=name_hint) artifacts_module = _get_local_artifacts_module() setattr(artifacts_module, unique_name, cls) - cls.__module__ = artifacts_module.__name__ - cls.__qualname__ = unique_name - setattr(cls, "__ccflow_dynamic_origin__", name_hint) + + # Store the import path as a separate attribute - DON'T change __module__ or __qualname__ + # This preserves cloudpickle's ability to serialize the class definition. + # The original module/qualname can still be retrieved via cls.__module__ and cls.__qualname__. + import_path = f"{artifacts_module.__name__}.{unique_name}" + setattr(cls, "__ccflow_import_path__", import_path) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 0d4a642..56c5737 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -530,23 +530,46 @@ def test_module_level_context_retains_module(self): if dynamic_module: self.assertFalse(any(value is MyContext for value in vars(dynamic_module).values())) - def test_local_class_moves_under_dynamic_namespace(self): + def test_local_class_registered_with_import_path(self): + """Test that local-scope classes preserve __module__/__qualname__ but get an import path. + + The new behavior: + - __module__ stays as the original (preserves cloudpickle serialization) + - __qualname__ stays with '' (preserves cloudpickle serialization) + - __ccflow_import_path__ is set for PyObjectPath validation + - Class is registered in _local_artifacts under a unique name + """ LocalCallable = build_local_callable() - module_name = LocalCallable.__module__ - self.assertEqual(module_name, LOCAL_ARTIFACTS_MODULE_NAME) - dynamic_module = sys.modules[module_name] - self.assertIs(getattr(dynamic_module, LocalCallable.__qualname__), LocalCallable) - self.assertIn("", getattr(LocalCallable, "__ccflow_dynamic_origin__")) + # __module__ should NOT change to _local_artifacts + self.assertNotEqual(LocalCallable.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + # __qualname__ should still have '' + self.assertIn("", LocalCallable.__qualname__) + # But __ccflow_import_path__ should be set and point to _local_artifacts + self.assertTrue(hasattr(LocalCallable, "__ccflow_import_path__")) + import_path = LocalCallable.__ccflow_import_path__ + self.assertTrue(import_path.startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) + # Class should be registered in _local_artifacts under the import path + dynamic_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + registered_name = import_path.split(".")[-1] + self.assertIs(getattr(dynamic_module, registered_name), LocalCallable) + # Functionality should work result = LocalCallable()(NullContext()) self.assertEqual(result.value, "local") - def test_multiple_local_definitions_have_unique_identifiers(self): + def test_multiple_local_definitions_have_unique_import_paths(self): + """Test that multiple local classes get unique import paths.""" first = build_local_callable() second = build_local_callable() - self.assertNotEqual(first.__qualname__, second.__qualname__) + # __qualname__ stays the same (both are 'build_local_callable.._LocalCallable') + self.assertEqual(first.__qualname__, second.__qualname__) + # But __ccflow_import_path__ should be unique + self.assertNotEqual(first.__ccflow_import_path__, second.__ccflow_import_path__) + # Both should be registered in _local_artifacts dynamic_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] - self.assertIs(getattr(dynamic_module, first.__qualname__), first) - self.assertIs(getattr(dynamic_module, second.__qualname__), second) + first_name = first.__ccflow_import_path__.split(".")[-1] + second_name = second.__ccflow_import_path__.split(".")[-1] + self.assertIs(getattr(dynamic_module, first_name), first) + self.assertIs(getattr(dynamic_module, second_name), second) def test_context_and_callable_same_name_do_not_collide(self): def build_conflicting(): @@ -614,7 +637,8 @@ def __call__(self, context: LocalContext) -> GenericResult: restored_result = restored_model(restored_context) self.assertEqual(restored_result.value, 35) - def test_multiple_nested_levels_unique_paths(self): + def test_multiple_nested_levels_unique_import_paths(self): + """Test that multiple nested local classes all get unique import paths.""" created = [] def layer(depth: int): @@ -656,16 +680,21 @@ def __call__(self, context: LocalContext) -> GenericResult: locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] - context_names = {ctx.__qualname__ for _, ctx, _ in created} - model_names = {model.__qualname__ for _, _, model in created} - self.assertEqual(len(context_names), len(created)) - self.assertEqual(len(model_names), len(created)) + # __ccflow_import_path__ should be unique for each class + context_import_paths = {ctx.__ccflow_import_path__ for _, ctx, _ in created} + model_import_paths = {model.__ccflow_import_path__ for _, _, model in created} + self.assertEqual(len(context_import_paths), len(created)) + self.assertEqual(len(model_import_paths), len(created)) for label, ctx_cls, model_cls in created: - self.assertIs(getattr(locals_module, ctx_cls.__qualname__), ctx_cls) - self.assertIs(getattr(locals_module, model_cls.__qualname__), model_cls) - self.assertIn("", getattr(ctx_cls, "__ccflow_dynamic_origin__")) - self.assertIn("", getattr(model_cls, "__ccflow_dynamic_origin__")) + # Each class should be registered under its import path + ctx_name = ctx_cls.__ccflow_import_path__.split(".")[-1] + model_name = model_cls.__ccflow_import_path__.split(".")[-1] + self.assertIs(getattr(locals_module, ctx_name), ctx_cls) + self.assertIs(getattr(locals_module, model_name), model_cls) + # Original qualname should still have '' (preserved for cloudpickle) + self.assertIn("", ctx_cls.__qualname__) + self.assertIn("", model_cls.__qualname__) ctx_instance = ctx_cls(value=4) result = model_cls()(ctx_instance) if isinstance(label, str): # sibling group @@ -683,34 +712,54 @@ def test_meta_callable_builds_dynamic_specialist(self): locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] SpecialistContext = captured["context_cls"] SpecialistCallable = captured["callable_cls"] - self.assertEqual(SpecialistContext.__module__, LOCAL_ARTIFACTS_MODULE_NAME) - self.assertEqual(SpecialistCallable.__module__, LOCAL_ARTIFACTS_MODULE_NAME) - self.assertIs(getattr(locals_module, SpecialistContext.__qualname__), SpecialistContext) - self.assertIs(getattr(locals_module, SpecialistCallable.__qualname__), SpecialistCallable) + # __module__ should NOT change (preserves cloudpickle) + self.assertNotEqual(SpecialistContext.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + self.assertNotEqual(SpecialistCallable.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + # But __ccflow_import_path__ should be set and classes should be registered + ctx_name = SpecialistContext.__ccflow_import_path__.split(".")[-1] + model_name = SpecialistCallable.__ccflow_import_path__.split(".")[-1] + self.assertIs(getattr(locals_module, ctx_name), SpecialistContext) + self.assertIs(getattr(locals_module, model_name), SpecialistCallable) def test_dynamic_factory_pickle_roundtrip(self): - serializers = [(pdumps, ploads), (rcpdumps, rcploads)] - for dumps, loads in serializers: - factory = loads(dumps(create_sensor_scope_models)) - SensorContext, SensorModel = factory() - self.assertEqual(SensorContext.__module__, LOCAL_ARTIFACTS_MODULE_NAME) - self.assertEqual(SensorModel.__module__, LOCAL_ARTIFACTS_MODULE_NAME) - locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] - self.assertIs(getattr(locals_module, SensorContext.__qualname__), SensorContext) - self.assertIs(getattr(locals_module, SensorModel.__qualname__), SensorModel) - - context = SensorContext(reading=41) - model = SensorModel(offset=1) - self.assertEqual(model(context).value, 42) - - restored_context_cls = loads(dumps(SensorContext)) - restored_model_cls = loads(dumps(SensorModel)) - self.assertIs(restored_context_cls, SensorContext) - self.assertIs(restored_model_cls, SensorModel) - - restored_context = loads(dumps(context)) - restored_model = loads(dumps(model)) - self.assertEqual(restored_model(restored_context).value, 42) + """Test that dynamically created local classes can be pickled with cloudpickle. + + Note: Standard pickle can't handle classes with '' in __qualname__ + because it tries to import the class by module.qualname. Cloudpickle can + serialize the class definition, which is essential for distributed computing + (e.g., Ray tasks). This tradeoff enables cross-process cloudpickle support + while PyObjectPath still works via __ccflow_import_path__. + """ + # Test with cloudpickle (which can serialize class definitions) + factory = rcploads(rcpdumps(create_sensor_scope_models)) + SensorContext, SensorModel = factory() + # __module__ should NOT change (preserves cloudpickle) + self.assertNotEqual(SensorContext.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + self.assertNotEqual(SensorModel.__module__, LOCAL_ARTIFACTS_MODULE_NAME) + # But __ccflow_import_path__ should be set and point to _local_artifacts + self.assertTrue(SensorContext.__ccflow_import_path__.startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) + self.assertTrue(SensorModel.__ccflow_import_path__.startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) + # Classes should be registered in _local_artifacts + locals_module = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + ctx_name = SensorContext.__ccflow_import_path__.split(".")[-1] + model_name = SensorModel.__ccflow_import_path__.split(".")[-1] + self.assertIs(getattr(locals_module, ctx_name), SensorContext) + self.assertIs(getattr(locals_module, model_name), SensorModel) + + context = SensorContext(reading=41) + model = SensorModel(offset=1) + self.assertEqual(model(context).value, 42) + + # Class roundtrip with cloudpickle + restored_context_cls = rcploads(rcpdumps(SensorContext)) + restored_model_cls = rcploads(rcpdumps(SensorModel)) + self.assertIs(restored_context_cls, SensorContext) + self.assertIs(restored_model_cls, SensorModel) + + # Instance roundtrip with cloudpickle + restored_context = rcploads(rcpdumps(context)) + restored_model = rcploads(rcpdumps(model)) + self.assertEqual(restored_model(restored_context).value, 42) class TestWrapperModel(TestCase): diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index 72d6be4..085aec0 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -5,6 +5,8 @@ from itertools import count from unittest import TestCase, mock +import ray + import ccflow.local_persistence as local_persistence from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, NullContext @@ -107,7 +109,7 @@ def test_local_artifacts_module_is_lazy(): assert output == "False" -def test_local_artifacts_module_reload_preserves_dynamic_attrs_and_qualname(): +def test_local_artifacts_module_reload_preserves_dynamic_attrs(): output = _run_subprocess( """ import importlib @@ -121,17 +123,23 @@ class _Temp: Temp = build_cls() lp._register_local_subclass(Temp, kind="demo") module = importlib.import_module(lp.LOCAL_ARTIFACTS_MODULE_NAME) - qual_before = Temp.__qualname__ - before = getattr(module, qual_before) is Temp + + # Extract the registered name from __ccflow_import_path__ + import_path = Temp.__ccflow_import_path__ + registered_name = import_path.split(".")[-1] + + before = getattr(module, registered_name) is Temp module = importlib.reload(module) - after = getattr(module, qual_before) is Temp - print(before, after, qual_before == Temp.__qualname__) + after = getattr(module, registered_name) is Temp + + # __qualname__ should NOT have changed (preserves cloudpickle behavior) + print(before, after, "" in Temp.__qualname__) """ ) assert output.split() == ["True", "True", "True"] -def test_register_local_subclass_sets_module_qualname_and_origin(): +def test_register_local_subclass_preserves_module_qualname_and_sets_import_path(): output = _run_subprocess( """ import sys @@ -143,17 +151,671 @@ class Foo: return Foo Foo = build() + original_module = Foo.__module__ + original_qualname = Foo.__qualname__ lp._register_local_subclass(Foo, kind="ModelThing") module = sys.modules[lp.LOCAL_ARTIFACTS_MODULE_NAME] - print(Foo.__module__) - print(Foo.__qualname__) - print(hasattr(module, Foo.__qualname__)) - print(Foo.__ccflow_dynamic_origin__) + + # __module__ and __qualname__ should NOT change (preserves cloudpickle) + print(Foo.__module__ == original_module) + print(Foo.__qualname__ == original_qualname) + print("" in Foo.__qualname__) + + # __ccflow_import_path__ should be set and point to the registered class + import_path = Foo.__ccflow_import_path__ + registered_name = import_path.split(".")[-1] + print(hasattr(module, registered_name)) + print(getattr(module, registered_name) is Foo) + print(import_path.startswith("ccflow._local_artifacts.modelthing__")) """ ) lines = output.splitlines() - assert lines[0] == "ccflow._local_artifacts" - assert lines[2] == "True" - assert lines[3] == "__main__.build..Foo" - assert lines[1].startswith("modelthing__") - assert lines[1].endswith("__0") + assert lines[0] == "True", f"__module__ should not change: {lines}" + assert lines[1] == "True", f"__qualname__ should not change: {lines}" + assert lines[2] == "True", f"__qualname__ should contain '': {lines}" + assert lines[3] == "True", f"Class should be registered in module: {lines}" + assert lines[4] == "True", f"Registered class should be the same object: {lines}" + assert lines[5] == "True", f"Import path should start with expected prefix: {lines}" + + +def test_local_basemodel_cloudpickle_cross_process(): + """Test that local-scope BaseModel subclasses work with cloudpickle cross-process. + + This is the key test for the "best of both worlds" approach: + - __qualname__ has '' so cloudpickle serializes the class definition + - __ccflow_import_path__ allows PyObjectPath validation to work + - After unpickling, __pydantic_init_subclass__ re-registers the class + """ + import os + import tempfile + + pkl_path = tempfile.mktemp(suffix=".pkl") + + try: + # Create a local-scope BaseModel in subprocess 1 and pickle it + create_result = subprocess.run( + [ + sys.executable, + "-c", + f""" +from ray.cloudpickle import dump +from ccflow import BaseModel + +def create_local_model(): + class LocalModel(BaseModel): + value: int + + return LocalModel + +LocalModel = create_local_model() + +# Verify __qualname__ has '' (enables cloudpickle serialization) +assert "" in LocalModel.__qualname__, f"Expected '' in qualname: {{LocalModel.__qualname__}}" + +# Verify __ccflow_import_path__ is set (enables PyObjectPath) +assert hasattr(LocalModel, "__ccflow_import_path__"), "Expected __ccflow_import_path__ to be set" + +# Create instance and verify type_ works (PyObjectPath validation) +instance = LocalModel(value=42) +type_path = instance.type_ +print(f"type_: {{type_path}}") + +# Pickle the instance +with open("{pkl_path}", "wb") as f: + dump(instance, f) + +print("SUCCESS: Created and pickled") +""", + ], + capture_output=True, + text=True, + ) + assert create_result.returncode == 0, f"Create subprocess failed: {create_result.stderr}" + assert "SUCCESS" in create_result.stdout, f"Create subprocess output: {create_result.stdout}" + + # Load in subprocess 2 (different process, class not pre-defined) + load_result = subprocess.run( + [ + sys.executable, + "-c", + f""" +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + obj = load(f) + +# Verify the value was preserved +assert obj.value == 42, f"Expected value=42, got {{obj.value}}" + +# Verify type_ works after unpickling (class was re-registered) +type_path = obj.type_ +print(f"type_: {{type_path}}") + +# Verify the import path works +import importlib +path_parts = str(type_path).rsplit(".", 1) +module = importlib.import_module(path_parts[0]) +cls = getattr(module, path_parts[1]) +assert cls is type(obj), "Import path should resolve to the same class" + +print("SUCCESS: Loaded and verified") +""", + ], + capture_output=True, + text=True, + ) + assert load_result.returncode == 0, f"Load subprocess failed: {load_result.stderr}" + assert "SUCCESS" in load_result.stdout, f"Load subprocess output: {load_result.stdout}" + + finally: + if os.path.exists(pkl_path): + os.unlink(pkl_path) + + +# ============================================================================= +# Comprehensive tests for local persistence and PyObjectPath integration +# ============================================================================= + + +class TestLocalPersistencePreservesCloudpickle: + """Tests verifying that local persistence preserves cloudpickle behavior.""" + + def test_qualname_has_locals_for_function_defined_class(self): + """Verify that __qualname__ contains '' for classes defined in functions.""" + + def create_class(): + class Inner(BaseModel): + x: int + + return Inner + + cls = create_class() + assert "" in cls.__qualname__ + assert cls.__module__ != local_persistence.LOCAL_ARTIFACTS_MODULE_NAME + + def test_module_not_changed_to_local_artifacts(self): + """Verify that __module__ is NOT changed to _local_artifacts.""" + + def create_class(): + class Inner(ContextBase): + value: str + + return Inner + + cls = create_class() + # __module__ should be this test module, not _local_artifacts + assert cls.__module__ == "ccflow.tests.test_local_persistence" + assert cls.__module__ != local_persistence.LOCAL_ARTIFACTS_MODULE_NAME + + def test_ccflow_import_path_is_set(self): + """Verify that __ccflow_import_path__ is set for local classes.""" + + def create_class(): + class Inner(BaseModel): + y: float + + return Inner + + cls = create_class() + assert hasattr(cls, "__ccflow_import_path__") + assert cls.__ccflow_import_path__.startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME + ".") + + def test_class_registered_in_local_artifacts(self): + """Verify that the class is registered in _local_artifacts under import path.""" + import sys + + def create_class(): + class Inner(BaseModel): + z: bool + + return Inner + + cls = create_class() + import_path = cls.__ccflow_import_path__ + registered_name = import_path.split(".")[-1] + + artifacts_module = sys.modules[local_persistence.LOCAL_ARTIFACTS_MODULE_NAME] + assert hasattr(artifacts_module, registered_name) + assert getattr(artifacts_module, registered_name) is cls + + +class TestPyObjectPathWithImportPath: + """Tests for PyObjectPath integration with __ccflow_import_path__.""" + + def test_type_property_uses_import_path(self): + """Verify that the type_ property returns a path using __ccflow_import_path__.""" + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=123) + type_path = str(instance.type_) + + # type_ should use the __ccflow_import_path__, not module.qualname + assert type_path == cls.__ccflow_import_path__ + assert type_path.startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME) + + def test_type_path_can_be_imported(self): + """Verify that the type_ path can be used to import the class.""" + import importlib + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=456) + type_path = str(instance.type_) + + # Should be able to import using the path + parts = type_path.rsplit(".", 1) + module = importlib.import_module(parts[0]) + imported_cls = getattr(module, parts[1]) + assert imported_cls is cls + + def test_type_property_for_context_base(self): + """Verify type_ works for ContextBase subclasses.""" + + def create_class(): + class LocalContext(ContextBase): + name: str + + return LocalContext + + cls = create_class() + instance = cls(name="test") + type_path = str(instance.type_) + + assert type_path == cls.__ccflow_import_path__ + assert instance.type_.object is cls + + def test_json_serialization_includes_target(self): + """Verify JSON serialization includes _target_ using __ccflow_import_path__.""" + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=789) + data = instance.model_dump(mode="python") + + assert "type_" in data or "_target_" in data + # The computed field should use __ccflow_import_path__ + type_value = data.get("type_") or data.get("_target_") + assert str(type_value).startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME) + + +class TestCloudpickleSameProcess: + """Tests for same-process cloudpickle behavior.""" + + def test_cloudpickle_class_roundtrip_same_process(self): + """Verify cloudpickle can serialize and deserialize local classes in same process.""" + from ray.cloudpickle import dumps, loads + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + restored_cls = loads(dumps(cls)) + + # Should be the same object (cloudpickle recognizes it's in the same process) + assert restored_cls is cls + + def test_cloudpickle_instance_roundtrip_same_process(self): + """Verify cloudpickle can serialize and deserialize instances in same process.""" + from ray.cloudpickle import dumps, loads + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=42) + restored = loads(dumps(instance)) + + assert restored.value == 42 + assert type(restored) is cls + + def test_cloudpickle_preserves_type_path(self): + """Verify type_ works after cloudpickle roundtrip in same process.""" + from ray.cloudpickle import dumps, loads + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=100) + original_type_path = str(instance.type_) + + restored = loads(dumps(instance)) + restored_type_path = str(restored.type_) + + assert restored_type_path == original_type_path + + +class TestCloudpickleCrossProcess: + """Tests for cross-process cloudpickle behavior (subprocess tests).""" + + def test_context_base_cross_process(self): + """Test cross-process cloudpickle for ContextBase subclasses.""" + import os + import tempfile + + pkl_path = tempfile.mktemp(suffix=".pkl") + + try: + create_code = f''' +from ray.cloudpickle import dump +from ccflow import ContextBase + +def create_context(): + class LocalContext(ContextBase): + name: str + value: int + return LocalContext + +LocalContext = create_context() +assert "" in LocalContext.__qualname__ +instance = LocalContext(name="test", value=42) +_ = instance.type_ # Verify type_ works before pickle + +with open("{pkl_path}", "wb") as f: + dump(instance, f) +print("SUCCESS") +''' + create_result = subprocess.run([sys.executable, "-c", create_code], capture_output=True, text=True) + assert create_result.returncode == 0, f"Create failed: {create_result.stderr}" + + load_code = f''' +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + obj = load(f) + +assert obj.name == "test" +assert obj.value == 42 +type_path = obj.type_ # Verify type_ works after unpickle +assert type_path.object is type(obj) +print("SUCCESS") +''' + load_result = subprocess.run([sys.executable, "-c", load_code], capture_output=True, text=True) + assert load_result.returncode == 0, f"Load failed: {load_result.stderr}" + + finally: + if os.path.exists(pkl_path): + os.unlink(pkl_path) + + def test_callable_model_cross_process(self): + """Test cross-process cloudpickle for CallableModel subclasses.""" + import os + import tempfile + + pkl_path = tempfile.mktemp(suffix=".pkl") + + try: + create_code = f''' +from ray.cloudpickle import dump +from ccflow import CallableModel, ContextBase, GenericResult, Flow + +def create_callable(): + class LocalContext(ContextBase): + x: int + + class LocalCallable(CallableModel): + multiplier: int = 2 + + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + return GenericResult(value=context.x * self.multiplier) + + return LocalContext, LocalCallable + +LocalContext, LocalCallable = create_callable() +instance = LocalCallable(multiplier=3) +context = LocalContext(x=10) + +# Verify it works +result = instance(context) +assert result.value == 30 + +with open("{pkl_path}", "wb") as f: + dump((instance, context), f) +print("SUCCESS") +''' + create_result = subprocess.run([sys.executable, "-c", create_code], capture_output=True, text=True) + assert create_result.returncode == 0, f"Create failed: {create_result.stderr}" + + load_code = f''' +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + instance, context = load(f) + +# Verify the callable works after unpickle +result = instance(context) +assert result.value == 30 + +# Verify type_ works +assert instance.type_.object is type(instance) +assert context.type_.object is type(context) +print("SUCCESS") +''' + load_result = subprocess.run([sys.executable, "-c", load_code], capture_output=True, text=True) + assert load_result.returncode == 0, f"Load failed: {load_result.stderr}" + + finally: + if os.path.exists(pkl_path): + os.unlink(pkl_path) + + def test_nested_local_classes_cross_process(self): + """Test cross-process cloudpickle for multiply-nested local classes.""" + import os + import tempfile + + pkl_path = tempfile.mktemp(suffix=".pkl") + + try: + create_code = f''' +from ray.cloudpickle import dump +from ccflow import BaseModel + +def outer(): + def inner(): + class DeeplyNested(BaseModel): + value: int + return DeeplyNested + return inner() + +cls = outer() +assert "" in cls.__qualname__ +assert cls.__qualname__.count("") == 2 # Two levels of nesting + +instance = cls(value=999) +with open("{pkl_path}", "wb") as f: + dump(instance, f) +print("SUCCESS") +''' + create_result = subprocess.run([sys.executable, "-c", create_code], capture_output=True, text=True) + assert create_result.returncode == 0, f"Create failed: {create_result.stderr}" + + load_code = f''' +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + obj = load(f) + +assert obj.value == 999 +_ = obj.type_ # Verify type_ works +print("SUCCESS") +''' + load_result = subprocess.run([sys.executable, "-c", load_code], capture_output=True, text=True) + assert load_result.returncode == 0, f"Load failed: {load_result.stderr}" + + finally: + if os.path.exists(pkl_path): + os.unlink(pkl_path) + + def test_multiple_instances_same_local_class_cross_process(self): + """Test that multiple instances of the same local class work cross-process.""" + import os + import tempfile + + pkl_path = tempfile.mktemp(suffix=".pkl") + + try: + create_code = f''' +from ray.cloudpickle import dump +from ccflow import BaseModel + +def create_class(): + class LocalModel(BaseModel): + value: int + return LocalModel + +cls = create_class() +instances = [cls(value=i) for i in range(5)] + +with open("{pkl_path}", "wb") as f: + dump(instances, f) +print("SUCCESS") +''' + create_result = subprocess.run([sys.executable, "-c", create_code], capture_output=True, text=True) + assert create_result.returncode == 0, f"Create failed: {create_result.stderr}" + + load_code = f''' +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + instances = load(f) + +# All instances should have the correct values +for i, instance in enumerate(instances): + assert instance.value == i + +# All instances should be of the same class +assert len(set(type(inst) for inst in instances)) == 1 + +# type_ should work for all +for instance in instances: + _ = instance.type_ +print("SUCCESS") +''' + load_result = subprocess.run([sys.executable, "-c", load_code], capture_output=True, text=True) + assert load_result.returncode == 0, f"Load failed: {load_result.stderr}" + + finally: + if os.path.exists(pkl_path): + os.unlink(pkl_path) + + +class TestModuleLevelClassesUnaffected: + """Tests verifying that module-level classes are not affected by local persistence.""" + + def test_module_level_class_no_import_path(self): + """Verify module-level classes don't get __ccflow_import_path__.""" + assert not hasattr(ModuleLevelModel, "__ccflow_import_path__") + assert not hasattr(ModuleLevelContext, "__ccflow_import_path__") + assert not hasattr(ModuleLevelCallable, "__ccflow_import_path__") + + def test_module_level_class_type_path_uses_qualname(self): + """Verify module-level classes use standard module.qualname for type_.""" + instance = ModuleLevelModel(value=1) + type_path = str(instance.type_) + + # Should use standard path, not _local_artifacts + assert type_path == "ccflow.tests.test_local_persistence.ModuleLevelModel" + assert local_persistence.LOCAL_ARTIFACTS_MODULE_NAME not in type_path + + def test_module_level_standard_pickle_works(self): + """Verify standard pickle works for module-level classes.""" + from pickle import dumps, loads + + instance = ModuleLevelModel(value=42) + restored = loads(dumps(instance)) + assert restored.value == 42 + assert type(restored) is ModuleLevelModel + + +class TestRayTaskWithLocalClasses: + """Tests for Ray task execution with locally-defined classes. + + These tests verify that the local persistence mechanism works correctly + when classes are serialized and sent to Ray workers (different processes). + """ + + def test_local_callable_model_ray_task(self): + """Test that locally-defined CallableModels can be sent to Ray tasks. + + This is the ultimate test of cross-process cloudpickle support: + - Local class defined in function (has in __qualname__) + - Sent to Ray worker (different process) + - Executed and returns correct result + - type_ property works after execution (PyObjectPath validation) + """ + + def create_local_callable(): + class LocalContext(ContextBase): + x: int + + class LocalCallable(CallableModel): + multiplier: int = 2 + + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + return GenericResult(value=context.x * self.multiplier) + + return LocalContext, LocalCallable + + LocalContext, LocalCallable = create_local_callable() + + # Verify is in qualname (ensures cloudpickle serializes definition) + assert "" in LocalCallable.__qualname__ + assert "" in LocalContext.__qualname__ + + # Verify __ccflow_import_path__ is set + assert hasattr(LocalCallable, "__ccflow_import_path__") + assert hasattr(LocalContext, "__ccflow_import_path__") + + @ray.remote + def run_callable(model, context): + result = model(context) + # Verify type_ works inside the Ray task (cross-process PyObjectPath) + _ = model.type_ + _ = context.type_ + return result.value + + model = LocalCallable(multiplier=3) + context = LocalContext(x=10) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, context)) + + assert result == 30 + + def test_local_context_ray_task(self): + """Test that locally-defined ContextBase can be sent to Ray tasks.""" + + def create_local_context(): + class LocalContext(ContextBase): + name: str + value: int + + return LocalContext + + LocalContext = create_local_context() + assert "" in LocalContext.__qualname__ + + @ray.remote + def process_context(ctx): + # Access fields and type_ inside Ray task + _ = ctx.type_ + return f"{ctx.name}:{ctx.value}" + + context = LocalContext(name="test", value=42) + + with ray.init(num_cpus=1): + result = ray.get(process_context.remote(context)) + + assert result == "test:42" + + def test_local_base_model_ray_task(self): + """Test that locally-defined BaseModel can be sent to Ray tasks.""" + + def create_local_model(): + class LocalModel(BaseModel): + data: str + + return LocalModel + + LocalModel = create_local_model() + assert "" in LocalModel.__qualname__ + + @ray.remote + def process_model(m): + # Access type_ inside Ray task + type_path = str(m.type_) + return f"{m.data}|{type_path}" + + model = LocalModel(data="hello") + + with ray.init(num_cpus=1): + result = ray.get(process_model.remote(model)) + + assert result.startswith("hello|ccflow._local_artifacts.")