diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 30706d7..163f275 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -12,6 +12,7 @@ from .context import * from .enums import Enum from .global_state import * +from .local_persistence import * from .models import * from .object_config import * from .publisher import * diff --git a/ccflow/base.py b/ccflow/base.py index 18e2eac..f718c21 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -30,6 +30,7 @@ from typing_extensions import Self from .exttypes.pyobjectpath import PyObjectPath +from .local_persistence import register_ccflow_import_path, sync_to_module log = logging.getLogger(__name__) @@ -156,6 +157,19 @@ class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta - Registration by name, and coercion from string name to allow for object re-use from the configs """ + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + super().__pydantic_init_subclass__(**kwargs) + # Register local-scope classes and __main__ classes so they're importable via PyObjectPath. + # - Local classes ( in qualname) aren't importable via their qualname path + # - __main__ classes aren't importable cross-process (cloudpickle recreates them but + # doesn't add them to sys.modules["__main__"]) + # Note: Cross-process unpickle sync (when __ccflow_import_path__ is already set) happens + # lazily via sync_to_module, since cloudpickle sets class attributes + # AFTER __pydantic_init_subclass__ runs. + if "" in cls.__qualname__ or cls.__module__ == "__main__": + register_ccflow_import_path(cls) + @computed_field( alias="_target_", repr=False, @@ -165,8 +179,18 @@ class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta ) @property def type_(self) -> PyObjectPath: - """The path to the object type""" - return PyObjectPath.validate(type(self)) + """The path to the object type. + + For local classes (defined in functions), this returns the __ccflow_import_path__. + For cross-process unpickle scenarios, this also ensures the class is synced to + the ccflow.local_persistence module so the import path resolves correctly. + """ + cls = type(self) + # Handle cross-process unpickle: cloudpickle sets __ccflow_import_path__ but + # the class may not be on ccflow.local_persistence yet in this process + if "__ccflow_import_path__" in cls.__dict__: + sync_to_module(cls) + return PyObjectPath.validate(cls) # We want to track under what names a model has been registered _registrations: List[Tuple["ModelRegistry", str]] = PrivateAttr(default_factory=list) diff --git a/ccflow/exttypes/pyobjectpath.py b/ccflow/exttypes/pyobjectpath.py index 9c91b2f..0612fa0 100644 --- a/ccflow/exttypes/pyobjectpath.py +++ b/ccflow/exttypes/pyobjectpath.py @@ -1,26 +1,74 @@ """This module contains extension types for pydantic.""" +import importlib from functools import cached_property, lru_cache from types import FunctionType, MethodType, ModuleType from typing import Any, Type, get_origin -from pydantic import ImportString, TypeAdapter +from pydantic import TypeAdapter from pydantic_core import core_schema from typing_extensions import Self -_import_string_adapter = TypeAdapter(ImportString) - @lru_cache(maxsize=None) -def import_string(input_string: str): - return _import_string_adapter.validate_python(input_string) +def import_string(dotted_path: str) -> Any: + """Import an object from a dotted path string. + + Handles nested class paths like 'module.OuterClass.InnerClass' by progressively + trying shorter module paths and using getattr for the remaining parts. + + This is more flexible than pydantic's ImportString which can fail on nested classes. + """ + if not dotted_path: + raise ImportError("Empty path") + + parts = dotted_path.split(".") + + # Try progressively shorter module paths + # e.g., for 'a.b.C.D', try 'a.b.C.D', then 'a.b.C', then 'a.b', then 'a' + for i in range(len(parts), 0, -1): + module_path = ".".join(parts[:i]) + try: + obj = importlib.import_module(module_path) + # Successfully imported module, now getattr for remaining parts + for attr_name in parts[i:]: + obj = getattr(obj, attr_name) + return obj + except ImportError: + continue + except AttributeError: + # Module imported but attribute not found - keep trying shorter paths + continue + + raise ImportError(f"No module named '{dotted_path}'") + + +def _build_standard_import_path(obj: Any) -> str: + """Build 'module.qualname' path from an object with __module__ and __qualname__.""" + # Handle Python 2 -> 3 module name change for builtins + if obj.__module__ == "__builtin__": + module = "builtins" + else: + module = obj.__module__ + + qualname = obj.__qualname__ + # Strip generic type parameters (e.g., "MyClass[int]" -> "MyClass") + # This happens with Generic types in pydantic. Type info is lost but + # at least the base class remains importable. + # TODO: Find a way of capturing the underlying type info + if "[" in qualname: + qualname = qualname.split("[", 1)[0] + return f"{module}.{qualname}" if module else qualname class PyObjectPath(str): - """Similar to pydantic's ImportString (formerly PyObject in v1), this class represents the path to the object as a string. + """A string representing an importable Python object path (e.g., "module.ClassName"). + + Similar to pydantic's ImportString, but with consistent serialization behavior: + - ImportString deserializes to the actual object + - PyObjectPath deserializes back to the string path - In pydantic v1, PyObject could not be serialized to json, whereas in v2, ImportString can. - However, the round trip is not always consistent, i.e. + Example: >>> ta = TypeAdapter(ImportString) >>> ta.validate_json(ta.dump_json("math.pi")) 3.141592653589793 @@ -28,7 +76,7 @@ class PyObjectPath(str): >>> ta.validate_json(ta.dump_json("math.pi")) 'math.pi' - Other differences are that ImportString can contain other arbitrary python values, whereas PyObjectPath is always a string + PyObjectPath also only accepts importable objects, not arbitrary values: >>> TypeAdapter(ImportString).validate_python(0) 0 >>> TypeAdapter(PyObjectPath).validate_python(0) @@ -36,7 +84,7 @@ class PyObjectPath(str): """ # TODO: It would be nice to make this also derive from Generic[T], - # where T could then by used for type checking in validate. + # where T could then be used for type checking in validate. # However, this doesn't work: https://github.com/python/typing/issues/629 @cached_property @@ -50,34 +98,43 @@ def __get_pydantic_core_schema__(cls, source_type, handler): @classmethod def _validate(cls, value: Any): + """Convert value (string path or object) to PyObjectPath, verifying it's importable.""" if isinstance(value, str): - value = cls(value) - else: # Try to construct a string from the object that can then be used to import the object + path = cls(value) + else: + # Unwrap generic types (e.g., List[int] -> list) origin = get_origin(value) if origin: value = origin - if hasattr(value, "__module__") and hasattr(value, "__qualname__"): - if value.__module__ == "__builtin__": - module = "builtins" - else: - module = value.__module__ - qualname = value.__qualname__ - if "[" in qualname: - # This happens with Generic types in pydantic. We strip out the info for now. - # TODO: Find a way of capturing the underlying type info - qualname = qualname.split("[", 1)[0] - if not module: - value = cls(qualname) - else: - value = cls(module + "." + qualname) - else: - raise ValueError(f"ensure this value contains valid import path or importable object: unable to import path for {value}") + path = cls._path_from_object(value) + + # Verify the path is actually importable try: - value.object + path.object except ImportError as e: raise ValueError(f"ensure this value contains valid import path or importable object: {str(e)}") - return value + return path + + @classmethod + def _path_from_object(cls, value: Any) -> "PyObjectPath": + """Build import path from an object. + + For classes with __ccflow_import_path__ set (local classes), + uses that path. Otherwise uses the standard module.qualname path. + """ + if isinstance(value, type): + # Use __ccflow_import_path__ if set (check __dict__ to avoid inheriting from parents). + # Note: accessing .__dict__ is safe here because value is a type (class object), + # and all class objects have __dict__. Only instances of __slots__ classes lack it. + if "__ccflow_import_path__" in value.__dict__: + return cls(value.__ccflow_import_path__) + return cls(_build_standard_import_path(value)) + + if hasattr(value, "__module__") and hasattr(value, "__qualname__"): + return cls(_build_standard_import_path(value)) + + raise ValueError(f"ensure this value contains valid import path or importable object: unable to import path for {value}") @classmethod @lru_cache(maxsize=None) @@ -86,10 +143,12 @@ def _validate_cached(cls, value: str) -> Self: @classmethod def validate(cls, value) -> Self: - """Try to convert/validate an arbitrary value to a PyObjectPath.""" - if isinstance( - value, (str, type, FunctionType, ModuleType, MethodType) - ): # If the value is trivial, we cache it here to avoid the overhead of validation + """Try to convert/validate an arbitrary value to a PyObjectPath. + + Uses caching for common value types to improve performance. + """ + # Cache validation for common immutable types to avoid repeated work + if isinstance(value, (str, type, FunctionType, ModuleType, MethodType)): return cls._validate_cached(value) return _TYPE_ADAPTER.validate_python(value) diff --git a/ccflow/local_persistence.py b/ccflow/local_persistence.py new file mode 100644 index 0000000..39fb9c8 --- /dev/null +++ b/ccflow/local_persistence.py @@ -0,0 +1,95 @@ +"""Register local-scope classes on a module so PyObjectPath can import them. + +Classes defined in functions (with '' in __qualname__) aren't normally importable. +We give them a unique name and register them on this module (ccflow.local_persistence). +We keep __module__ and __qualname__ unchanged so cloudpickle can still serialize the +class definition. + +This module provides: +- register_ccflow_import_path(cls): Register a local class with a unique import path +- sync_to_module(cls): Ensure a class with __ccflow_import_path__ is on the module + (used for cross-process unpickle scenarios) +- create_ccflow_model: Wrapper around pydantic.create_model that registers the created model +""" + +import re +import sys +import uuid +from typing import Any, Type + +__all__ = ("LOCAL_ARTIFACTS_MODULE_NAME", "create_ccflow_model") + +LOCAL_ARTIFACTS_MODULE_NAME = "ccflow.local_persistence" + + +def _register_on_module(cls: Type[Any], module_name: str) -> None: + """Register cls on the specified module with a unique name. + + This sets __ccflow_import_path__ on the class without modifying __module__ or + __qualname__, preserving cloudpickle's ability to serialize the class definition. + + Args: + cls: The class to register. + module_name: The fully-qualified module name to register on (must be in sys.modules). + """ + # Sanitize the class name to be a valid Python identifier + name = re.sub(r"[^0-9A-Za-z_]", "_", cls.__name__ or "Model").strip("_") or "Model" + if name[0].isdigit(): + name = f"_{name}" + unique = f"_Local_{name}_{uuid.uuid4().hex[:12]}" + + setattr(sys.modules[module_name], unique, cls) + cls.__ccflow_import_path__ = f"{module_name}.{unique}" + + +def register_ccflow_import_path(cls: Type[Any]) -> None: + """Give cls a unique name and register it on ccflow.local_persistence. + + This sets __ccflow_import_path__ on the class without modifying __module__ or + __qualname__, preserving cloudpickle's ability to serialize the class definition. + """ + _register_on_module(cls, LOCAL_ARTIFACTS_MODULE_NAME) + + +def sync_to_module(cls: Type[Any]) -> None: + """Ensure cls is registered on the artifacts module in this process. + + This handles cross-process unpickle scenarios where cloudpickle recreates the class + with __ccflow_import_path__ already set (from the original process), but the class + isn't yet registered on ccflow.local_persistence in the new process. + """ + path = getattr(cls, "__ccflow_import_path__", "") + if path.startswith(LOCAL_ARTIFACTS_MODULE_NAME + "."): + name = path.rsplit(".", 1)[-1] + base = sys.modules[LOCAL_ARTIFACTS_MODULE_NAME] + if getattr(base, name, None) is not cls: + setattr(base, name, cls) + + +def create_ccflow_model(__model_name: str, *, __base__: Any = None, **field_definitions: Any) -> Type[Any]: + """Create a dynamic ccflow model and register it for PyObjectPath serialization. + + Wraps pydantic's create_model and registers the model so it can be serialized + via PyObjectPath, including across processes (e.g., with Ray). + + Example: + >>> from ccflow import ContextBase, create_ccflow_model + >>> MyContext = create_ccflow_model( + ... "MyContext", + ... __base__=ContextBase, + ... name=(str, ...), + ... value=(int, 0), + ... ) + >>> ctx = MyContext(name="test", value=42) + """ + from pydantic import create_model as pydantic_create_model + + model = pydantic_create_model(__model_name, __base__=__base__, **field_definitions) + + # Register if it's a ccflow BaseModel subclass + from ccflow.base import BaseModel + + if isinstance(model, type) and issubclass(model, BaseModel): + register_ccflow_import_path(model) + + return model diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index 7b5abf0..6124ed2 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -5,7 +5,14 @@ import pandas as pd import pyarrow as pa -from ccflow import DateContext, DateRangeContext, Evaluator, FlowOptionsOverride, ModelEvaluationContext, NullContext +from ccflow import ( + DateContext, + DateRangeContext, + Evaluator, + FlowOptionsOverride, + ModelEvaluationContext, + NullContext, +) from ccflow.evaluators import ( FallbackEvaluator, GraphEvaluator, diff --git a/ccflow/tests/exttypes/test_pyobjectpath.py b/ccflow/tests/exttypes/test_pyobjectpath.py index d25fabb..7b3dd93 100644 --- a/ccflow/tests/exttypes/test_pyobjectpath.py +++ b/ccflow/tests/exttypes/test_pyobjectpath.py @@ -62,3 +62,20 @@ def test_pickle(self): self.assertIsNotNone(p.object) self.assertEqual(p, pickle.loads(pickle.dumps(p))) self.assertEqual(p.object, pickle.loads(pickle.dumps(p.object))) + + def test_builtin_module_alias(self): + """Test that objects with __module__ == '__builtin__' are handled correctly. + + In Python 2, built-in types had __module__ == '__builtin__', but in Python 3 + it's 'builtins'. Some C extensions or pickled objects may still report the + old module name. + """ + + # Create a mock object that reports __builtin__ as its module + class MockBuiltinObject: + __module__ = "__builtin__" + __qualname__ = "int" + + p = PyObjectPath.validate(MockBuiltinObject) + self.assertEqual(p, "builtins.int") + self.assertEqual(p.object, int) diff --git a/ccflow/tests/test_base.py b/ccflow/tests/test_base.py index 3c6fc02..e53e0af 100644 --- a/ccflow/tests/test_base.py +++ b/ccflow/tests/test_base.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List -from unittest import TestCase +from unittest import TestCase, mock -from pydantic import ConfigDict, ValidationError +from pydantic import BaseModel as PydanticBaseModel, ConfigDict, ValidationError -from ccflow import BaseModel, PyObjectPath +from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, NullContext, PyObjectPath +from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME class ModelA(BaseModel): @@ -106,6 +107,20 @@ def test_type_after_assignment(self): self.assertIsInstance(m.type_, PyObjectPath) self.assertEqual(m.type_, path) + def test_pyobjectpath_requires_ccflow_local_registration(self): + class PlainLocalModel(PydanticBaseModel): + value: int + + with self.assertRaises(ValueError): + PyObjectPath.validate(PlainLocalModel) + + class LocalCcflowModel(BaseModel): + value: int + + path = PyObjectPath.validate(LocalCcflowModel) + self.assertEqual(path.object, LocalCcflowModel) + self.assertTrue(str(path).startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) + def test_validate(self): self.assertEqual(ModelA.model_validate({"x": "foo"}), ModelA(x="foo")) type_ = "ccflow.tests.test_base.ModelA" @@ -157,3 +172,35 @@ def test_widget(self): {"expanded": True, "root": "bar"}, ), ) + + +class TestLocalRegistration(TestCase): + def test_local_class_registered_for_base_model(self): + with mock.patch("ccflow.base.register_ccflow_import_path") as register: + + class LocalModel(BaseModel): + value: int + + # Local classes (defined in functions) should be registered + calls = [args[0] for args, _ in register.call_args_list if args] + self.assertIn(LocalModel, calls) + + def test_local_class_registered_for_context(self): + with mock.patch("ccflow.base.register_ccflow_import_path") as register: + + class LocalContext(ContextBase): + value: int + + calls = [args[0] for args, _ in register.call_args_list if args] + self.assertIn(LocalContext, calls) + + def test_local_class_registered_for_callable(self): + with mock.patch("ccflow.base.register_ccflow_import_path") as register: + + class LocalCallable(CallableModel): + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + return GenericResult(value="ok") + + calls = [args[0] for args, _ in register.call_args_list if args] + self.assertIn(LocalCallable, calls) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 0ba24ea..43f86b5 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -21,6 +21,7 @@ ResultType, WrapperModel, ) +from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME class MyContext(ContextBase): @@ -493,6 +494,38 @@ def test_union_return(self): self.assertEqual(result.a, 1) +class TestCallableModelRegistration(TestCase): + """Smoke test verifying CallableModel inherits registration from BaseModel. + + NOTE: Registration behavior is thoroughly tested at the BaseModel level in + test_local_persistence.py. This single test verifies inheritance works. + """ + + def test_local_callable_smoke_test(self): + """Verify that local CallableModel classes inherit registration from BaseModel.""" + + class LocalContext(ContextBase): + value: int + + class LocalCallable(CallableModel): + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + return GenericResult(value=context.value * 2) + + # Basic registration should work (inherits from BaseModel) + self.assertIn("", LocalCallable.__qualname__) + self.assertTrue(hasattr(LocalCallable, "__ccflow_import_path__")) + self.assertTrue(LocalCallable.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + # type_ should work + instance = LocalCallable() + self.assertEqual(instance.type_.object, LocalCallable) + + # Callable should execute correctly + result = instance(LocalContext(value=21)) + self.assertEqual(result.value, 42) + + class TestWrapperModel(TestCase): def test_wrapper(self): md = MetaData(name="foo", description="My Foo") diff --git a/ccflow/tests/test_global_state.py b/ccflow/tests/test_global_state.py index e5b0d68..470afec 100644 --- a/ccflow/tests/test_global_state.py +++ b/ccflow/tests/test_global_state.py @@ -48,9 +48,8 @@ def test_global_state(root_registry): assert state3.open_overrides == {} -def test_global_state_pickle(): - r = ModelRegistry.root() - r.add("foo", DummyModel(name="foo")) +def test_global_state_pickle(root_registry): + root_registry.add("foo", DummyModel(name="foo")) evaluator = DummyEvaluator() with FlowOptionsOverride(options=dict(evaluator=evaluator)) as override: state = GlobalState() diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py new file mode 100644 index 0000000..586b03f --- /dev/null +++ b/ccflow/tests/test_local_persistence.py @@ -0,0 +1,1454 @@ +"""Tests for local class registration in ccflow. + +The local persistence module allows classes defined inside functions (with '' +in their __qualname__) to work with PyObjectPath serialization by registering them +on ccflow.local_persistence with unique names. + +Key behaviors tested: +1. Local classes get __ccflow_import_path__ set at definition time +2. Module-level classes are NOT registered (they're already importable) +3. Cross-process cloudpickle works via sync_to_module +4. UUID-based naming provides uniqueness +""" + +import re +import subprocess +import sys + +import pytest +import ray + +import ccflow.local_persistence as local_persistence +from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, NullContext +from ccflow.local_persistence import create_ccflow_model + + +class ModuleLevelModel(BaseModel): + value: int + + +class ModuleLevelContext(ContextBase): + value: int + + +class ModuleLevelCallable(CallableModel): + @Flow.call + def __call__(self, context: NullContext) -> GenericResult: + return GenericResult(value="ok") + + +# ============================================================================= +# Tests for register_ccflow_import_path function +# ============================================================================= + + +def test_base_module_available_after_import(): + """Test that ccflow.local_persistence module is available after importing ccflow.""" + assert local_persistence.LOCAL_ARTIFACTS_MODULE_NAME in sys.modules + + +def test_register_preserves_module_qualname_and_sets_import_path(): + """Test that register_ccflow_import_path sets __ccflow_import_path__ without changing __module__ or __qualname__.""" + + def build(): + class Foo: + pass + + return Foo + + Foo = build() + original_module = Foo.__module__ + original_qualname = Foo.__qualname__ + + local_persistence.register_ccflow_import_path(Foo) + + # __module__ and __qualname__ should NOT change (preserves cloudpickle) + assert Foo.__module__ == original_module, "__module__ should not change" + assert Foo.__qualname__ == original_qualname, "__qualname__ should not change" + assert "" in Foo.__qualname__, "__qualname__ should contain ''" + + # __ccflow_import_path__ should be set and point to the registered class + import_path = Foo.__ccflow_import_path__ + registered_name = import_path.split(".")[-1] + module = sys.modules[local_persistence.LOCAL_ARTIFACTS_MODULE_NAME] + assert hasattr(module, registered_name), "Class should be registered in module" + assert getattr(module, registered_name) is Foo, "Registered class should be the same object" + assert import_path.startswith("ccflow.local_persistence._Local_"), "Import path should have expected prefix" + + +def test_register_handles_class_name_starting_with_digit(): + """Test that register_ccflow_import_path handles class names starting with a digit by prefixing with underscore.""" + # Create a class with a name starting with a digit + cls = type("3DModel", (), {}) + local_persistence.register_ccflow_import_path(cls) + + import_path = cls.__ccflow_import_path__ + registered_name = import_path.split(".")[-1] + + # The registered name should start with _Local__ (underscore added for digit) + assert registered_name.startswith("_Local__"), "Registered name should start with _Local__" + assert "_3DModel_" in registered_name, "Registered name should contain _3DModel_" + + # Should be registered on ccflow.local_persistence + module = sys.modules[local_persistence.LOCAL_ARTIFACTS_MODULE_NAME] + assert getattr(module, registered_name) is cls, "Class should be registered on module" + + +def test_sync_to_module_registers_class_not_yet_on_module(): + """Test that sync_to_module registers a class that has __ccflow_import_path__ but isn't on the module yet. + + This happens in cross-process unpickle scenarios where cloudpickle recreates the class + with __ccflow_import_path__ set, but the class isn't yet on ccflow.local_persistence. + """ + # Simulate a class that has __ccflow_import_path__ but isn't registered on ccflow.local_persistence + # (like what happens after cross-process cloudpickle unpickle) + cls = type("SimulatedUnpickled", (), {}) + unique_name = "_Local_SimulatedUnpickled_test123abc" + cls.__ccflow_import_path__ = f"{local_persistence.LOCAL_ARTIFACTS_MODULE_NAME}.{unique_name}" + + # Verify class is NOT on ccflow.local_persistence yet + module = sys.modules[local_persistence.LOCAL_ARTIFACTS_MODULE_NAME] + assert getattr(module, unique_name, None) is None, "Class should NOT be on module before sync" + + # Call sync_to_module + local_persistence.sync_to_module(cls) + + # Verify class IS now on ccflow.local_persistence + assert getattr(module, unique_name, None) is cls, "Class should be on module after sync" + + +# ============================================================================= +# Tests for _register_on_module (internal function) +# ============================================================================= + + +def test_register_on_module_uses_specified_module(): + """Test that _register_on_module registers on the specified module, not just LOCAL_ARTIFACTS_MODULE_NAME.""" + # Use a different module that's already in sys.modules + target_module_name = "ccflow.tests.test_local_persistence" + target_module = sys.modules[target_module_name] + + cls = type("CustomModuleClass", (), {}) + local_persistence._register_on_module(cls, target_module_name) + + # Verify import path uses the specified module name + import_path = cls.__ccflow_import_path__ + assert import_path.startswith(f"{target_module_name}._Local_"), f"Import path should start with {target_module_name}" + + # Verify class is registered on the specified module + registered_name = import_path.split(".")[-1] + assert getattr(target_module, registered_name) is cls, "Class should be on specified module" + + # Verify class is NOT on LOCAL_ARTIFACTS_MODULE_NAME + artifacts_module = sys.modules[local_persistence.LOCAL_ARTIFACTS_MODULE_NAME] + assert getattr(artifacts_module, registered_name, None) is None, "Class should NOT be on artifacts module" + + # Cleanup + delattr(target_module, registered_name) + + +def test_register_on_module_import_path_format(): + """Test that _register_on_module produces correctly formatted import paths.""" + target_module_name = "ccflow.tests.test_local_persistence" + target_module = sys.modules[target_module_name] + + cls = type("FormatTestClass", (), {}) + local_persistence._register_on_module(cls, target_module_name) + + import_path = cls.__ccflow_import_path__ + + # Import path should be: module_name._Local_ClassName_uuid + parts = import_path.rsplit(".", 1) + assert len(parts) == 2, "Import path should have module.name format" + assert parts[0] == target_module_name, "Module part should match target" + assert parts[1].startswith("_Local_FormatTestClass_"), "Name should follow _Local_ClassName_uuid pattern" + assert len(parts[1].split("_")[-1]) == 12, "UUID suffix should be 12 hex chars" + + # Cleanup + delattr(target_module, parts[1]) + + +# ============================================================================= +# Tests for local class registration via BaseModel +# ============================================================================= + + +class TestLocalPersistencePreservesCloudpickle: + """Tests verifying that local persistence preserves cloudpickle behavior.""" + + def test_qualname_has_locals_for_function_defined_class(self): + """Verify that __qualname__ contains '' for classes defined in functions.""" + + def create_class(): + class Inner(BaseModel): + x: int + + return Inner + + cls = create_class() + assert "" in cls.__qualname__ + assert cls.__module__ != local_persistence.LOCAL_ARTIFACTS_MODULE_NAME + + def test_module_not_changed_to_local_artifacts(self): + """Verify that __module__ is NOT changed to ccflow.local_persistence.""" + + def create_class(): + class Inner(ContextBase): + value: str + + return Inner + + cls = create_class() + # __module__ should be this test module, not ccflow.local_persistence + assert cls.__module__ == "ccflow.tests.test_local_persistence" + assert cls.__module__ != local_persistence.LOCAL_ARTIFACTS_MODULE_NAME + + def test_ccflow_import_path_is_set(self): + """Verify that __ccflow_import_path__ is set for local classes.""" + + def create_class(): + class Inner(BaseModel): + y: float + + return Inner + + cls = create_class() + assert hasattr(cls, "__ccflow_import_path__") + assert cls.__ccflow_import_path__.startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME + ".") + + def test_class_registered_in_base_module(self): + """Verify that the class is registered in ccflow.local_persistence under import path.""" + + def create_class(): + class Inner(BaseModel): + z: bool + + return Inner + + cls = create_class() + import_path = cls.__ccflow_import_path__ + registered_name = import_path.split(".")[-1] + + artifacts_module = sys.modules[local_persistence.LOCAL_ARTIFACTS_MODULE_NAME] + assert hasattr(artifacts_module, registered_name) + assert getattr(artifacts_module, registered_name) is cls + + +class TestPyObjectPathWithImportPath: + """Tests for PyObjectPath integration with __ccflow_import_path__.""" + + def test_type_property_uses_import_path(self): + """Verify that the type_ property returns a path using __ccflow_import_path__.""" + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=123) + type_path = str(instance.type_) + + # type_ should use the __ccflow_import_path__, not module.qualname + assert type_path == cls.__ccflow_import_path__ + assert type_path.startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME) + + def test_type_path_can_be_imported(self): + """Verify that the type_ path can be used to import the class.""" + import importlib + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=456) + type_path = str(instance.type_) + + # Should be able to import using the path + parts = type_path.rsplit(".", 1) + module = importlib.import_module(parts[0]) + imported_cls = getattr(module, parts[1]) + assert imported_cls is cls + + def test_type_property_for_context_base(self): + """Verify type_ works for ContextBase subclasses.""" + + def create_class(): + class LocalContext(ContextBase): + name: str + + return LocalContext + + cls = create_class() + instance = cls(name="test") + type_path = str(instance.type_) + + assert type_path == cls.__ccflow_import_path__ + assert instance.type_.object is cls + + def test_json_serialization_includes_target(self): + """Verify JSON serialization includes _target_ using __ccflow_import_path__.""" + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=789) + data = instance.model_dump(mode="python") + + assert "type_" in data or "_target_" in data + # The computed field should use __ccflow_import_path__ + type_value = data.get("type_") or data.get("_target_") + assert str(type_value).startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME) + + +class TestCloudpickleSameProcess: + """Tests for same-process cloudpickle behavior.""" + + def test_cloudpickle_class_roundtrip_same_process(self): + """Verify cloudpickle can serialize and deserialize local classes in same process.""" + from ray.cloudpickle import dumps, loads + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + restored_cls = loads(dumps(cls)) + + # Should be the same object (cloudpickle recognizes it's in the same process) + assert restored_cls is cls + + def test_cloudpickle_instance_roundtrip_same_process(self): + """Verify cloudpickle can serialize and deserialize instances in same process.""" + from ray.cloudpickle import dumps, loads + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=42) + restored = loads(dumps(instance)) + + assert restored.value == 42 + assert type(restored) is cls + + def test_cloudpickle_preserves_type_path(self): + """Verify type_ works after cloudpickle roundtrip in same process.""" + from ray.cloudpickle import dumps, loads + + def create_class(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + cls = create_class() + instance = cls(value=100) + original_type_path = str(instance.type_) + + restored = loads(dumps(instance)) + restored_type_path = str(restored.type_) + + assert restored_type_path == original_type_path + + +# ============================================================================= +# Cross-process cloudpickle tests +# ============================================================================= + + +@pytest.fixture +def pickle_file(tmp_path): + """Provide a temporary pickle file path with automatic cleanup.""" + pkl_path = tmp_path / "test.pkl" + yield str(pkl_path) + + +class TestCloudpickleCrossProcess: + """Tests for cross-process cloudpickle behavior (subprocess tests). + + These tests verify that ccflow classes (BaseModel, ContextBase, CallableModel) + with local or __main__ scope can be pickled in one process and unpickled in another, + with .type_ (PyObjectPath) working correctly after unpickle. + """ + + @pytest.mark.parametrize( + "create_code,load_code", + [ + pytest.param( + # Local-scope BaseModel + """ +from ray.cloudpickle import dump +from ccflow import BaseModel + +def create_local(): + class LocalModel(BaseModel): + value: int + return LocalModel + +LocalModel = create_local() +assert "" in LocalModel.__qualname__ +assert hasattr(LocalModel, "__ccflow_import_path__") + +instance = LocalModel(value=42) +_ = instance.type_ + +with open("{pkl_path}", "wb") as f: + dump(instance, f) +print("SUCCESS") +""", + """ +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + obj = load(f) + +assert obj.value == 42 +assert obj.type_.object is type(obj) +print("SUCCESS") +""", + id="local_basemodel", + ), + pytest.param( + # Local-scope ContextBase + """ +from ray.cloudpickle import dump +from ccflow import ContextBase + +def create_local(): + class LocalContext(ContextBase): + name: str + value: int + return LocalContext + +LocalContext = create_local() +assert "" in LocalContext.__qualname__ + +instance = LocalContext(name="test", value=42) +_ = instance.type_ + +with open("{pkl_path}", "wb") as f: + dump(instance, f) +print("SUCCESS") +""", + """ +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + obj = load(f) + +assert obj.name == "test" +assert obj.value == 42 +assert obj.type_.object is type(obj) +print("SUCCESS") +""", + id="local_context", + ), + pytest.param( + # Local-scope CallableModel (also tests callable execution) + """ +from ray.cloudpickle import dump +from ccflow import CallableModel, ContextBase, GenericResult, Flow + +def create_local(): + class LocalContext(ContextBase): + x: int + + class LocalCallable(CallableModel): + multiplier: int = 2 + + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + return GenericResult(value=context.x * self.multiplier) + + return LocalContext, LocalCallable + +LocalContext, LocalCallable = create_local() +model = LocalCallable(multiplier=3) +ctx = LocalContext(x=10) +result = model(ctx) +assert result.value == 30 + +with open("{pkl_path}", "wb") as f: + dump((model, ctx), f) +print("SUCCESS") +""", + """ +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + model, ctx = load(f) + +result = model(ctx) +assert result.value == 30 +assert model.type_.object is type(model) +assert ctx.type_.object is type(ctx) +print("SUCCESS") +""", + id="local_callable", + ), + pytest.param( + # __main__ module class (not inside a function) + # cloudpickle recreates but doesn't add to sys.modules["__main__"] + """ +from ray.cloudpickle import dump +from ccflow import ContextBase + +class MainContext(ContextBase): + value: int + +assert MainContext.__module__ == "__main__" +assert hasattr(MainContext, "__ccflow_import_path__") + +instance = MainContext(value=42) +_ = instance.type_ + +with open("{pkl_path}", "wb") as f: + dump(instance, f) +print("SUCCESS") +""", + """ +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + obj = load(f) + +assert obj.value == 42 +assert obj.type_.object is type(obj) +print("SUCCESS") +""", + id="main_module", + ), + ], + ) + def test_cross_process_cloudpickle(self, pickle_file, create_code, load_code): + """Test that ccflow classes work with cloudpickle across processes.""" + pkl_path = pickle_file + + create_result = subprocess.run( + [sys.executable, "-c", create_code.format(pkl_path=pkl_path)], + capture_output=True, + text=True, + ) + assert create_result.returncode == 0, f"Create failed: {create_result.stderr}" + assert "SUCCESS" in create_result.stdout + + load_result = subprocess.run( + [sys.executable, "-c", load_code.format(pkl_path=pkl_path)], + capture_output=True, + text=True, + ) + assert load_result.returncode == 0, f"Load failed: {load_result.stderr}" + assert "SUCCESS" in load_result.stdout + + +# ============================================================================= +# Module-level classes should not be affected +# ============================================================================= + + +class TestModuleLevelClassesUnaffected: + """Tests verifying that module-level classes are not affected by local persistence.""" + + def test_module_level_class_no_import_path(self): + """Verify module-level classes don't get __ccflow_import_path__.""" + assert not hasattr(ModuleLevelModel, "__ccflow_import_path__") + assert not hasattr(ModuleLevelContext, "__ccflow_import_path__") + assert not hasattr(ModuleLevelCallable, "__ccflow_import_path__") + + def test_module_level_class_type_path_uses_qualname(self): + """Verify module-level classes use standard module.qualname for type_.""" + instance = ModuleLevelModel(value=1) + type_path = str(instance.type_) + + # Should use standard path, not ccflow.local_persistence._Local_... + assert type_path == "ccflow.tests.test_local_persistence.ModuleLevelModel" + assert "_Local_" not in type_path + + def test_module_level_standard_pickle_works(self): + """Verify standard pickle works for module-level classes.""" + from pickle import dumps, loads + + instance = ModuleLevelModel(value=42) + restored = loads(dumps(instance)) + assert restored.value == 42 + assert type(restored) is ModuleLevelModel + + +# ============================================================================= +# Ray task tests +# ============================================================================= + + +class TestRayTaskWithLocalClasses: + """Tests for Ray task execution with locally-defined classes.""" + + def test_local_callable_model_ray_task(self): + """Test that locally-defined CallableModels can be sent to Ray tasks.""" + + def create_local_callable(): + class LocalContext(ContextBase): + x: int + + class LocalCallable(CallableModel): + multiplier: int = 2 + + @Flow.call + def __call__(self, context: LocalContext) -> GenericResult: + return GenericResult(value=context.x * self.multiplier) + + return LocalContext, LocalCallable + + LocalContext, LocalCallable = create_local_callable() + + # Verify is in qualname (ensures cloudpickle serializes definition) + assert "" in LocalCallable.__qualname__ + assert "" in LocalContext.__qualname__ + + # Verify __ccflow_import_path__ is set + assert hasattr(LocalCallable, "__ccflow_import_path__") + assert hasattr(LocalContext, "__ccflow_import_path__") + + @ray.remote + def run_callable(model, context): + result = model(context) + # Verify type_ works inside the Ray task (cross-process PyObjectPath) + _ = model.type_ + _ = context.type_ + return result.value + + model = LocalCallable(multiplier=3) + context = LocalContext(x=10) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, context)) + + assert result == 30 + + def test_local_context_ray_task(self): + """Test that locally-defined ContextBase can be sent to Ray tasks.""" + + def create_local_context(): + class LocalContext(ContextBase): + name: str + value: int + + return LocalContext + + LocalContext = create_local_context() + assert "" in LocalContext.__qualname__ + + @ray.remote + def process_context(ctx): + # Access fields and type_ inside Ray task + _ = ctx.type_ + return f"{ctx.name}:{ctx.value}" + + context = LocalContext(name="test", value=42) + + with ray.init(num_cpus=1): + result = ray.get(process_context.remote(context)) + + assert result == "test:42" + + def test_local_base_model_ray_task(self): + """Test that locally-defined BaseModel can be sent to Ray tasks.""" + + def create_local_model(): + class LocalModel(BaseModel): + data: str + + return LocalModel + + LocalModel = create_local_model() + assert "" in LocalModel.__qualname__ + + @ray.remote + def process_model(m): + # Access type_ inside Ray task + type_path = str(m.type_) + return f"{m.data}|{type_path}" + + model = LocalModel(data="hello") + + with ray.init(num_cpus=1): + result = ray.get(process_model.remote(model)) + + assert result.startswith("hello|ccflow.local_persistence._Local_") + + +# ============================================================================= +# UUID uniqueness tests +# ============================================================================= + + +class TestUUIDUniqueness: + """Tests verifying UUID-based naming provides uniqueness.""" + + def test_multiple_local_classes_same_name_get_unique_paths(self): + """Test that multiple local classes with same name get unique import paths.""" + + def create_model_a(): + class SameName(BaseModel): + value: int + + return SameName + + def create_model_b(): + class SameName(BaseModel): + value: str + + return SameName + + ModelA = create_model_a() + ModelB = create_model_b() + + # Both have same class name but different import paths + assert ModelA.__name__ == ModelB.__name__ == "SameName" + assert ModelA.__ccflow_import_path__ != ModelB.__ccflow_import_path__ + + # Both should be accessible + instance_a = ModelA(value=42) + instance_b = ModelB(value="hello") + assert instance_a.type_.object is ModelA + assert instance_b.type_.object is ModelB + + def test_uuid_format_is_valid(self): + """Test that the UUID portion of names is valid hex.""" + + def create_class(): + class TestModel(BaseModel): + x: int + + return TestModel + + Model = create_class() + import_path = Model.__ccflow_import_path__ + + # Extract UUID portion + match = re.search(r"_Local_TestModel_([a-f0-9]+)$", import_path) + assert match is not None, f"Import path doesn't match expected format: {import_path}" + + uuid_part = match.group(1) + assert len(uuid_part) == 12, f"UUID should be 12 hex chars, got {len(uuid_part)}" + assert all(c in "0123456789abcdef" for c in uuid_part) + + +# ============================================================================= +# Nested class and inheritance tests +# ============================================================================= + + +class OuterClass: + """Module-level outer class for testing nested class importability.""" + + class NestedModel(BaseModel): + """A BaseModel nested inside a module-level class.""" + + value: int + + +class TestNestedClasses: + """Tests for classes nested inside other classes.""" + + def test_nested_class_inside_module_level_class_not_registered(self): + """Verify that a nested class inside a module-level class is NOT registered. + + Classes nested inside module-level classes (like OuterClass.NestedModel) + have qualnames like 'OuterClass.NestedModel' without '' and ARE + importable via the standard module.qualname path. + """ + # The qualname has a '.' indicating nested class, but no '' + assert "." in OuterClass.NestedModel.__qualname__ + assert OuterClass.NestedModel.__qualname__ == "OuterClass.NestedModel" + assert "" not in OuterClass.NestedModel.__qualname__ + + # Should NOT have __ccflow_import_path__ + assert "__ccflow_import_path__" not in OuterClass.NestedModel.__dict__ + + # type_ should use standard path + instance = OuterClass.NestedModel(value=42) + type_path = str(instance.type_) + assert type_path == "ccflow.tests.test_local_persistence.OuterClass.NestedModel" + assert "_Local_" not in type_path + assert instance.type_.object is OuterClass.NestedModel + + def test_nested_class_inside_function_is_registered(self): + """Verify that a class nested inside a function-defined class IS registered.""" + + def create_outer(): + class Outer: + class Inner(BaseModel): + value: int + + return Outer + + Outer = create_outer() + # The inner class has in its qualname (from the function) + assert "" in Outer.Inner.__qualname__ + + # Should be registered and have __ccflow_import_path__ + assert hasattr(Outer.Inner, "__ccflow_import_path__") + + +class TestInheritanceDoesNotPropagateImportPath: + """Tests verifying that __ccflow_import_path__ is not inherited by subclasses.""" + + def test_subclass_of_local_class_gets_own_registration(self): + """Verify that subclassing a local class doesn't inherit __ccflow_import_path__.""" + + def create_base(): + class LocalBase(BaseModel): + value: int + + return LocalBase + + def create_derived(base_cls): + class LocalDerived(base_cls): + extra: str = "default" + + return LocalDerived + + Base = create_base() + Derived = create_derived(Base) + + # Both should have __ccflow_import_path__ in their own __dict__ + assert "__ccflow_import_path__" in Base.__dict__ + assert "__ccflow_import_path__" in Derived.__dict__ + + # They should have DIFFERENT import paths + assert Base.__ccflow_import_path__ != Derived.__ccflow_import_path__ + + # Both should be importable + base_instance = Base(value=1) + derived_instance = Derived(value=2, extra="test") + + assert base_instance.type_.object is Base + assert derived_instance.type_.object is Derived + + def test_subclass_of_module_level_class_is_registered(self): + """Verify that subclassing a module-level class inside a function creates a local class.""" + + def create_subclass(): + class LocalSubclass(ModuleLevelModel): + extra: str = "default" + + return LocalSubclass + + Subclass = create_subclass() + + # The subclass is local (defined in function), so needs registration + assert "" in Subclass.__qualname__ + assert "__ccflow_import_path__" in Subclass.__dict__ + + # But the parent should NOT have __ccflow_import_path__ + assert "__ccflow_import_path__" not in ModuleLevelModel.__dict__ + + +# ============================================================================= +# Generic types tests +# ============================================================================= + + +class TestGenericTypes: + """Tests for generic types and PyObjectPath.""" + + def test_unparameterized_generic_type_path(self): + """Test that unparameterized generic types work with type_.""" + from typing import Generic, TypeVar + + T = TypeVar("T") + + def create_generic(): + class GenericModel(BaseModel, Generic[T]): + data: T # Will be Any when unparameterized + + return GenericModel + + GenericModel = create_generic() + + # Create an unparameterized instance + instance = GenericModel(data=42) + + # type_ should work + type_path = str(instance.type_) + assert "_Local_" in type_path + assert "GenericModel" in type_path + assert instance.type_.object is GenericModel + + def test_generic_base_class_is_registered(self): + """Test that the unparameterized generic class is correctly registered. + + Note: Parameterized generics (e.g., GenericModel[int]) create new classes + that lose the '' marker in their qualname due to Python/pydantic's + generic handling. Use unparameterized generics for local class scenarios, + or define concrete subclasses. + """ + from typing import Generic, TypeVar + + T = TypeVar("T") + + def create_generic(): + class GenericModel(BaseModel, Generic[T]): + data: T + + return GenericModel + + GenericModel = create_generic() + + # The unparameterized class should be registered + assert "" in GenericModel.__qualname__ + assert hasattr(GenericModel, "__ccflow_import_path__") + assert GenericModel.__ccflow_import_path__.startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME) + + +# ============================================================================= +# Import string tests +# ============================================================================= + + +class TestImportString: + """Tests for the import_string function.""" + + def test_import_string_handles_nested_class_path(self): + """Verify our import_string handles nested class paths that pydantic's ImportString cannot.""" + from pydantic import ImportString, TypeAdapter + + from ccflow.exttypes.pyobjectpath import import_string + + nested_path = "ccflow.tests.test_local_persistence.OuterClass.NestedModel" + + # Pydantic's ImportString fails on nested class paths + pydantic_adapter = TypeAdapter(ImportString) + with pytest.raises(Exception, match="No module named"): + pydantic_adapter.validate_python(nested_path) + + # Our import_string handles it correctly + result = import_string(nested_path) + assert result is OuterClass.NestedModel + + def test_import_string_still_works_for_simple_paths(self): + """Verify import_string still works for simple module.ClassName paths.""" + from ccflow.exttypes.pyobjectpath import import_string + + # Simple class path + result = import_string("ccflow.tests.test_local_persistence.ModuleLevelModel") + assert result is ModuleLevelModel + + # Built-in module + result = import_string("os.path.join") + import os.path + + assert result is os.path.join + + +# ============================================================================= +# Registration strategy tests +# ============================================================================= + + +class TestRegistrationStrategy: + """Tests verifying the registration strategy for different class types.""" + + def test_module_level_classes_not_registered(self): + """Module-level classes should NOT have __ccflow_import_path__ set.""" + # ModuleLevelModel is defined at module level in this file + assert "__ccflow_import_path__" not in ModuleLevelModel.__dict__ + assert "" not in ModuleLevelModel.__qualname__ + + # Nested classes at module level also shouldn't need registration + assert "__ccflow_import_path__" not in OuterClass.NestedModel.__dict__ + + def test_local_class_registered_immediately(self): + """Local classes (with in qualname) should be registered during definition.""" + from unittest import mock + + # Must patch where it's used (base.py), not where it's defined (local_persistence) + with mock.patch("ccflow.base.register_ccflow_import_path") as mock_do_reg: + + def create_local(): + class LocalModel(BaseModel): + value: int + + return LocalModel + + LocalModel = create_local() + + # register_ccflow_import_path SHOULD be called immediately for local classes + mock_do_reg.assert_called_once() + # Verify it has in qualname + assert "" in LocalModel.__qualname__ + + +# ============================================================================= +# Tests for create_ccflow_model wrapper +# ============================================================================= + + +class TestCreateCcflowModelWrapper: + """Tests for the create_ccflow_model wrapper function.""" + + def test_create_ccflow_model_basic(self): + """Test basic create_ccflow_model usage with ContextBase.""" + DynamicContext = create_ccflow_model( + "DynamicContext", + __base__=ContextBase, + name=(str, ...), + value=(int, 0), + ) + + # Should be a valid ContextBase subclass + assert issubclass(DynamicContext, ContextBase) + + # Should be registered + assert hasattr(DynamicContext, "__ccflow_import_path__") + assert DynamicContext.__ccflow_import_path__.startswith(local_persistence.LOCAL_ARTIFACTS_MODULE_NAME) + + # Should be usable + ctx = DynamicContext(name="test", value=42) + assert ctx.name == "test" + assert ctx.value == 42 + + def test_create_ccflow_model_with_base_model(self): + """Test create_ccflow_model with ccflow BaseModel as base.""" + DynamicModel = create_ccflow_model( + "DynamicModel", + __base__=BaseModel, + data=(str, "default"), + count=(int, 0), + ) + + assert issubclass(DynamicModel, BaseModel) + assert hasattr(DynamicModel, "__ccflow_import_path__") + + instance = DynamicModel(data="hello", count=5) + assert instance.data == "hello" + assert instance.count == 5 + + def test_create_ccflow_model_type_property_works(self): + """Test that type_ property works for dynamically created models.""" + DynamicContext = create_ccflow_model( + "DynamicContext", + __base__=ContextBase, + x=(int, ...), + ) + + ctx = DynamicContext(x=10) + type_path = str(ctx.type_) + + # type_ should use __ccflow_import_path__ + assert type_path == DynamicContext.__ccflow_import_path__ + assert "_Local_" in type_path + assert ctx.type_.object is DynamicContext + + def test_create_ccflow_model_can_be_imported(self): + """Test that dynamically created models can be imported via their path.""" + import importlib + + DynamicModel = create_ccflow_model( + "ImportableModel", + __base__=BaseModel, + value=(int, 0), + ) + + import_path = DynamicModel.__ccflow_import_path__ + parts = import_path.rsplit(".", 1) + module = importlib.import_module(parts[0]) + imported_cls = getattr(module, parts[1]) + + assert imported_cls is DynamicModel + + def test_create_ccflow_model_with_docstring(self): + """Test create_ccflow_model with custom docstring.""" + + DynamicModel = create_ccflow_model( + "DocumentedModel", + __base__=BaseModel, + __doc__="A dynamically created model for testing.", + value=(int, 0), + ) + + assert DynamicModel.__doc__ == "A dynamically created model for testing." + + def test_create_ccflow_model_with_complex_fields(self): + """Test create_ccflow_model with various field types.""" + from typing import List, Optional + + from pydantic import Field + + DynamicModel = create_ccflow_model( + "ComplexModel", + __base__=BaseModel, + name=(str, ...), + tags=(List[str], Field(default_factory=list)), + description=(Optional[str], None), + count=(int, 0), + ) + + instance = DynamicModel(name="test") + assert instance.name == "test" + assert instance.tags == [] + assert instance.description is None + assert instance.count == 0 + + instance2 = DynamicModel(name="test2", tags=["a", "b"], description="desc", count=5) + assert instance2.tags == ["a", "b"] + assert instance2.description == "desc" + assert instance2.count == 5 + + def test_create_ccflow_model_multiple_unique_names(self): + """Test that multiple models with same name get unique registration paths.""" + + Model1 = create_ccflow_model("SameName", __base__=BaseModel, value=(int, 0)) + Model2 = create_ccflow_model("SameName", __base__=BaseModel, value=(str, "")) + + # Both should be registered with different paths + assert Model1.__ccflow_import_path__ != Model2.__ccflow_import_path__ + + # Both should have the same __name__ + assert Model1.__name__ == Model2.__name__ == "SameName" + + # Both should be accessible via their own paths + assert Model1(value=42).type_.object is Model1 + assert Model2(value="test").type_.object is Model2 + + def test_create_ccflow_model_inherits_from_context_base(self): + """Test that models inheriting from ContextBase have frozen config.""" + DynamicContext = create_ccflow_model( + "FrozenContext", + __base__=ContextBase, + value=(int, 0), + ) + + ctx = DynamicContext(value=42) + + # ContextBase subclasses should be frozen + with pytest.raises(Exception): # ValidationError for frozen model + ctx.value = 100 + + +class TestCreateCcflowModelCloudpickleSameProcess: + """Tests for cloudpickle with dynamically created models in the same process.""" + + def test_cloudpickle_instance_roundtrip(self): + """Test cloudpickle roundtrip for instances of dynamically created models.""" + from ray.cloudpickle import dumps, loads + + DynamicModel = create_ccflow_model( + "PickleTestModel", + __base__=BaseModel, + value=(int, 0), + ) + + instance = DynamicModel(value=123) + restored = loads(dumps(instance)) + + assert restored.value == 123 + assert type(restored) is DynamicModel + + def test_cloudpickle_class_roundtrip(self): + """Test cloudpickle roundtrip for dynamically created model classes.""" + from ray.cloudpickle import dumps, loads + + DynamicModel = create_ccflow_model( + "ClassPickleModel", + __base__=BaseModel, + name=(str, ""), + ) + + restored_cls = loads(dumps(DynamicModel)) + assert restored_cls is DynamicModel + + def test_cloudpickle_preserves_type_path(self): + """Test that type_ path is preserved after cloudpickle roundtrip.""" + from ray.cloudpickle import dumps, loads + + DynamicModel = create_ccflow_model( + "TypePathModel", + __base__=BaseModel, + value=(int, 0), + ) + + instance = DynamicModel(value=42) + original_path = str(instance.type_) + + restored = loads(dumps(instance)) + restored_path = str(restored.type_) + + assert restored_path == original_path + + +class TestCreateCcflowModelCloudpickleCrossProcess: + """Tests for cross-process cloudpickle with dynamically created models (via create_ccflow_model).""" + + @pytest.mark.parametrize( + "create_code,load_code", + [ + pytest.param( + # Context only + """ +from ray.cloudpickle import dump +from ccflow import ContextBase +from ccflow.local_persistence import create_ccflow_model + +DynamicContext = create_ccflow_model( + "CrossProcessContext", + __base__=ContextBase, + name=(str, ...), + value=(int, 0), +) + +assert hasattr(DynamicContext, "__ccflow_import_path__") + +ctx = DynamicContext(name="test", value=42) +assert "_Local_" in str(ctx.type_) + +with open("{pkl_path}", "wb") as f: + dump(ctx, f) +print("SUCCESS") +""", + """ +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + ctx = load(f) + +assert ctx.name == "test" +assert ctx.value == 42 +assert "_Local_" in str(ctx.type_) +assert ctx.type_.object is type(ctx) +print("SUCCESS") +""", + id="context_only", + ), + pytest.param( + # Dynamic context with CallableModel + """ +from ray.cloudpickle import dump +from ccflow import CallableModel, ContextBase, GenericResult, Flow +from ccflow.local_persistence import create_ccflow_model + +DynamicContext = create_ccflow_model( + "CallableModelContext", + __base__=ContextBase, + x=(int, ...), + multiplier=(int, 2), +) + +def create_callable(): + class DynamicCallable(CallableModel): + @Flow.call + def __call__(self, context: DynamicContext) -> GenericResult: + return GenericResult(value=context.x * context.multiplier) + return DynamicCallable + +DynamicCallable = create_callable() +model = DynamicCallable() +ctx = DynamicContext(x=10, multiplier=3) +result = model(ctx) +assert result.value == 30 + +with open("{pkl_path}", "wb") as f: + dump((model, ctx), f) +print("SUCCESS") +""", + """ +from ray.cloudpickle import load + +with open("{pkl_path}", "rb") as f: + model, ctx = load(f) + +result = model(ctx) +assert result.value == 30 +assert ctx.type_.object is type(ctx) +assert model.type_.object is type(model) +print("SUCCESS") +""", + id="with_callable", + ), + ], + ) + def test_create_ccflow_model_cross_process(self, pickle_file, create_code, load_code): + """Test that dynamically created models work across processes.""" + pkl_path = pickle_file + + create_result = subprocess.run( + [sys.executable, "-c", create_code.format(pkl_path=pkl_path)], + capture_output=True, + text=True, + ) + assert create_result.returncode == 0, f"Create failed: {create_result.stderr}" + assert "SUCCESS" in create_result.stdout + + load_result = subprocess.run( + [sys.executable, "-c", load_code.format(pkl_path=pkl_path)], + capture_output=True, + text=True, + ) + assert load_result.returncode == 0, f"Load failed: {load_result.stderr}" + assert "SUCCESS" in load_result.stdout + + +class TestCreateCcflowModelRayTask: + """Tests for Ray task execution with dynamically created models.""" + + def test_create_ccflow_model_ray_task(self): + """Test that dynamically created models work in Ray tasks.""" + + DynamicContext = create_ccflow_model( + "RayTaskContext", + __base__=ContextBase, + name=(str, ...), + value=(int, 0), + ) + + @ray.remote + def process_context(ctx): + # Access fields and type_ inside Ray task + _ = ctx.type_ + return f"{ctx.name}:{ctx.value}" + + ctx = DynamicContext(name="ray_test", value=99) + + with ray.init(num_cpus=1): + result = ray.get(process_context.remote(ctx)) + + assert result == "ray_test:99" + + def test_create_ccflow_model_callable_model_ray_task(self): + """Test CallableModel with dynamically created context in Ray tasks.""" + + DynamicContext = create_ccflow_model( + "RayCallableContext", + __base__=ContextBase, + x=(int, ...), + ) + + class RayCallable(CallableModel): + factor: int = 2 + + @Flow.call + def __call__(self, context: DynamicContext) -> GenericResult: + return GenericResult(value=context.x * self.factor) + + @ray.remote + def run_callable(model, ctx): + result = model(ctx) + # Verify type_ works in Ray worker + _ = model.type_ + _ = ctx.type_ + return result.value + + model = RayCallable(factor=5) + ctx = DynamicContext(x=10) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, ctx)) + + assert result == 50 + + +class TestCreateCcflowModelEdgeCases: + """Tests for edge cases in create_ccflow_model wrapper.""" + + def test_create_ccflow_model_no_fields(self): + """Test create_ccflow_model with no custom fields.""" + + EmptyModel = create_ccflow_model("EmptyModel", __base__=BaseModel) + + assert issubclass(EmptyModel, BaseModel) + assert hasattr(EmptyModel, "__ccflow_import_path__") + + instance = EmptyModel() + assert instance.type_.object is EmptyModel + + def test_create_ccflow_model_with_module_override(self): + """Test create_ccflow_model with __module__ parameter.""" + + CustomModuleModel = create_ccflow_model( + "CustomModuleModel", + __base__=BaseModel, + __module__="custom.module.path", + value=(int, 0), + ) + + # Module should be overridden + assert CustomModuleModel.__module__ == "custom.module.path" + + # But should still be registered since it's not actually importable + assert hasattr(CustomModuleModel, "__ccflow_import_path__") + + def test_create_ccflow_model_inheritance_from_custom_base(self): + """Test create_ccflow_model inheriting from a custom ccflow class.""" + + # First create a base class + class CustomBase(ContextBase): + base_field: str = "base" + + DerivedModel = create_ccflow_model( + "DerivedModel", + __base__=CustomBase, + derived_field=(int, 0), + ) + + assert issubclass(DerivedModel, CustomBase) + assert issubclass(DerivedModel, ContextBase) + + instance = DerivedModel(derived_field=42) + assert instance.base_field == "base" + assert instance.derived_field == 42 + + def test_create_ccflow_model_special_characters_in_name(self): + """Test create_ccflow_model handles special characters in name.""" + + # Names with special characters should still work + SpecialModel = create_ccflow_model( + "Model-With-Dashes", + __base__=BaseModel, + value=(int, 0), + ) + + assert hasattr(SpecialModel, "__ccflow_import_path__") + + # The registered name should be sanitized + import_path = SpecialModel.__ccflow_import_path__ + registered_name = import_path.split(".")[-1] + # Should have sanitized the dashes + assert "-" not in registered_name + + def test_create_ccflow_model_returns_correct_type(self): + """Test that create_ccflow_model returns the model class, not an instance.""" + + result = create_ccflow_model( + "TypeCheckModel", + __base__=BaseModel, + value=(int, 0), + ) + + assert isinstance(result, type) + assert issubclass(result, BaseModel) + + def test_create_ccflow_model_import_at_top_level(self): + """Test that create_ccflow_model can be imported from ccflow.""" + from ccflow import create_ccflow_model as ccflow_create_model + from ccflow.local_persistence import create_ccflow_model as lp_create_model + + # Both should be the same function + assert ccflow_create_model is lp_create_model + + # And both should work + Model = ccflow_create_model("TopLevelImportModel", __base__=BaseModel, value=(int, 0)) + assert hasattr(Model, "__ccflow_import_path__")