diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 163f275..c8d2259 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -10,7 +10,9 @@ from .compose import * from .callable import * from .context import * +from .dep import * from .enums import Enum +from .flow_model import FlowAPI, BoundModel, Lazy from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/callable.py b/ccflow/callable.py index 748759c..1aa7189 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,6 +14,7 @@ import abc import inspect import logging +from contextvars import ContextVar from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -28,6 +29,7 @@ ResultBase, ResultType, ) +from .dep import Dep, extract_dep from .local_persistence import create_ccflow_model from .validators import str_to_log_level @@ -46,6 +48,8 @@ "EvaluatorBase", "Evaluator", "WrapperModel", + # Note: resolve() is intentionally not in __all__ to avoid namespace pollution. + # Users who need it can import explicitly: from ccflow.callable import resolve ) log = logging.getLogger(__name__) @@ -128,7 +132,7 @@ def _check_result_type(cls, result_type): @model_validator(mode="after") def _check_signature(self): sig_call = _cached_signature(self.__class__.__call__) - if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context") + if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: raise ValueError("__call__ method must take a single argument, named 'context'") sig_deps = _cached_signature(self.__class__.__deps__) @@ -195,6 +199,176 @@ def _get_logging_evaluator(log_level): return LoggingEvaluator(log_level=log_level) +def _get_dep_fields(model_class) -> Dict[str, Dep]: + """Analyze class fields to find Dep-annotated fields. + + Returns a dict mapping field name to Dep instance for fields that need resolution. + """ + dep_fields = {} + + # Get type hints from the class + hints = {} + for cls in model_class.__mro__: + if hasattr(cls, "__annotations__"): + for name, annotation in cls.__annotations__.items(): + if name not in hints: # Don't override child class annotations + hints[name] = annotation + + for name, annotation in hints.items(): + base_type, dep = extract_dep(annotation) + if dep is not None: + dep_fields[name] = dep + + return dep_fields + + +def _wrap_with_dep_resolution(fn): + """Wrap a function to auto-resolve DepOf fields before calling. + + For each Dep-annotated field on the model that contains a CallableModel, + resolves it using __deps__ and temporarily sets the resolved value on self. + + Note: This wrapper is only applied at runtime when the function is called, + not during decoration. This avoids issues with functools.wraps flattening + the __wrapped__ chain. + + Args: + fn: The original function + + Returns: + The original function unchanged - dep resolution happens at the call site + """ + # Don't modify the function - dep resolution is handled in ModelEvaluationContext + return fn + + +# Context variable for storing resolved dependency values during __call__ +# Maps id(callable_model) -> resolved_value +_resolved_deps: ContextVar[Dict[int, Any]] = ContextVar("resolved_deps", default={}) + +# TypeVar for resolve() function to enable proper type inference +_T = TypeVar("_T") + + +def resolve(dep: Union[_T, "_CallableModel"]) -> _T: + """Access the resolved value of a DepOf dependency during __call__. + + This function is used inside a CallableModel's __call__ method to get + the resolved value of a dependency field. It provides proper type inference - + if the field is `DepOf[..., GenericResult[int]]`, this returns `GenericResult[int]`. + + Args: + dep: The dependency field value (either a CallableModel or already-resolved value) + + Returns: + The resolved value. If dep is already a resolved value (not a CallableModel), + returns it unchanged. + + Raises: + RuntimeError: If called outside of __call__ or if the dependency wasn't resolved. + + Example: + class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: MyContext) -> GenericResult[int]: + # resolve() provides proper type inference + data = resolve(self.data) # type: GenericResult[int] + return GenericResult(value=data.value + 1) + """ + # If it's not a CallableModel, it's already a resolved value - pass through + if not isinstance(dep, _CallableModel): + return dep # type: ignore[return-value] + + # Look up in context var + store = _resolved_deps.get() + dep_id = id(dep) + if dep_id not in store: + raise RuntimeError( + "resolve() can only be used inside __call__ for DepOf fields. Make sure the field is annotated with DepOf and contains a CallableModel." + ) + return store[dep_id] + + +def _resolve_deps_and_call(model, context, fn): + """Resolve DepOf fields and call the function. + + This is called from ModelEvaluationContext.__call__ to handle dep resolution. + Resolved values are stored in a context variable and accessed via resolve(). + + Args: + model: The CallableModel instance + context: The context to pass to the function + fn: The function to call + + Returns: + The result of calling fn(model, context) + """ + # Don't resolve deps for __deps__ method + if fn.__name__ == "__deps__": + return fn(model, context) + + # Get Dep-annotated fields for this model class + dep_fields = _get_dep_fields(model.__class__) + + # Check if model has custom deps (from @func.deps decorator) + has_custom_deps = getattr(model.__class__, "__has_custom_deps__", False) + + if not dep_fields and not has_custom_deps: + return fn(model, context) + + # Get dependencies from __deps__ + deps_result = model.__deps__(context) + # Build a map from model instance id to (model, contexts) for lookup + dep_map = {} + for dep_model, contexts in deps_result: + dep_map[id(dep_model)] = (dep_model, contexts) + + # Resolve dependencies and store in context var + resolved_values = {} + + # If custom deps, resolve ALL CallableModel fields from dep_map + if has_custom_deps: + for dep_model, contexts in deps_result: + resolved = dep_model(contexts[0]) if contexts else dep_model(context) + # Unwrap GenericResult if present (consistent with auto-detected deps) + if hasattr(resolved, 'value'): + resolved = resolved.value + resolved_values[id(dep_model)] = resolved + else: + # Standard path: iterate over Dep-annotated fields + for field_name, dep in dep_fields.items(): + field_value = getattr(model, field_name, None) + if field_value is None: + continue + + # Check if field is a CallableModel that needs resolution + if not isinstance(field_value, _CallableModel): + continue # Already a resolved value, skip + + # Check if this field is in __deps__ (for custom transforms) + if id(field_value) in dep_map: + dep_model, contexts = dep_map[id(field_value)] + # Call dependency with the (transformed) context + resolved = dep_model(contexts[0]) if contexts else dep_model(context) + else: + # Not in __deps__, use Dep annotation transform directly + transformed_ctx = dep.apply(context) + resolved = field_value(transformed_ctx) + + resolved_values[id(field_value)] = resolved + + # Store in context var and call function + current_store = _resolved_deps.get() + new_store = {**current_store, **resolved_values} + token = _resolved_deps.set(new_store) + try: + return fn(model, context) + finally: + _resolved_deps.reset(token) + + class FlowOptions(BaseModel): """Options for Flow evaluation. @@ -246,6 +420,9 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase": return self._get_evaluator_from_options(options) def __call__(self, fn): + # Wrap function with dependency resolution for DepOf fields + fn = _wrap_with_dep_resolution(fn) + # Used for building a graph of model evaluation contexts without evaluating def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None): # Create the evaluation context. @@ -451,6 +628,33 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: # The generated context inherits from DateContext, so it's compatible # with infrastructure expecting DateContext instances. + + Auto-Resolve Dependencies Example: + When __call__ has parameters beyond 'self' and 'context' that match field + names annotated with DepOf/Dep, those dependencies are automatically resolved + using __deps__ (if defined) or auto-generated from Dep annotations. + + class MyModel(CallableModel): + data: Annotated[GenericResult[dict], Dep(transform=my_transform)] + + @Flow.call + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: + # data is automatically resolved - no manual calling needed + return GenericResult(value=process(data.value)) + + For transforms that need access to instance fields, define __deps__ manually: + + class MyModel(CallableModel): + data: DepOf[..., GenericResult[dict]] + window: int = 7 + + def __deps__(self, context): + # Can access self.window here + return [(self.data, [context.with_lookback(self.window)])] + + @Flow.call + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: + return GenericResult(value=process(data.value)) """ # Extract auto_context option (not part of FlowOptions) # Can be: False, True, or a ContextBase subclass @@ -502,6 +706,78 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def model(*args, **kwargs): + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + context_args: List of parameter names that come from context (for unpacked mode) + cacheable: Enable caching of results (default: False) + volatile: Mark as volatile (default: False) + log_level: Logging verbosity (default: logging.DEBUG) + validate_result: Validate return type (default: True) + verbose: Verbose logging output (default: True) + evaluator: Custom evaluator (default: None) + + Two Context Modes: + + Mode 1 - Explicit context parameter: + Function has a 'context' parameter annotated with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + + Mode 2 - Unpacked context_args: + Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"]) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, start_date, end_date)) + + Dependencies: + Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies: + + from ccflow import Dep, DepOf + from typing import Annotated + + @Flow.model + def compute_returns( + context: DateRangeContext, + prices: Annotated[GenericResult[pl.DataFrame], Dep( + transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + )] + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=prices.value.pct_change()) + + # Or use DepOf shorthand for no transform: + @Flow.model + def compute_stats( + context: DateRangeContext, + data: DepOf[..., GenericResult[pl.DataFrame]] + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=data.value.describe()) + + Usage: + # Create model instances + loader = load_prices(source="prod_db") + returns = compute_returns(prices=loader) + + # Execute + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = returns(ctx) + + Returns: + A factory function that creates CallableModel instances + """ + from .flow_model import flow_model + + return flow_model(*args, **kwargs) + # ***************************************************************************** # Define "Evaluators" and associated types @@ -555,7 +831,8 @@ def _context_validator(cls, values, handler, info): def __call__(self) -> ResultType: fn = getattr(self.model, self.fn) if hasattr(fn, "__wrapped__"): - result = fn.__wrapped__(self.model, self.context) + # Call through _resolve_deps_and_call to handle DepOf field resolution + result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__) # If it's a callable model, then we can validate the result if self.options.get("validate_result", True): if fn.__name__ == "__deps__": diff --git a/ccflow/context.py b/ccflow/context.py index cf17d24..62ce0f7 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -2,10 +2,10 @@ import warnings from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency @@ -15,6 +15,7 @@ __all__ = ( + "FlowContext", "NullContext", "GenericContext", "DateContext", @@ -93,6 +94,42 @@ # Starting 0.8.0 Nullcontext is an alias to ContextBase NullContext = ContextBase + +class FlowContext(ContextBase): + """Universal context for @Flow.model functions. + + Instead of generating a new ContextBase subclass for each @Flow.model, + this single class with extra="allow" serves as the universal carrier. + Validation happens via TypedDict + TypeAdapter at compute() time. + + This design avoids: + - Proliferation of dynamic _funcname_Context classes + - Class registration overhead for serialization + - Pickling issues with Ray/distributed computing + + Fields are stored in __pydantic_extra__ and accessed via __getattr__. + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + def __getattr__(self, name: str) -> Any: + """Access fields stored in __pydantic_extra__.""" + # Use object.__getattribute__ to avoid infinite recursion + try: + extra = object.__getattribute__(self, "__pydantic_extra__") + if extra is not None and name in extra: + return extra[name] + except AttributeError: + pass + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __repr__(self) -> str: + """Show all fields including extra fields.""" + extra = object.__getattribute__(self, "__pydantic_extra__") or {} + fields = ", ".join(f"{k}={v!r}" for k, v in extra.items()) + return f"FlowContext({fields})" + + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/dep.py b/ccflow/dep.py new file mode 100644 index 0000000..b57261e --- /dev/null +++ b/ccflow/dep.py @@ -0,0 +1,278 @@ +"""Dependency annotation markers for Flow.model. + +This module provides: +- Dep: Annotation marker for dependency parameters that can accept CallableModel +- DepOf: Shorthand for Annotated[Union[T, CallableModel], Dep()] +""" + +from typing import TYPE_CHECKING, Annotated, Callable, Optional, Type, TypeVar, Union, get_args, get_origin + +from .base import ContextBase + +if TYPE_CHECKING: + from .callable import CallableModel + +__all__ = ("Dep", "DepOf") + +T = TypeVar("T") + +# Lazy reference to CallableModel to avoid circular import +_CallableModel = None + + +def _get_callable_model(): + """Lazily import CallableModel to avoid circular imports.""" + global _CallableModel + if _CallableModel is None: + from .callable import CallableModel + + _CallableModel = CallableModel + return _CallableModel + + +class _DepOfMeta(type): + """Metaclass that makes DepOf[ContextType, ResultType] work.""" + + def __getitem__(cls, item): + if not isinstance(item, tuple) or len(item) != 2: + raise TypeError( + "DepOf requires 2 type arguments: DepOf[ContextType, ResultType]. " + "Use ... for ContextType to inherit from parent: DepOf[..., ResultType]" + ) + context_type, result_type = item + CallableModel = _get_callable_model() + + if context_type is ...: + # DepOf[..., ResultType] - inherit context from parent + return Annotated[Union[result_type, CallableModel], Dep()] + else: + # DepOf[ContextType, ResultType] - explicit context type + return Annotated[Union[result_type, CallableModel], Dep(context_type=context_type)] + + +class DepOf(metaclass=_DepOfMeta): + """ + Shorthand for Annotated[Union[ResultType, CallableModel], Dep(context_type=...)]. + + Follows Callable convention: DepOf[InputContext, OutputResult] + + For class fields, accepts either: + - The result type directly (pre-computed value) + - A CallableModel that produces the result type (resolved at call time) + + Usage: + # Inherit context type from parent model (most common) + data: DepOf[..., GenericResult[dict]] + + # Explicit context type validation + data: DepOf[DateRangeContext, GenericResult[dict]] + + At call time, if the field contains a CallableModel, it will be automatically + resolved using __deps__ and the resolved value will be accessible via self.field_name. + + For dependencies with transforms, define them in __deps__: + def __deps__(self, context): + transformed_ctx = context.model_copy(update={...}) + return [(self.data, [transformed_ctx])] + """ + + pass + + +def _is_compatible_type(actual: Type, expected: Type) -> bool: + """Check if actual type is compatible with expected type. + + Handles generic types like GenericResult[pl.DataFrame] where issubclass + would raise TypeError. + + Args: + actual: The actual type to check + expected: The expected type to match against + + Returns: + True if actual is compatible with expected + """ + # Handle None/empty types + if actual is None or expected is None: + return actual is expected + + # Get origins for generic types + actual_origin = get_origin(actual) or actual + expected_origin = get_origin(expected) or expected + + # Check if origins are compatible + try: + if not (isinstance(actual_origin, type) and isinstance(expected_origin, type)): + return False + if not issubclass(actual_origin, expected_origin): + return False + except TypeError: + # issubclass can fail for certain types + return False + + # Check generic args if present + actual_args = get_args(actual) + expected_args = get_args(expected) + + if expected_args and actual_args: + if len(actual_args) != len(expected_args): + return False + return all(_is_compatible_type(a, e) for a, e in zip(actual_args, expected_args)) + + return True + + +class Dep: + """ + Annotation marker for dependency parameters. + + Marks a parameter as accepting either the declared type or a CallableModel + that produces that type. Supports optional context transform and + construction-time type validation. + + Usage: + # No transform, no explicit validation (uses parent's context_type) + prices: Annotated[GenericResult[pl.DataFrame], Dep()] + + # With transform + prices: Annotated[GenericResult[pl.DataFrame], Dep( + transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) + )] + + # With explicit context_type validation + prices: Annotated[GenericResult[pl.DataFrame], Dep( + context_type=DateRangeContext, + transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) + )] + + # Cross-context dependency (transform changes context type) + sim_data: Annotated[GenericResult[pl.DataFrame], Dep( + context_type=SimulationContext, + transform=date_to_simulation_context + )] + """ + + def __init__( + self, + transform: Optional[Callable[..., ContextBase]] = None, + context_type: Optional[Type[ContextBase]] = None, + ): + """ + Args: + transform: Optional function to transform context before calling dependency. + Signature: (context) -> transformed_context + context_type: Expected context_type of the dependency CallableModel. + If None, defaults to the parent model's context_type. + Validated at construction time when a CallableModel is passed. + """ + self.transform = transform + self.context_type = context_type + + def apply(self, context: ContextBase) -> ContextBase: + """Apply the transform to a context, or return unchanged if no transform.""" + if self.transform is not None: + return self.transform(context) + return context + + def validate_dependency( + self, + value: "CallableModel", # noqa: F821 + expected_result_type: Type, + parent_context_type: Type[ContextBase], + param_name: str, + ) -> None: + """ + Validate a CallableModel dependency at construction time. + + Args: + value: The CallableModel being passed as a dependency + expected_result_type: The result type from the Annotated type hint + parent_context_type: The context_type of the parent model + param_name: Name of the parameter (for error messages) + + Raises: + TypeError: If context_type or result_type don't match + """ + # Import here to avoid circular import + from .callable import CallableModel + + if not isinstance(value, CallableModel): + return # Not a CallableModel, skip validation + + # Determine expected context type + expected_ctx = self.context_type if self.context_type is not None else parent_context_type + + # Validate context_type - the dependency's context_type should be compatible + # with what we'll pass to it (expected_ctx) + dep_context_type = value.context_type + try: + if not issubclass(expected_ctx, dep_context_type): + raise TypeError( + f"Dependency '{param_name}': expected context_type compatible with " + f"{dep_context_type.__name__}, but will pass {expected_ctx.__name__}" + ) + except TypeError: + # issubclass can fail for certain types, try alternate check + if expected_ctx != dep_context_type: + raise TypeError(f"Dependency '{param_name}': context_type mismatch - expected {dep_context_type}, got {expected_ctx}") + + # Validate result_type using the generic-safe comparison + # If expected_result_type is Union[T, CallableModel], extract T for validation + dep_result_type = value.result_type + actual_expected_type = expected_result_type + + # Handle Union[T, CallableModel] from DepOf expansion + if get_origin(expected_result_type) is Union: + union_args = get_args(expected_result_type) + # Filter out CallableModel from the union + non_callable_types = [t for t in union_args if t is not CallableModel] + if non_callable_types: + actual_expected_type = non_callable_types[0] + + if not _is_compatible_type(dep_result_type, actual_expected_type): + raise TypeError( + f"Dependency '{param_name}': expected result_type compatible with " + f"{actual_expected_type}, but got CallableModel with result_type {dep_result_type}" + ) + + def __repr__(self): + parts = [] + if self.transform is not None: + parts.append(f"transform={self.transform}") + if self.context_type is not None: + parts.append(f"context_type={self.context_type.__name__}") + return f"Dep({', '.join(parts)})" if parts else "Dep()" + + def __eq__(self, other): + if not isinstance(other, Dep): + return False + return self.transform == other.transform and self.context_type == other.context_type + + def __hash__(self): + # Make Dep hashable for use in sets/dicts + return hash((id(self.transform), self.context_type)) + + +def extract_dep(annotation) -> tuple: + """Extract Dep from Annotated[T, Dep(...)] or DepOf[ContextType, T]. + + When multiple Dep annotations exist (e.g., from nested Annotated that flattens), + returns the LAST one, which represents the outermost user annotation. + + Args: + annotation: A type annotation, possibly Annotated with Dep + + Returns: + Tuple of (base_type, Dep instance or None) + """ + if get_origin(annotation) is Annotated: + args = get_args(annotation) + base_type = args[0] + # Find the LAST Dep - nested Annotated flattens, so outer annotation comes last + last_dep = None + for metadata in args[1:]: + if isinstance(metadata, Dep): + last_dep = metadata + if last_dep is not None: + return base_type, last_dep + return annotation, None diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py new file mode 100644 index 0000000..2d3ab3a --- /dev/null +++ b/ccflow/flow_model.py @@ -0,0 +1,745 @@ +"""Flow.model decorator implementation. + +This module provides the Flow.model decorator that generates CallableModel classes +from plain Python functions, reducing boilerplate while maintaining full compatibility +with existing ccflow infrastructure. + +Key design: Uses TypedDict + TypeAdapter for context schema validation instead of +generating dynamic ContextBase subclasses. This avoids class registration overhead +and enables clean pickling for distributed computing (e.g., Ray). +""" + +import inspect +import logging +from functools import wraps +from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_origin + +from pydantic import Field, TypeAdapter +from typing_extensions import TypedDict + +from .base import ContextBase, ResultBase +from .context import FlowContext +from .dep import Dep, extract_dep + +__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy") + +log = logging.getLogger(__name__) + + +class FlowAPI: + """API namespace for deferred computation operations. + + Provides methods for executing models and transforming contexts. + Accessed via model.flow property. + """ + + def __init__(self, model: "CallableModel"): # noqa: F821 + self._model = model + + def compute(self, **kwargs) -> Any: + """Execute the model with the provided context arguments. + + Validates kwargs against the model's context schema using TypeAdapter, + then wraps in FlowContext and calls the model. + + Args: + **kwargs: Context arguments (e.g., start_date, end_date) + + Returns: + The model's result, unwrapped from GenericResult if applicable. + """ + # Get validator from model (lazily created if needed after unpickling) + validator = self._model._get_context_validator() + + # Validate and coerce kwargs via TypeAdapter + validated = validator.validate_python(kwargs) + + # Wrap in FlowContext (single class, always) + ctx = FlowContext(**validated) + + # Call the model + result = self._model(ctx) + + # Unwrap GenericResult if present + if hasattr(result, "value"): + return result.value + return result + + @property + def unbound_inputs(self) -> Dict[str, Type]: + """Return the context schema (field name -> type). + + In deferred mode, this is everything NOT provided at construction. + """ + all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self._model, "_bound_fields", set()) + + # If explicit context_args was provided, use _context_schema + explicit_args = getattr(self._model.__class__, "__flow_model_explicit_context_args__", None) + if explicit_args is not None: + return self._model._context_schema.copy() + + # Otherwise, unbound = all params - bound + return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + + @property + def bound_inputs(self) -> Dict[str, Any]: + """Return the config values bound at construction time.""" + bound_fields = getattr(self._model, "_bound_fields", set()) + result = {} + for name in bound_fields: + if hasattr(self._model, name): + result[name] = getattr(self._model, name) + return result + + def with_inputs(self, **transforms) -> "BoundModel": + """Create a version of this model with transformed context inputs. + + Args: + **transforms: Mapping of field name to either: + - A callable (ctx) -> value for dynamic transforms + - A static value to bind + + Returns: + A BoundModel that applies the transforms before calling. + """ + return BoundModel(model=self._model, input_transforms=transforms) + + +class BoundModel: + """A model with context transforms applied. + + Created by model.flow.with_inputs(). Applies transforms to context + before delegating to the underlying model. + """ + + def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # noqa: F821 + self._model = model + self._input_transforms = input_transforms + + def __call__(self, context: ContextBase) -> Any: + """Call the model with transformed context.""" + # Build new context dict with transforms applied + ctx_dict = {} + + # Get fields from context + if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: + ctx_dict.update(context.__pydantic_extra__) + for field in context.__class__.model_fields: + ctx_dict[field] = getattr(context, field) + + # Apply transforms + for name, transform in self._input_transforms.items(): + if callable(transform): + ctx_dict[name] = transform(context) + else: + ctx_dict[name] = transform + + # Create new context and call model + new_ctx = FlowContext(**ctx_dict) + return self._model(new_ctx) + + @property + def flow(self) -> FlowAPI: + """Access the flow API.""" + return FlowAPI(self._model) + + +class Lazy: + """Deferred model execution with runtime context overrides. + + Wraps a CallableModel to allow context fields to be determined at + runtime rather than at construction time. Use in with_inputs() when + you need values that aren't available until execution. + + Example: + # Create a model that needs runtime-determined context + market_data = load_market_data(symbols=["AAPL"]) + + # Use Lazy to defer the start_date calculation + lookback_data = market_data.flow.with_inputs( + start_date=Lazy(market_data)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + ) + + # More commonly, use Lazy for self-referential transforms: + adjusted_model = model.flow.with_inputs( + value=Lazy(other_model)(multiplier=2) # Call other_model with multiplier=2 + ) + + The __call__ method returns a callable that, when invoked with a context, + calls the wrapped model with the specified overrides applied. + """ + + def __init__(self, model: "CallableModel"): # noqa: F821 + """Wrap a model for deferred execution. + + Args: + model: The CallableModel to wrap + """ + self._model = model + + def __call__(self, **overrides) -> Callable[[ContextBase], Any]: + """Create a callable that applies overrides to context before execution. + + Args: + **overrides: Context field overrides. Values can be: + - Static values (applied directly) + - Callables (ctx) -> value (called with context at runtime) + + Returns: + A callable (context) -> result that applies overrides and calls the model + """ + model = self._model + + def execute_with_overrides(context: ContextBase) -> Any: + # Build context dict from incoming context + ctx_dict = {} + if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: + ctx_dict.update(context.__pydantic_extra__) + for field in context.__class__.model_fields: + ctx_dict[field] = getattr(context, field) + + # Apply overrides + for name, value in overrides.items(): + if callable(value): + ctx_dict[name] = value(context) + else: + ctx_dict[name] = value + + # Call model with modified context + new_ctx = FlowContext(**ctx_dict) + return model(new_ctx) + + return execute_with_overrides + + @property + def model(self) -> "CallableModel": # noqa: F821 + """Access the wrapped model.""" + return self._model + + +def _build_context_schema( + context_args: List[str], func: Callable, sig: inspect.Signature +) -> Tuple[Dict[str, Type], Type, Optional[Type[ContextBase]]]: + """Build context schema from context_args parameter names. + + Instead of creating a dynamic ContextBase subclass, this builds: + - A schema dict mapping field names to types + - A TypedDict for Pydantic TypeAdapter validation + - Optionally, a matched existing ContextBase type for compatibility + + Args: + context_args: List of parameter names that come from context + func: The decorated function + sig: The function signature + + Returns: + Tuple of (schema_dict, TypedDict type, optional matched ContextBase type) + """ + # Build schema dict from parameter annotations + schema = {} + for name in context_args: + if name not in sig.parameters: + raise ValueError(f"context_arg '{name}' not found in function parameters") + param = sig.parameters[name] + if param.annotation is inspect.Parameter.empty: + raise ValueError(f"context_arg '{name}' must have a type annotation") + schema[name] = param.annotation + + # Try to match common context types for compatibility + matched_context_type = None + from .context import DateRangeContext + + if set(context_args) == {"start_date", "end_date"}: + from datetime import date + + if all( + sig.parameters[name].annotation in (date, "date") + or (isinstance(sig.parameters[name].annotation, type) and sig.parameters[name].annotation is date) + for name in context_args + ): + matched_context_type = DateRangeContext + + # Create TypedDict for validation (not registered anywhere!) + context_td = TypedDict(f"{func.__name__}Inputs", schema) + + return schema, context_td, matched_context_type + + +def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: + """Extract dependency info from an annotation. + + Returns: + Tuple of (base_type, Dep instance or None) + """ + return extract_dep(annotation) + + +def flow_model( + func: Callable = None, + *, + # Context handling + context_args: Optional[List[str]] = None, + # Flow.call options (passed to generated __call__) + cacheable: bool = False, + volatile: bool = False, + log_level: int = logging.DEBUG, + validate_result: bool = True, + verbose: bool = True, + evaluator: Optional[Any] = None, +) -> Callable: + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + func: The function to decorate + context_args: List of parameter names that come from context (for unpacked mode) + cacheable: Enable caching of results + volatile: Mark as volatile (always re-execute) + log_level: Logging verbosity + validate_result: Validate return type + verbose: Verbose logging output + evaluator: Custom evaluator + + Two Context Modes: + 1. Explicit context parameter: Function has a 'context' parameter annotated + with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + ... + + 2. Unpacked context_args: Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"]) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + ... + + Returns: + A factory function that creates CallableModel instances + """ + + def decorator(fn: Callable) -> Callable: + # Import here to avoid circular imports + from .callable import CallableModel, Flow, GraphDepList + + sig = inspect.signature(fn) + params = sig.parameters + + # Validate return type + return_type = sig.return_annotation + if return_type is inspect.Signature.empty: + raise TypeError(f"Function {fn.__name__} must have a return type annotation") + # Check that return type is a ResultBase subclass + return_origin = get_origin(return_type) or return_type + if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): + raise TypeError(f"Function {fn.__name__} must return a ResultBase subclass, got {return_type}") + + # Determine context mode + if "context" in params or "_" in params: + # Mode 1: Explicit context parameter (named 'context' or '_' for unused) + context_param_name = "context" if "context" in params else "_" + context_param = params[context_param_name] + if context_param.annotation is inspect.Parameter.empty: + raise TypeError(f"Function {fn.__name__}: '{context_param_name}' parameter must have a type annotation") + context_type = context_param.annotation + if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): + raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") + model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} + use_context_args = False + explicit_context_args = None + elif context_args is not None: + # Mode 2: Explicit context_args - specified params come from context + context_param_name = "context" + # Build context schema early to determine matched_context_type + context_schema_early, _, matched_type = _build_context_schema(context_args, fn, sig) + # Use matched type if available (e.g., DateRangeContext), else FlowContext + context_type = matched_type if matched_type is not None else FlowContext + # Exclude context_args from model fields + model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} + use_context_args = True + explicit_context_args = context_args + else: + # Mode 3: Dynamic deferred mode - ALL params are potential context or config + # What's provided at construction = config/deps + # What's NOT provided = comes from context at runtime + context_param_name = "context" + context_type = FlowContext + model_field_params = {name: param for name, param in params.items() if name != "self"} + use_context_args = True + explicit_context_args = None # Dynamic - determined at construction + + # Analyze parameters to find dependencies and regular fields + dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) + model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + + # In dynamic deferred mode (no explicit context_args), all fields are optional + # because values not provided at construction come from context at runtime + dynamic_deferred_mode = use_context_args and explicit_context_args is None + + for name, param in model_field_params.items(): + if param.annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation") + + base_type, dep = _get_dep_info(param.annotation) + if param.default is not inspect.Parameter.empty: + default = param.default + elif dynamic_deferred_mode: + # In dynamic mode, params without defaults are optional (come from context) + default = None + else: + # In explicit mode, params without defaults are required + default = ... + + if dep is not None: + # This is an explicit dependency parameter (DepOf annotation) + dep_fields[name] = (base_type, dep) + # Use Annotated so _resolve_deps_and_call in callable.py can find the Dep + model_fields[name] = (Annotated[Union[base_type, CallableModel], dep], default) + else: + # Regular model field - use Any for auto-detection of CallableModels. + # We can't use Union[T, CallableModel] because Pydantic tries to generate + # schema for T, which fails for arbitrary types like pl.DataFrame. + # Using Any allows any value; we do runtime isinstance checks in __call__. + model_fields[name] = (Any, default) + + # Capture variables for closures + ctx_param_name = context_param_name if not use_context_args else "context" + all_param_names = list(model_fields.keys()) # All non-context params (model fields) + all_param_types = {name: param.annotation for name, param in model_field_params.items()} + # For explicit context_args mode, we also need the list of context arg names + ctx_args_for_closure = context_args if context_args is not None else [] + is_dynamic_mode = use_context_args and explicit_context_args is None + + # Create the __call__ method + def make_call_impl(): + def __call__(self, context): + # Import here (inside function) to avoid pickling issues with ContextVar + from .callable import _resolved_deps + + # Check if this model has custom deps (from @func.deps decorator) + has_custom_deps = getattr(self.__class__, "__has_custom_deps__", False) + + def resolve_callable_model(name, value, store): + """Resolve a CallableModel field. + + When has_custom_deps is True and the value is NOT in the store, + it means the custom deps function chose not to include this dep. + In that case, we return None (the field's default) instead of + calling the CallableModel directly. + """ + if id(value) in store: + return store[id(value)] + elif has_custom_deps: + # Custom deps excluded this field - use None + return None + else: + # Auto-detection fallback: call directly + resolved = value(context) + if hasattr(resolved, 'value'): + return resolved.value + return resolved + + # Build kwargs for the original function + fn_kwargs = {} + store = _resolved_deps.get() + + if not use_context_args: + # Mode 1: Explicit context param - pass context directly + fn_kwargs[ctx_param_name] = context + # Add model fields + for name in all_param_names: + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value + elif not is_dynamic_mode: + # Mode 2: Explicit context_args - get those from context, rest from self + for name in ctx_args_for_closure: + fn_kwargs[name] = getattr(context, name) + # Add model fields + for name in all_param_names: + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value + else: + # Mode 3: Dynamic deferred mode - unbound from context, bound from self + bound_fields = getattr(self, "_bound_fields", set()) + + for name in all_param_names: + if name in bound_fields: + # Bound at construction - get from self + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value + else: + # Unbound - get from context + fn_kwargs[name] = getattr(context, name) + + return fn(**fn_kwargs) + + # Set proper signature for CallableModel validation + __call__.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=return_type, + ) + return __call__ + + call_impl = make_call_impl() + + # Apply Flow.call decorator + flow_options = { + "cacheable": cacheable, + "volatile": volatile, + "log_level": log_level, + "validate_result": validate_result, + "verbose": verbose, + } + if evaluator is not None: + flow_options["evaluator"] = evaluator + + decorated_call = Flow.call(**flow_options)(call_impl) + + # Create the __deps__ method + def make_deps_impl(): + def __deps__(self, context) -> GraphDepList: + deps = [] + # Check ALL fields for CallableModels (auto-detection) + for name in model_fields: + value = getattr(self, name) + if isinstance(value, CallableModel): + if name in dep_fields: + # Explicit DepOf with transform (backwards compat) + _, dep_obj = dep_fields[name] + transformed_ctx = dep_obj.apply(context) + deps.append((value, [transformed_ctx])) + else: + # Auto-detected dependency - use context as-is + deps.append((value, [context])) + return deps + + # Set proper signature + __deps__.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ + + deps_impl = make_deps_impl() + decorated_deps = Flow.deps(deps_impl) + + # Build pydantic field annotations for the class + annotations = {} + + namespace = { + "__module__": fn.__module__, + "__qualname__": f"_{fn.__name__}_Model", + "__call__": decorated_call, + "__deps__": decorated_deps, + } + + for name, (typ, default) in model_fields.items(): + annotations[name] = typ + if default is not ...: + namespace[name] = default + else: + # For required fields, use Field(...) + namespace[name] = Field(...) + + namespace["__annotations__"] = annotations + + # Add model validator for dependency validation if we have dep fields + if dep_fields: + from pydantic import model_validator + + # Create validator function that captures dep_fields and context_type + def make_dep_validator(d_fields, ctx_type): + @model_validator(mode="after") + def __validate_deps__(self): + from .callable import CallableModel + + for dep_name, (base_type, dep_obj) in d_fields.items(): + value = getattr(self, dep_name) + if isinstance(value, CallableModel): + dep_obj.validate_dependency(value, base_type, ctx_type, dep_name) + return self + + return __validate_deps__ + + namespace["__validate_deps__"] = make_dep_validator(dep_fields, context_type) + + # Create the class using type() + GeneratedModel = type(f"_{fn.__name__}_Model", (CallableModel,), namespace) + + # Set class-level attributes after class creation (to avoid pydantic processing) + GeneratedModel.__flow_model_context_type__ = context_type + GeneratedModel.__flow_model_return_type__ = return_type + GeneratedModel.__flow_model_func__ = fn + GeneratedModel.__flow_model_dep_fields__ = dep_fields + GeneratedModel.__flow_model_use_context_args__ = use_context_args + GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args + GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + + # Build context_schema and matched_context_type + context_schema: Dict[str, Type] = {} + context_td = None + matched_context_type: Optional[Type[ContextBase]] = None + + if explicit_context_args is not None: + # Explicit context_args provided - use early-computed schema + # (matched_context_type was already used to set context_type above) + context_schema, context_td, matched_context_type = _build_context_schema(explicit_context_args, fn, sig) + elif not use_context_args: + # Explicit context mode - schema comes from the context type's fields + if hasattr(context_type, "model_fields"): + context_schema = {name: info.annotation for name, info in context_type.model_fields.items()} + # For dynamic mode (is_dynamic_mode), _context_schema remains empty + # and schema is built dynamically from _bound_fields at runtime + + # Store context schema for TypedDict-based validation (picklable!) + GeneratedModel._context_schema = context_schema + GeneratedModel._context_td = context_td + GeneratedModel._matched_context_type = matched_context_type + # Validator is created lazily to survive pickling + GeneratedModel._cached_context_validator = None + + # Method to get/create context validator (lazy for pickling support) + def _get_context_validator(self) -> TypeAdapter: + """Get or create the context validator. + + For dynamic deferred mode, builds schema from unbound fields. + For explicit context_args or explicit context mode, uses cached schema. + """ + cls = self.__class__ + explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) + + # For explicit context_args or explicit context mode, use cached validator + if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): + if cls._cached_context_validator is None: + if cls._context_td is not None: + cls._cached_context_validator = TypeAdapter(cls._context_td) + elif cls._context_schema: + td = TypedDict(f"{cls.__name__}Inputs", cls._context_schema) + cls._cached_context_validator = TypeAdapter(td) + else: + cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) + return cls._cached_context_validator + + # Dynamic mode: build schema from unbound fields (instance-specific) + # Cache on instance since bound_fields varies per instance + if not hasattr(self, "_instance_context_validator"): + all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self, "_bound_fields", set()) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + if unbound_schema: + td = TypedDict(f"{cls.__name__}Inputs", unbound_schema) + object.__setattr__(self, "_instance_context_validator", TypeAdapter(td)) + else: + # No unbound fields - empty validator + object.__setattr__(self, "_instance_context_validator", TypeAdapter(dict)) + return self._instance_context_validator + + GeneratedModel._get_context_validator = _get_context_validator + + # Override context_type property after class creation + @property + def context_type_getter(self) -> Type[ContextBase]: + return self.__class__.__flow_model_context_type__ + + # Override result_type property after class creation + @property + def result_type_getter(self) -> Type[ResultBase]: + return self.__class__.__flow_model_return_type__ + + # Add .flow property for the new API + @property + def flow_getter(self) -> FlowAPI: + return FlowAPI(self) + + GeneratedModel.context_type = context_type_getter + GeneratedModel.result_type = result_type_getter + GeneratedModel.flow = flow_getter + + # Register the MODEL class for serialization (needed for model_dump/_target_). + # Note: We do NOT register dynamic context classes anymore - context handling + # uses FlowContext + TypedDict instead, which don't need registration. + from .local_persistence import register_ccflow_import_path + + register_ccflow_import_path(GeneratedModel) + + # Rebuild the model to process annotations properly + GeneratedModel.model_rebuild() + + # Create factory function that returns model instances + @wraps(fn) + def factory(**kwargs) -> GeneratedModel: + instance = GeneratedModel(**kwargs) + # Track which fields were explicitly provided at construction + # These are "bound" - everything else comes from context at runtime + object.__setattr__(instance, "_bound_fields", set(kwargs.keys())) + return instance + + # Preserve useful attributes on factory + factory._generated_model = GeneratedModel + factory.__doc__ = fn.__doc__ + + # Add .deps decorator for customizing __deps__ + def deps_decorator(deps_fn): + """Decorator to customize the __deps__ method. + + Usage: + @Flow.model + def my_func(start_date: date, prices: dict) -> GenericResult[...]: + ... + + @my_func.deps + def _(self, context): + # Custom context transform + lookback_ctx = FlowContext( + start_date=context.start_date - timedelta(days=30), + end_date=context.end_date, + ) + return [(self.prices, [lookback_ctx])] + """ + from .callable import GraphDepList + + # Rename the function to __deps__ so Flow.deps accepts it + deps_fn.__name__ = "__deps__" + deps_fn.__qualname__ = f"{GeneratedModel.__qualname__}.__deps__" + # Set proper signature to match __call__'s context type + deps_fn.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=GraphDepList, + ) + # Wrap with Flow.deps and replace on the class + decorated = Flow.deps(deps_fn) + GeneratedModel.__deps__ = decorated + # Mark that this model has custom deps (so _resolve_deps_and_call will call it) + GeneratedModel.__has_custom_deps__ = True + return factory # Return factory for chaining + + factory.deps = deps_decorator + + return factory + + # Handle both @Flow.model and @Flow.model(...) syntax + if func is not None: + return decorator(func) + return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml new file mode 100644 index 0000000..781bd24 --- /dev/null +++ b/ccflow/tests/config/conf_flow.yaml @@ -0,0 +1,80 @@ +# Flow.model configurations for Hydra integration tests +# This file is separate from conf.yaml to avoid affecting existing tests + +# Basic Flow.model +flow_loader: + _target_: ccflow.tests.test_flow_model.basic_loader + source: test_source + multiplier: 5 + +flow_processor: + _target_: ccflow.tests.test_flow_model.string_processor + prefix: "value=" + suffix: "!" + +# Pipeline with dependencies (uses registry name references for same instance) +flow_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 100 + +flow_transformer: + _target_: ccflow.tests.test_flow_model.data_transformer + source: flow_source + factor: 3 + +# Three-stage pipeline +flow_stage1: + _target_: ccflow.tests.test_flow_model.pipeline_stage1 + initial: 10 + +flow_stage2: + _target_: ccflow.tests.test_flow_model.pipeline_stage2 + stage1_output: flow_stage1 + multiplier: 2 + +flow_stage3: + _target_: ccflow.tests.test_flow_model.pipeline_stage3 + stage2_output: flow_stage2 + offset: 50 + +# Diamond dependency pattern +diamond_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 10 + +diamond_branch_a: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 2 + +diamond_branch_b: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 5 + +diamond_aggregator: + _target_: ccflow.tests.test_flow_model.data_aggregator + input_a: diamond_branch_a + input_b: diamond_branch_b + operation: add + +# DateRangeContext with transform +flow_date_loader: + _target_: ccflow.tests.test_flow_model.date_range_loader + source: market_data + include_weekends: false + +flow_date_processor: + _target_: ccflow.tests.test_flow_model.date_range_processor + raw_data: flow_date_loader + normalize: true + +# context_args models (auto-unpacked context parameters) +ctx_args_loader: + _target_: ccflow.tests.test_flow_model.context_args_loader + source: data_source + +ctx_args_processor: + _target_: ccflow.tests.test_flow_model.context_args_processor + data: ctx_args_loader + prefix: "output" diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index a748765..29f4524 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -462,6 +462,7 @@ def test_types(self): error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelMissingContextArg) + # BadModelDoubleContextArg also fails with the same error since extra params aren't allowed error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelDoubleContextArg) diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index ad98bd9..64d71e8 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -275,8 +275,13 @@ def split_camel(name: str): def test_inheritance(self): """Test that if a context has a superset of fields of another context, it is a subclass of that context.""" - for parent_name, parent_class in self.classes.items(): - for child_name, child_class in self.classes.items(): + # Exclude FlowContext from this test - it's a special universal carrier with no + # declared fields (uses extra="allow"), so the "superset implies subclass" logic + # doesn't apply to it. + classes_to_check = {name: cls for name, cls in self.classes.items() if name != "FlowContext"} + + for parent_name, parent_class in classes_to_check.items(): + for child_name, child_class in classes_to_check.items(): if parent_class is child_class: continue diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py new file mode 100644 index 0000000..70af8b2 --- /dev/null +++ b/ccflow/tests/test_flow_context.py @@ -0,0 +1,467 @@ +"""Tests for FlowContext, FlowAPI, and TypedDict-based context validation. + +These tests verify the new deferred computation API that uses: +- FlowContext: Universal context carrier with extra="allow" +- TypedDict + TypeAdapter: Schema validation without dynamic class registration +- FlowAPI: The .flow namespace for compute/with_inputs/etc. +""" + +import pickle +from datetime import date, timedelta + +import cloudpickle +import pytest + +from ccflow import Flow, FlowAPI, FlowContext, GenericResult +from ccflow.context import DateRangeContext + + +class TestFlowContext: + """Tests for the FlowContext universal carrier.""" + + def test_flow_context_basic(self): + """FlowContext accepts arbitrary fields.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + + def test_flow_context_extra_fields(self): + """FlowContext stores fields in __pydantic_extra__.""" + ctx = FlowContext(x=1, y="hello", z=[1, 2, 3]) + assert ctx.x == 1 + assert ctx.y == "hello" + assert ctx.z == [1, 2, 3] + assert ctx.__pydantic_extra__ == {"x": 1, "y": "hello", "z": [1, 2, 3]} + + def test_flow_context_frozen(self): + """FlowContext is immutable (frozen).""" + ctx = FlowContext(value=42) + with pytest.raises(Exception): # ValidationError for frozen model + ctx.value = 100 + + def test_flow_context_repr(self): + """FlowContext has a useful repr.""" + ctx = FlowContext(a=1, b=2) + repr_str = repr(ctx) + assert "FlowContext" in repr_str + assert "a=1" in repr_str + assert "b=2" in repr_str + + def test_flow_context_attribute_error(self): + """FlowContext raises AttributeError for missing fields.""" + ctx = FlowContext(x=1) + with pytest.raises(AttributeError, match="no attribute 'missing'"): + _ = ctx.missing + + def test_flow_context_model_dump(self): + """FlowContext can be dumped (includes extra fields).""" + ctx = FlowContext(start_date=date(2024, 1, 1), value=42) + dumped = ctx.model_dump() + assert dumped["start_date"] == date(2024, 1, 1) + assert dumped["value"] == 42 + + def test_flow_context_pickle(self): + """FlowContext pickles cleanly.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + assert unpickled.start_date == date(2024, 1, 1) + assert unpickled.end_date == date(2024, 1, 31) + + def test_flow_context_cloudpickle(self): + """FlowContext works with cloudpickle (for Ray).""" + ctx = FlowContext(data=[1, 2, 3], name="test") + pickled = cloudpickle.dumps(ctx) + unpickled = cloudpickle.loads(pickled) + assert unpickled.data == [1, 2, 3] + assert unpickled.name == "test" + + +class TestFlowAPI: + """Tests for the FlowAPI (.flow namespace).""" + + def test_flow_compute_basic(self): + """FlowAPI.compute() validates and executes.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date, "source": source}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + assert result["source"] == "api" + + def test_flow_compute_type_coercion(self): + """FlowAPI.compute() coerces types via TypeAdapter.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Pass strings - should be coerced to dates + result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + + def test_flow_compute_validation_error(self): + """FlowAPI.compute() raises on missing required args.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + with pytest.raises(Exception): # ValidationError + model.flow.compute(start_date=date(2024, 1, 1)) # Missing end_date + + def test_flow_unbound_inputs(self): + """FlowAPI.unbound_inputs returns the context schema.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + unbound = model.flow.unbound_inputs + + assert "start_date" in unbound + assert "end_date" in unbound + assert unbound["start_date"] == date + assert unbound["end_date"] == date + # source is not unbound (it has a default/is bound) + assert "source" not in unbound + + def test_flow_bound_inputs(self): + """FlowAPI.bound_inputs returns config values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + bound = model.flow.bound_inputs + + assert "source" in bound + assert bound["source"] == "api" + # Context args are not in bound_inputs + assert "start_date" not in bound + assert "end_date" not in bound + + +class TestBoundModel: + """Tests for BoundModel (created via .flow.with_inputs()).""" + + def test_with_inputs_static_value(self): + """with_inputs can bind static values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs(start_date=date(2024, 1, 1)) + + # Call with just end_date (start_date is bound) + ctx = FlowContext(end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_transform_function(self): + """with_inputs can use transform functions.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Lookback: start_date is 7 days before the context's start_date + bound = model.flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_multiple_transforms(self): + """with_inputs can apply multiple transforms.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=lambda ctx: ctx.end_date + timedelta(days=1), + ) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 30)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_bound_model_has_flow_property(self): + """BoundModel has a .flow property.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x * 2) + + model = compute() + bound = model.flow.with_inputs(x=42) + assert isinstance(bound.flow, FlowAPI) + + +class TestTypedDictValidation: + """Tests for TypedDict-based context validation.""" + + def test_schema_stored_on_model(self): + """Model stores _context_schema for validation.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + assert hasattr(model, "_context_schema") + assert model._context_schema == {"start_date": date, "end_date": date} + + def test_validator_created_lazily(self): + """TypeAdapter validator is created lazily.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Initially None + assert model.__class__._cached_context_validator is None + + # After getting validator, it's cached + validator = model._get_context_validator() + assert validator is not None + assert model.__class__._cached_context_validator is validator + + def test_matched_context_type(self): + """DateRangeContext pattern is matched for compatibility.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + # Should match DateRangeContext + assert model.context_type == DateRangeContext + + +class TestPicklingSupport: + """Tests for pickling support (important for Ray). + + Note: Regular pickle cannot pickle locally-defined classes (functions decorated + inside test methods). cloudpickle CAN handle this, which is why Ray uses it. + All tests here use cloudpickle to match Ray's behavior. + """ + + def test_model_cloudpickle_roundtrip(self): + """Model works with cloudpickle (for Ray).""" + + @Flow.model(context_args=["x", "y"]) + def compute(x: int, y: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=(x + y) * multiplier) + + model = compute(multiplier=3) + + # cloudpickle roundtrip (what Ray uses) + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Should work after unpickling + result = unpickled.flow.compute(x=1, y=2) + assert result == 9 # (1 + 2) * 3 + + def test_model_cloudpickle_simple(self): + """Simple model cloudpickle test.""" + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + result = unpickled.flow.compute(value=21) + assert result == 42 + + def test_validator_recreated_after_cloudpickle(self): + """TypeAdapter validator is recreated after cloudpickling.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Warm up the validator cache + _ = model._get_context_validator() + assert model.__class__._cached_context_validator is not None + + # cloudpickle and unpickle + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Validator should still work (may be lazily recreated) + result = unpickled.flow.compute(x=42) + assert result == 42 + + def test_flow_context_pickle_standard(self): + """FlowContext works with standard pickle.""" + ctx = FlowContext(x=1, y=2, z="test") + + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + + assert unpickled.x == 1 + assert unpickled.y == 2 + assert unpickled.z == "test" + + +class TestIntegrationWithExistingContextTypes: + """Tests for integration with existing ContextBase subclasses.""" + + def test_explicit_context_still_works(self): + """Explicit context parameter mode still works.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date, "source": source}) + + model = load_data(source="api") + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["source"] == "api" + + def test_flow_context_coerces_to_date_range(self): + """FlowContext can be used with models expecting DateRangeContext.""" + + @Flow.model + def load_data(context: DateRangeContext) -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data() + # Use FlowContext - should coerce to DateRangeContext + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_flow_api_with_explicit_context(self): + """FlowAPI.compute works with explicit context mode.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + + +class TestLazy: + """Tests for Lazy (deferred execution with context overrides).""" + + def test_lazy_basic(self): + """Lazy wraps a model for deferred execution.""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def compute(value: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + model = compute(multiplier=3) + lazy = Lazy(model) + + assert lazy.model is model + + def test_lazy_call_with_static_override(self): + """Lazy.__call__ with static override values.""" + from ccflow import Lazy + + @Flow.model(context_args=["x", "y"]) + def add(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + + model = add() + lazy_fn = Lazy(model)(y=100) # Override y to 100 + + ctx = FlowContext(x=5, y=10) # Original y=10 + result = lazy_fn(ctx) + assert result.value == 105 # x=5 + y=100 (overridden) + + def test_lazy_call_with_callable_override(self): + """Lazy.__call__ with callable override (computed at runtime).""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + # Override value to be original value + 10 + lazy_fn = Lazy(model)(value=lambda ctx: ctx.value + 10) + + ctx = FlowContext(value=5) + result = lazy_fn(ctx) + assert result.value == 30 # (5 + 10) * 2 = 30 + + def test_lazy_with_date_transforms(self): + """Lazy works with date transforms.""" + from ccflow import Lazy + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + + # Use Lazy to create a transform that shifts dates + lazy_fn = Lazy(model)( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=lambda ctx: ctx.end_date + ) + + ctx = FlowContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) + result = lazy_fn(ctx) + + assert result.value["start"] == date(2024, 1, 8) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_lazy_multiple_overrides(self): + """Lazy supports multiple overrides at once.""" + from ccflow import Lazy + + @Flow.model(context_args=["a", "b", "c"]) + def compute(a: int, b: int, c: int) -> GenericResult[int]: + return GenericResult(value=a + b + c) + + model = compute() + lazy_fn = Lazy(model)( + a=10, # Static + b=lambda ctx: ctx.b * 2, # Transform + # c not overridden, uses context value + ) + + ctx = FlowContext(a=1, b=5, c=100) + result = lazy_fn(ctx) + assert result.value == 10 + 10 + 100 # a=10, b=5*2=10, c=100 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py new file mode 100644 index 0000000..b283a2b --- /dev/null +++ b/ccflow/tests/test_flow_model.py @@ -0,0 +1,1561 @@ +"""Tests for Flow.model decorator.""" + +from datetime import date, timedelta +from typing import Annotated +from unittest import TestCase + +from pydantic import ValidationError +from ray.cloudpickle import dumps as rcpdumps, loads as rcploads + +from ccflow import ( + CallableModel, + ContextBase, + DateRangeContext, + Dep, + DepOf, + Flow, + GenericResult, + ModelRegistry, + ResultBase, +) +from ccflow.callable import resolve + + +class SimpleContext(ContextBase): + """Simple context for testing.""" + + value: int + + +class ExtendedContext(ContextBase): + """Extended context with multiple fields.""" + + x: int + y: str = "default" + + +class MyResult(ResultBase): + """Custom result type for testing.""" + + data: str + + +# ============================================================================= +# Basic Flow.model Tests +# ============================================================================= + + +class TestFlowModelBasic(TestCase): + """Basic Flow.model functionality tests.""" + + def test_simple_model_explicit_context(self): + """Test Flow.model with explicit context parameter.""" + + @Flow.model + def simple_loader(context: SimpleContext, multiplier: int) -> GenericResult[int]: + return GenericResult(value=context.value * multiplier) + + # Create model instance + loader = simple_loader(multiplier=3) + + # Should be a CallableModel + self.assertIsInstance(loader, CallableModel) + + # Execute + ctx = SimpleContext(value=10) + result = loader(ctx) + + self.assertIsInstance(result, GenericResult) + self.assertEqual(result.value, 30) + + def test_model_with_default_params(self): + """Test Flow.model with default parameter values.""" + + @Flow.model + def loader_with_defaults(context: SimpleContext, multiplier: int = 2, prefix: str = "result") -> GenericResult[str]: + return GenericResult(value=f"{prefix}:{context.value * multiplier}") + + # Create with defaults + loader = loader_with_defaults() + result = loader(SimpleContext(value=5)) + self.assertEqual(result.value, "result:10") + + # Create with custom values + loader2 = loader_with_defaults(multiplier=3, prefix="custom") + result2 = loader2(SimpleContext(value=5)) + self.assertEqual(result2.value, "custom:15") + + def test_model_context_type_property(self): + """Test that generated model has correct context_type.""" + + @Flow.model + def typed_model(context: ExtendedContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.x * factor) + + model = typed_model(factor=2) + self.assertEqual(model.context_type, ExtendedContext) + + def test_model_result_type_property(self): + """Test that generated model has correct result_type.""" + + @Flow.model + def custom_result_model(context: SimpleContext) -> MyResult: + return MyResult(data=f"value={context.value}") + + model = custom_result_model() + self.assertEqual(model.result_type, MyResult) + + def test_model_with_no_extra_params(self): + """Test Flow.model with only context parameter.""" + + @Flow.model + def identity_model(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + model = identity_model() + result = model(SimpleContext(value=42)) + self.assertEqual(result.value, 42) + + def test_model_with_flow_options(self): + """Test Flow.model with Flow.call options.""" + + @Flow.model(cacheable=True, validate_result=True) + def cached_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + model = cached_model(value=10) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 15) + + def test_model_with_underscore_context(self): + """Test Flow.model with '_' as context parameter (unused context convention).""" + + @Flow.model + def loader(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def consumer(_: SimpleContext, data: DepOf[..., GenericResult[int]]) -> GenericResult[int]: + # Context not used directly, just passed to dependency + return GenericResult(value=data.value * 2) + + load = loader(base=100) + consume = consumer(data=load) + + result = consume(SimpleContext(value=10)) + # loader: 10 + 100 = 110, consumer: 110 * 2 = 220 + self.assertEqual(result.value, 220) + + # Verify context_type is still correct + self.assertEqual(consume.context_type, SimpleContext) + + +# ============================================================================= +# context_args Mode Tests +# ============================================================================= + + +class TestFlowModelContextArgs(TestCase): + """Tests for Flow.model with context_args (unpacked context).""" + + def test_context_args_basic(self): + """Test basic context_args usage.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def date_range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + + loader = date_range_loader(source="db") + + # Should use DateRangeContext + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual(result.value, "db:2024-01-01 to 2024-01-31") + + def test_context_args_custom_context(self): + """Test context_args with custom context type.""" + + @Flow.model(context_args=["x", "y"]) + def unpacked_model(x: int, y: str, multiplier: int = 1) -> GenericResult[str]: + return GenericResult(value=f"{y}:{x * multiplier}") + + model = unpacked_model(multiplier=2) + + # Create context with generated type + ctx_type = model.context_type + ctx = ctx_type(x=5, y="test") + + result = model(ctx) + self.assertEqual(result.value, "test:10") + + def test_context_args_with_defaults(self): + """Test context_args where context fields have defaults.""" + + @Flow.model(context_args=["value"]) + def model_with_ctx_default(value: int = 42, extra: str = "foo") -> GenericResult[str]: + return GenericResult(value=f"{extra}:{value}") + + model = model_with_ctx_default() + + # Create context - the generated context should allow default + ctx_type = model.context_type + ctx = ctx_type(value=100) + + result = model(ctx) + self.assertEqual(result.value, "foo:100") + + +# ============================================================================= +# Dependency Tests +# ============================================================================= + + +class TestFlowModelDependencies(TestCase): + """Tests for Flow.model with dependencies.""" + + def test_simple_dependency_with_depof(self): + """Test simple dependency using DepOf shorthand.""" + + @Flow.model + def loader(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + multiplier: int = 1, + ) -> GenericResult[int]: + return GenericResult(value=data.value * multiplier) + + # Create pipeline + load = loader(value=10) + consume = consumer(data=load, multiplier=2) + + ctx = SimpleContext(value=5) + result = consume(ctx) + + # loader returns 10 + 5 = 15, consumer multiplies by 2 = 30 + self.assertEqual(result.value, 30) + + def test_dependency_with_explicit_dep(self): + """Test dependency using explicit Dep() annotation.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep()], + ) -> GenericResult[int]: + return GenericResult(value=data.value + 100) + + load = loader() + consume = consumer(data=load) + + result = consume(SimpleContext(value=10)) + # loader: 10 * 2 = 20, consumer: 20 + 100 = 120 + self.assertEqual(result.value, 120) + + def test_dependency_with_direct_value(self): + """Test that Dep fields can accept direct values (not CallableModel).""" + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value + context.value) + + # Pass direct value instead of CallableModel + consume = consumer(data=GenericResult(value=100)) + + result = consume(SimpleContext(value=5)) + self.assertEqual(result.value, 105) + + def test_deps_method_generation(self): + """Test that __deps__ method is correctly generated.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=10) + deps = consume.__deps__(ctx) + + # Should have one dependency + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0][0], load) + self.assertEqual(deps[0][1], [ctx]) + + def test_no_deps_when_direct_value(self): + """Test that __deps__ returns empty when direct values used.""" + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + consume = consumer(data=GenericResult(value=100)) + + deps = consume.__deps__(SimpleContext(value=10)) + self.assertEqual(len(deps), 0) + + +# ============================================================================= +# Transform Tests +# ============================================================================= + + +class TestFlowModelTransforms(TestCase): + """Tests for Flow.model with context transforms.""" + + def test_transform_in_dep(self): + """Test dependency with context transform.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[ + GenericResult[int], + Dep(transform=lambda ctx: ctx.model_copy(update={"value": ctx.value + 10})), + ], + ) -> GenericResult[int]: + return GenericResult(value=data.value * 2) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=5) + result = consume(ctx) + + # Transform adds 10 to context.value: 5 + 10 = 15 + # Loader returns that: 15 + # Consumer multiplies by 2: 30 + self.assertEqual(result.value, 30) + + def test_transform_in_deps_method(self): + """Test that transform is applied in __deps__ method.""" + + def transform_fn(ctx): + return ctx.model_copy(update={"value": ctx.value * 3}) + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep(transform=transform_fn)], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=7) + deps = consume.__deps__(ctx) + + # Transform should be applied + self.assertEqual(len(deps), 1) + transformed_ctx = deps[0][1][0] + self.assertEqual(transformed_ctx.value, 21) # 7 * 3 + + def test_date_range_transform(self): + """Test transform pattern with date ranges using context_args.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date}") + + def lookback_transform(ctx: DateRangeContext) -> DateRangeContext: + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + + @Flow.model(context_args=["start_date", "end_date"]) + def range_processor( + start_date: date, + end_date: date, + data: Annotated[GenericResult[str], Dep(transform=lookback_transform)], + ) -> GenericResult[str]: + return GenericResult(value=f"processed:{data.value}") + + loader = range_loader(source="db") + processor = range_processor(data=loader) + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + # Transform should shift start_date back by 1 day + self.assertEqual(result.value, "processed:db:2024-01-09") + + +# ============================================================================= +# Pipeline Tests +# ============================================================================= + + +class TestFlowModelPipeline(TestCase): + """Tests for multi-stage pipelines with Flow.model.""" + + def test_three_stage_pipeline(self): + """Test a three-stage computation pipeline.""" + + @Flow.model + def stage1(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def stage2( + context: SimpleContext, + input_data: DepOf[..., GenericResult[int]], + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=input_data.value * multiplier) + + @Flow.model + def stage3( + context: SimpleContext, + input_data: DepOf[..., GenericResult[int]], + offset: int = 0, + ) -> GenericResult[int]: + return GenericResult(value=input_data.value + offset) + + # Build pipeline + s1 = stage1(base=100) + s2 = stage2(input_data=s1, multiplier=2) + s3 = stage3(input_data=s2, offset=50) + + ctx = SimpleContext(value=10) + result = s3(ctx) + + # s1: 10 + 100 = 110 + # s2: 110 * 2 = 220 + # s3: 220 + 50 = 270 + self.assertEqual(result.value, 270) + + def test_diamond_dependency_pattern(self): + """Test diamond-shaped dependency pattern.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def branch_a( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value * 2) + + @Flow.model + def branch_b( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value + 100) + + @Flow.model + def merger( + context: SimpleContext, + a: DepOf[..., GenericResult[int]], + b: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=a.value + b.value) + + src = source() + a = branch_a(data=src) + b = branch_b(data=src) + merge = merger(a=a, b=b) + + ctx = SimpleContext(value=10) + result = merge(ctx) + + # source: 10 + # branch_a: 10 * 2 = 20 + # branch_b: 10 + 100 = 110 + # merger: 20 + 110 = 130 + self.assertEqual(result.value, 130) + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestFlowModelIntegration(TestCase): + """Integration tests for Flow.model with ccflow infrastructure.""" + + def test_registry_integration(self): + """Test that Flow.model models work with ModelRegistry.""" + + @Flow.model + def registrable_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=context.value + value) + + model = registrable_model(value=100) + + registry = ModelRegistry.root().clear() + registry.add("test_model", model) + + retrieved = registry["test_model"] + self.assertEqual(retrieved, model) + + result = retrieved(SimpleContext(value=10)) + self.assertEqual(result.value, 110) + + def test_serialization_dump(self): + """Test that generated models can be serialized.""" + + @Flow.model + def serializable_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + return GenericResult(value=value) + + model = serializable_model(value=100) + dumped = model.model_dump(mode="python") + + self.assertIn("value", dumped) + self.assertEqual(dumped["value"], 100) + self.assertIn("type_", dumped) + + def test_pickle_roundtrip(self): + """Test cloudpickle serialization of generated models.""" + + @Flow.model + def pickleable_model(context: SimpleContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.value * factor) + + model = pickleable_model(factor=3) + + # Cloudpickle roundtrip (standard pickle won't work for local classes) + pickled = rcpdumps(model, protocol=5) + restored = rcploads(pickled) + + result = restored(SimpleContext(value=10)) + self.assertEqual(result.value, 30) + + def test_mix_with_manual_callable_model(self): + """Test mixing Flow.model with manually defined CallableModel.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + @Flow.model + def generated_consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=data.value * multiplier) + + manual = ManualModel(offset=50) + generated = generated_consumer(data=manual, multiplier=2) + + result = generated(SimpleContext(value=10)) + # manual: 10 + 50 = 60 + # generated: 60 * 2 = 120 + self.assertEqual(result.value, 120) + + +# ============================================================================= +# Error Case Tests +# ============================================================================= + + +class TestFlowModelErrors(TestCase): + """Error case tests for Flow.model.""" + + def test_missing_return_type(self): + """Test error when return type annotation is missing.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def no_return(context: SimpleContext): + return GenericResult(value=1) + + self.assertIn("return type annotation", str(cm.exception)) + + def test_non_result_return_type(self): + """Test error when return type is not ResultBase subclass.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def bad_return(context: SimpleContext) -> int: + return 42 + + self.assertIn("ResultBase", str(cm.exception)) + + def test_dynamic_deferred_mode(self): + """Test dynamic deferred mode where what you provide at construction = bound.""" + from ccflow import FlowContext + + @Flow.model + def dynamic_model(value: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + # Provide 'multiplier' at construction -> it's bound + # Don't provide 'value' -> comes from context + model = dynamic_model(multiplier=3) + + # Check bound vs unbound + self.assertEqual(model.flow.bound_inputs, {"multiplier": 3}) + self.assertEqual(model.flow.unbound_inputs, {"value": int}) + + # Call with context providing 'value' + ctx = FlowContext(value=10) + result = model(ctx) + self.assertEqual(result.value, 30) # 10 * 3 + + def test_all_defaults_is_valid(self): + """Test that all-defaults function is valid (everything can be pre-bound).""" + from ccflow import FlowContext + + @Flow.model + def all_defaults(value: int = 1, other: str = "x") -> GenericResult[str]: + return GenericResult(value=f"{value}-{other}") + + # No args provided -> everything comes from defaults or context + model = all_defaults() + + # All params are unbound (not provided at construction) + self.assertEqual(model.flow.unbound_inputs, {"value": int, "other": str}) + + # Call with context - context values override defaults + ctx = FlowContext(value=5, other="y") + result = model(ctx) + self.assertEqual(result.value, "5-y") + + def test_invalid_context_arg(self): + """Test error when context_args refers to non-existent parameter.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["nonexistent"]) + def bad_context_args(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("nonexistent", str(cm.exception)) + + def test_context_arg_without_annotation(self): + """Test error when context_arg parameter lacks type annotation.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["x"]) + def untyped_context_arg(x) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("type annotation", str(cm.exception)) + + +# ============================================================================= +# Dep and DepOf Tests +# ============================================================================= + + +class TestDepAndDepOf(TestCase): + """Tests for Dep and DepOf classes.""" + + def test_depof_creates_annotated(self): + """Test that DepOf[..., T] creates Annotated[Union[T, CallableModel], Dep()].""" + from typing import Union as TypingUnion, get_args, get_origin + + annotation = DepOf[..., GenericResult[int]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[int], union_args) + self.assertIn(CallableModel, union_args) + # Second arg is Dep() + self.assertIsInstance(args[1], Dep) + self.assertIsNone(args[1].context_type) # ... means inherit from parent + + def test_depof_with_generic_type(self): + """Test DepOf with nested generic types.""" + from typing import List as TypingList, Union as TypingUnion, get_args, get_origin + + annotation = DepOf[..., GenericResult[TypingList[str]]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[TypingList[str]], union_args) + self.assertIn(CallableModel, union_args) + + def test_depof_with_context_type(self): + """Test DepOf[ContextType, ResultType] syntax.""" + from typing import Union as TypingUnion, get_args, get_origin + + annotation = DepOf[SimpleContext, GenericResult[int]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[int], union_args) + self.assertIn(CallableModel, union_args) + # Second arg is Dep with context_type + self.assertIsInstance(args[1], Dep) + self.assertEqual(args[1].context_type, SimpleContext) + + def test_extract_dep_with_annotated(self): + """Test extract_dep with Annotated type.""" + from ccflow.dep import extract_dep + + dep = Dep(context_type=SimpleContext) + annotation = Annotated[GenericResult[int], dep] + + base_type, extracted_dep = extract_dep(annotation) + self.assertEqual(base_type, GenericResult[int]) + self.assertEqual(extracted_dep, dep) + + def test_extract_dep_with_depof(self): + """Test extract_dep with DepOf type.""" + from typing import Union as TypingUnion, get_args, get_origin + + from ccflow.dep import extract_dep + + annotation = DepOf[..., GenericResult[str]] + base_type, extracted_dep = extract_dep(annotation) + + # base_type is Union[ResultType, CallableModel] + self.assertEqual(get_origin(base_type), TypingUnion) + union_args = get_args(base_type) + self.assertIn(GenericResult[str], union_args) + self.assertIn(CallableModel, union_args) + self.assertIsInstance(extracted_dep, Dep) + + def test_extract_dep_without_dep(self): + """Test extract_dep with regular type (no Dep).""" + from ccflow.dep import extract_dep + + base_type, extracted_dep = extract_dep(int) + self.assertEqual(base_type, int) + self.assertIsNone(extracted_dep) + + def test_extract_dep_annotated_without_dep(self): + """Test extract_dep with Annotated but no Dep marker.""" + from ccflow.dep import extract_dep + + annotation = Annotated[int, "some metadata"] + base_type, extracted_dep = extract_dep(annotation) + + # When no Dep marker is found, returns original annotation unchanged + self.assertEqual(base_type, annotation) + self.assertIsNone(extracted_dep) + + def test_is_compatible_type_simple(self): + """Test _is_compatible_type with simple types.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(int, int)) + self.assertFalse(_is_compatible_type(int, str)) + self.assertTrue(_is_compatible_type(bool, int)) # bool is subclass of int + + def test_is_compatible_type_generic(self): + """Test _is_compatible_type with generic types.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(GenericResult[int], GenericResult[int])) + self.assertFalse(_is_compatible_type(GenericResult[int], GenericResult[str])) + self.assertTrue(_is_compatible_type(GenericResult, GenericResult)) + + def test_is_compatible_type_none(self): + """Test _is_compatible_type with None.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(None, None)) + self.assertFalse(_is_compatible_type(None, int)) + self.assertFalse(_is_compatible_type(int, None)) + + def test_is_compatible_type_subclass(self): + """Test _is_compatible_type with subclasses.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(MyResult, ResultBase)) + self.assertFalse(_is_compatible_type(ResultBase, MyResult)) + + def test_dep_validate_dependency_success(self): + """Test Dep.validate_dependency with valid dependency.""" + + @Flow.model + def valid_dep(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + dep = Dep() + model = valid_dep() + + # Should not raise + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + def test_dep_validate_dependency_context_mismatch(self): + """Test Dep.validate_dependency with context type mismatch.""" + + class OtherContext(ContextBase): + other: str + + @Flow.model + def other_dep(context: OtherContext) -> GenericResult[int]: + return GenericResult(value=42) + + dep = Dep(context_type=SimpleContext) + model = other_dep() + + with self.assertRaises(TypeError) as cm: + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + self.assertIn("context_type", str(cm.exception)) + + def test_dep_validate_dependency_result_mismatch(self): + """Test Dep.validate_dependency with result type mismatch.""" + + @Flow.model + def wrong_result(context: SimpleContext) -> MyResult: + return MyResult(data="test") + + dep = Dep() + model = wrong_result() + + with self.assertRaises(TypeError) as cm: + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + self.assertIn("result_type", str(cm.exception)) + + def test_dep_validate_dependency_non_callable(self): + """Test Dep.validate_dependency with non-CallableModel value.""" + dep = Dep() + # Should not raise for non-CallableModel values + dep.validate_dependency(GenericResult(value=42), GenericResult[int], SimpleContext, "data") + dep.validate_dependency("string", GenericResult[int], SimpleContext, "data") + dep.validate_dependency(123, GenericResult[int], SimpleContext, "data") + + def test_dep_hash(self): + """Test Dep is hashable for use in sets/dicts.""" + dep1 = Dep() + dep2 = Dep(context_type=SimpleContext) + + # Should be hashable + dep_set = {dep1, dep2} + self.assertEqual(len(dep_set), 2) + + dep_dict = {dep1: "value1", dep2: "value2"} + self.assertEqual(dep_dict[dep1], "value1") + self.assertEqual(dep_dict[dep2], "value2") + + def test_dep_apply_with_transform(self): + """Test Dep.apply with transform function.""" + + def transform(ctx): + return ctx.model_copy(update={"value": ctx.value * 2}) + + dep = Dep(transform=transform) + + ctx = SimpleContext(value=10) + result = dep.apply(ctx) + + self.assertEqual(result.value, 20) + + def test_dep_apply_without_transform(self): + """Test Dep.apply without transform (identity).""" + dep = Dep() + + ctx = SimpleContext(value=10) + result = dep.apply(ctx) + + self.assertIs(result, ctx) + + def test_dep_repr(self): + """Test Dep string representation.""" + dep1 = Dep() + self.assertEqual(repr(dep1), "Dep()") + + dep2 = Dep(context_type=SimpleContext) + self.assertIn("SimpleContext", repr(dep2)) + + dep3 = Dep(transform=lambda x: x) + self.assertIn("transform=", repr(dep3)) + + def test_dep_equality(self): + """Test Dep equality comparison.""" + dep1 = Dep() + dep2 = Dep() + dep3 = Dep(context_type=SimpleContext) + + # Note: Two Dep() instances with no arguments are equal + self.assertEqual(dep1, dep2) + self.assertNotEqual(dep1, dep3) + + +# ============================================================================= +# Validation Tests +# ============================================================================= + + +class TestFlowModelValidation(TestCase): + """Tests for dependency validation in Flow.model.""" + + def test_context_type_validation(self): + """Test that context_type mismatch is detected.""" + + class OtherContext(ContextBase): + other: str + + @Flow.model + def simple_loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def other_loader(context: OtherContext) -> GenericResult[int]: + return GenericResult(value=42) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep(context_type=SimpleContext)], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + # Should work with matching context + load1 = simple_loader() + consume1 = consumer(data=load1) + self.assertIsNotNone(consume1) + + # Should fail with mismatched context + load2 = other_loader() + with self.assertRaises((TypeError, ValidationError)): + consumer(data=load2) + + +# ============================================================================= +# Hydra Integration Tests +# ============================================================================= + + +# Define Flow.model functions at module level for Hydra to find them +@Flow.model +def hydra_basic_model(context: SimpleContext, value: int, name: str = "default") -> GenericResult[str]: + """Module-level model for Hydra testing.""" + return GenericResult(value=f"{name}:{context.value + value}") + + +# --- Additional module-level fixtures for Hydra YAML tests --- + + +@Flow.model +def basic_loader(context: SimpleContext, source: str, multiplier: int = 1) -> GenericResult[int]: + """Basic loader that multiplies context value by multiplier.""" + return GenericResult(value=context.value * multiplier) + + +@Flow.model +def string_processor(context: SimpleContext, prefix: str, suffix: str = "") -> GenericResult[str]: + """Process context value into a string with prefix and suffix.""" + return GenericResult(value=f"{prefix}{context.value}{suffix}") + + +@Flow.model +def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: + """Source that provides base data.""" + return GenericResult(value=context.value + base_value) + + +@Flow.model +def data_transformer( + context: SimpleContext, + source: DepOf[..., GenericResult[int]], + factor: int = 2, +) -> GenericResult[int]: + """Transform data by multiplying with factor.""" + return GenericResult(value=source.value * factor) + + +@Flow.model +def data_aggregator( + context: SimpleContext, + input_a: DepOf[..., GenericResult[int]], + input_b: DepOf[..., GenericResult[int]], + operation: str = "add", +) -> GenericResult[int]: + """Aggregate two inputs.""" + if operation == "add": + return GenericResult(value=input_a.value + input_b.value) + elif operation == "multiply": + return GenericResult(value=input_a.value * input_b.value) + else: + return GenericResult(value=input_a.value - input_b.value) + + +@Flow.model +def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: + """First stage of pipeline.""" + return GenericResult(value=context.value + initial) + + +@Flow.model +def pipeline_stage2( + context: SimpleContext, + stage1_output: DepOf[..., GenericResult[int]], + multiplier: int = 2, +) -> GenericResult[int]: + """Second stage of pipeline.""" + return GenericResult(value=stage1_output.value * multiplier) + + +@Flow.model +def pipeline_stage3( + context: SimpleContext, + stage2_output: DepOf[..., GenericResult[int]], + offset: int = 0, +) -> GenericResult[int]: + """Third stage of pipeline.""" + return GenericResult(value=stage2_output.value + offset) + + +def lookback_one_day(ctx: DateRangeContext) -> DateRangeContext: + """Transform that extends start_date back by one day.""" + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + + +@Flow.model +def date_range_loader( + context: DateRangeContext, + source: str, + include_weekends: bool = True, +) -> GenericResult[str]: + """Load data for a date range.""" + return GenericResult(value=f"{source}:{context.start_date} to {context.end_date}") + + +@Flow.model +def date_range_processor( + context: DateRangeContext, + raw_data: Annotated[GenericResult[str], Dep(transform=lookback_one_day)], + normalize: bool = False, +) -> GenericResult[str]: + """Process date range data with lookback.""" + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data.value}") + + +@Flow.model +def hydra_default_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + """Module-level model with defaults for Hydra testing.""" + return GenericResult(value=context.value + value) + + +@Flow.model +def hydra_source_model(context: SimpleContext, base: int) -> GenericResult[int]: + """Source model for dependency testing.""" + return GenericResult(value=context.value * base) + + +@Flow.model +def hydra_consumer_model( + context: SimpleContext, + source: DepOf[..., GenericResult[int]], + factor: int = 1, +) -> GenericResult[int]: + """Consumer model for dependency testing.""" + return GenericResult(value=source.value * factor) + + +# --- context_args fixtures for Hydra testing --- + + +@Flow.model(context_args=["start_date", "end_date"]) +def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + """Loader using context_args with DateRangeContext.""" + return GenericResult(value=f"{source}:{start_date} to {end_date}") + + +@Flow.model(context_args=["start_date", "end_date"]) +def context_args_processor( + start_date: date, + end_date: date, + data: DepOf[..., GenericResult[str]], + prefix: str = "processed", +) -> GenericResult[str]: + """Processor using context_args with dependency.""" + return GenericResult(value=f"{prefix}:{data.value}") + + +class TestFlowModelHydra(TestCase): + """Tests for Flow.model with Hydra configuration.""" + + def test_hydra_instantiate_basic(self): + """Test that Flow.model factory can be instantiated via Hydra.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create config that references the factory function by module path + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_basic_model", + "value": 100, + "name": "test", + } + ) + + # Instantiate via Hydra + model = instantiate(cfg) + + self.assertIsInstance(model, CallableModel) + result = model(SimpleContext(value=10)) + self.assertEqual(result.value, "test:110") + + def test_hydra_instantiate_with_defaults(self): + """Test Hydra instantiation using default parameter values.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_default_model", + # Not specifying value, should use default + } + ) + + model = instantiate(cfg) + result = model(SimpleContext(value=8)) + self.assertEqual(result.value, 50) + + def test_hydra_instantiate_with_dependency(self): + """Test Hydra instantiation with dependencies.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create nested config + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_consumer_model", + "source": { + "_target_": "ccflow.tests.test_flow_model.hydra_source_model", + "base": 10, + }, + "factor": 2, + } + ) + + model = instantiate(cfg) + + result = model(SimpleContext(value=5)) + # source: 5 * 10 = 50, consumer: 50 * 2 = 100 + self.assertEqual(result.value, 100) + + +# ============================================================================= +# Class-based CallableModel with Auto-Resolution Tests +# ============================================================================= + + +class TestClassBasedDepResolution(TestCase): + """Tests for auto-resolution of DepOf fields in class-based CallableModels. + + Key pattern: Fields use DepOf annotation, __call__ only takes context, + and resolved values are accessed via self.field_name during __call__. + """ + + def test_class_based_auto_resolve_basic(self): + """Test that DepOf fields are auto-resolved and accessible via resolve().""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + # DepOf expands to Annotated[Union[ResultType, CallableModel], Dep()] + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + # Access resolved value via resolve() + return GenericResult(value=resolve(self.source).value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + src = data_source() + consumer = Consumer(source=src) + + result = consumer(SimpleContext(value=5)) + # source: 5 * 10 = 50, consumer: 50 + 1 = 51 + self.assertEqual(result.value, 51) + + def test_class_based_with_custom_transform(self): + """Test that custom __deps__ transform is used.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + offset: int = 100 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.source).value + self.offset) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # Apply custom transform + transformed_ctx = SimpleContext(value=context.value + 5) + return [(self.source, [transformed_ctx])] + + src = data_source() + consumer = Consumer(source=src, offset=1) + + result = consumer(SimpleContext(value=5)) + # transformed context: 5 + 5 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result.value, 101) + + def test_class_based_with_annotated_transform(self): + """Test that Dep transform is used when field not in __deps__.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + def double_value(ctx: SimpleContext) -> SimpleContext: + return SimpleContext(value=ctx.value * 2) + + class Consumer(CallableModel): + source: Annotated[DepOf[..., GenericResult[int]], Dep(transform=double_value)] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.source).value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [] # Empty - uses Dep annotation transform from field + + src = data_source() + consumer = Consumer(source=src) + + result = consumer(SimpleContext(value=5)) + # transform: 5 * 2 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result.value, 101) + + def test_class_based_multiple_deps(self): + """Test auto-resolution with multiple dependencies.""" + + @Flow.model + def source_a(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def source_b(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + class Aggregator(CallableModel): + a: DepOf[..., GenericResult[int]] + b: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.a).value + resolve(self.b).value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.a, [context]), (self.b, [context])] + + agg = Aggregator(a=source_a(), b=source_b()) + + result = agg(SimpleContext(value=10)) + # a: 10, b: 20, aggregator: 30 + self.assertEqual(result.value, 30) + + def test_class_based_deps_with_instance_field_access(self): + """Test that __deps__ can access instance fields for configurable transforms. + + This is the key advantage of class-based models over @Flow.model: + transforms can use instance fields like window size. + """ + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + lookback: int = 5 # Configurable instance field + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.source).value * 2) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # Access self.lookback in transform - this is why we use class-based! + transformed = SimpleContext(value=context.value + self.lookback) + return [(self.source, [transformed])] + + src = data_source() + consumer = Consumer(source=src, lookback=10) + + result = consumer(SimpleContext(value=5)) + # transformed: 5 + 10 = 15 + # source: 15 + # consumer: 15 * 2 = 30 + self.assertEqual(result.value, 30) + + def test_class_based_with_direct_value(self): + """Test that DepOf fields can accept pre-resolved values.""" + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + # resolve() passes through non-CallableModel values unchanged + return GenericResult(value=resolve(self.source).value + context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # No deps when source is already resolved + return [] + + # Pass direct value instead of CallableModel + consumer = Consumer(source=GenericResult(value=100)) + + result = consumer(SimpleContext(value=5)) + self.assertEqual(result.value, 105) + + def test_class_based_no_double_call(self): + """Test that dependencies are not called twice during DepOf resolution. + + This verifies that the auto-resolution mechanism doesn't accidentally + evaluate the same dependency multiple times. + """ + call_counts = {"source": 0} + + @Flow.model + def counting_source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.data).value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + src = counting_source() + consumer = Consumer(data=src) + + # Call consumer - source should only be called once + result = consumer(SimpleContext(value=5)) + + self.assertEqual(result.value, 51) # 5 * 10 + 1 + self.assertEqual(call_counts["source"], 1, "Source should only be called once") + + def test_class_based_nested_depof_no_double_call(self): + """Test nested DepOf chain (A -> B -> C) has no double-calls at any layer. + + This tests a 3-layer dependency chain where: + - layer_c is the leaf (no dependencies) + - layer_b depends on layer_c + - layer_a depends on layer_b + + Each layer should be called exactly once. + """ + call_counts = {"layer_a": 0, "layer_b": 0, "layer_c": 0} + + # Layer C: leaf node (no dependencies) + @Flow.model + def layer_c(context: SimpleContext) -> GenericResult[int]: + call_counts["layer_c"] += 1 + return GenericResult(value=context.value) + + # Layer B: depends on layer_c + class LayerB(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["layer_b"] += 1 + return GenericResult(value=resolve(self.source).value * 10) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + # Layer A: depends on layer_b + class LayerA(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["layer_a"] += 1 + return GenericResult(value=resolve(self.source).value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + # Build the chain: A -> B -> C + c = layer_c() + b = LayerB(source=c) + a = LayerA(source=b) + + # Call layer_a - each layer should be called exactly once + result = a(SimpleContext(value=5)) + + # Verify result: C returns 5, B returns 5*10=50, A returns 50+1=51 + self.assertEqual(result.value, 51) + + # Verify each layer called exactly once + self.assertEqual(call_counts["layer_c"], 1, "layer_c should be called exactly once") + self.assertEqual(call_counts["layer_b"], 1, "layer_b should be called exactly once") + self.assertEqual(call_counts["layer_a"], 1, "layer_a should be called exactly once") + + def test_resolve_direct_value_passthrough(self): + """Test that resolve() passes through non-CallableModel values unchanged.""" + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + # resolve() should return the GenericResult directly (pass-through) + resolved = resolve(self.data) + # Verify it's the actual GenericResult, not a CallableModel + assert isinstance(resolved, GenericResult) + return GenericResult(value=resolved.value * 2) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [] + + # Pass a direct value, not a CallableModel + direct_result = GenericResult(value=42) + consumer = Consumer(data=direct_result) + + result = consumer(SimpleContext(value=5)) + self.assertEqual(result.value, 84) # 42 * 2 + + def test_resolve_outside_call_raises_error(self): + """Test that resolve() raises RuntimeError when called outside __call__.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.data).value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + src = source() + consumer = Consumer(data=src) + + # Calling resolve() outside of __call__ should raise RuntimeError + with self.assertRaises(RuntimeError) as cm: + resolve(consumer.data) + + self.assertIn("resolve() can only be used inside __call__", str(cm.exception)) + + def test_flow_model_uses_unified_resolution_path(self): + """Test that @Flow.model uses the same resolution path as class-based CallableModel. + + This verifies the consolidation of resolution logic - both @Flow.model and + class-based models should use _resolve_deps_and_call in callable.py. + """ + call_counts = {"source": 0, "decorator_model": 0, "class_model": 0} + + @Flow.model + def shared_source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 2) + + # @Flow.model consumer + @Flow.model + def decorator_consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + call_counts["decorator_model"] += 1 + return GenericResult(value=data.value + 100) + + # Class-based consumer (same logic) + class ClassConsumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["class_model"] += 1 + return GenericResult(value=resolve(self.data).value + 100) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + # Test both consumers with the same source + src = shared_source() + dec_consumer = decorator_consumer(data=src) + cls_consumer = ClassConsumer(data=src) + + ctx = SimpleContext(value=10) + + # Both should produce the same result + dec_result = dec_consumer(ctx) + cls_result = cls_consumer(ctx) + + self.assertEqual(dec_result.value, cls_result.value) + self.assertEqual(dec_result.value, 120) # 10 * 2 + 100 + + # Source should be called exactly twice (once per consumer) + self.assertEqual(call_counts["source"], 2) + self.assertEqual(call_counts["decorator_model"], 1) + self.assertEqual(call_counts["class_model"], 1) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py new file mode 100644 index 0000000..661ac4f --- /dev/null +++ b/ccflow/tests/test_flow_model_hydra.py @@ -0,0 +1,437 @@ +"""Hydra integration tests for Flow.model. + +These tests verify that Flow.model decorated functions work correctly when +loaded from YAML configuration files using ModelRegistry.load_config_from_path(). + +Key feature: Registry name references (e.g., `source: flow_source`) ensure the same +object instance is shared across all consumers. +""" + +from datetime import date +from pathlib import Path +from unittest import TestCase + +from omegaconf import OmegaConf + +from ccflow import CallableModel, DateRangeContext, GenericResult, ModelRegistry + +from .test_flow_model import SimpleContext + +CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") + + +class TestFlowModelHydraYAML(TestCase): + """Tests loading Flow.model from YAML config files using ModelRegistry.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_basic_loader_from_yaml(self): + """Test basic model instantiation from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + + self.assertIsInstance(loader, CallableModel) + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 50) # 10 * 5 + + def test_string_processor_from_yaml(self): + """Test string processor model from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_processor"] + + ctx = SimpleContext(value=42) + result = processor(ctx) + self.assertEqual(result.value, "value=42!") + + def test_two_stage_pipeline_from_yaml(self): + """Test two-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + self.assertIsInstance(transformer, CallableModel) + + ctx = SimpleContext(value=5) + result = transformer(ctx) + # flow_source: 5 + 100 = 105 + # flow_transformer: 105 * 3 = 315 + self.assertEqual(result.value, 315) + + def test_three_stage_pipeline_from_yaml(self): + """Test three-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage3 = r["flow_stage3"] + + ctx = SimpleContext(value=10) + result = stage3(ctx) + # stage1: 10 + 10 = 20 + # stage2: 20 * 2 = 40 + # stage3: 40 + 50 = 90 + self.assertEqual(result.value, 90) + + def test_diamond_dependency_from_yaml(self): + """Test diamond dependency pattern from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + aggregator = r["diamond_aggregator"] + + ctx = SimpleContext(value=10) + result = aggregator(ctx) + # source: 10 + 10 = 20 + # branch_a: 20 * 2 = 40 + # branch_b: 20 * 5 = 100 + # aggregator: 40 + 100 = 140 + self.assertEqual(result.value, 140) + + def test_date_range_pipeline_from_yaml(self): + """Test DateRangeContext pipeline with transforms from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + # The transform extends start_date back by one day + self.assertIn("2024-01-09", result.value) + self.assertIn("normalized:", result.value) + + def test_context_args_from_yaml(self): + """Test context_args model from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + + self.assertIsInstance(loader, CallableModel) + # context_args models use DateRangeContext + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual(result.value, "data_source:2024-01-01 to 2024-01-31") + + def test_context_args_pipeline_from_yaml(self): + """Test context_args pipeline with dependencies from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["ctx_args_processor"] + + ctx = DateRangeContext(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) + result = processor(ctx) + # loader: "data_source:2024-03-01 to 2024-03-31" + # processor: "output:data_source:2024-03-01 to 2024-03-31" + self.assertEqual(result.value, "output:data_source:2024-03-01 to 2024-03-31") + + def test_context_args_shares_instance(self): + """Test that context_args pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + processor = r["ctx_args_processor"] + + self.assertIs(processor.data, loader) + + +class TestFlowModelHydraInstanceSharing(TestCase): + """Tests that registry name references share the same object instance.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_pipeline_shares_instance(self): + """Test that pipeline stages share the same dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + source = r["flow_source"] + + self.assertIs(transformer.source, source) + + def test_three_stage_pipeline_shares_instances(self): + """Test that three-stage pipeline shares instances correctly.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage1 = r["flow_stage1"] + stage2 = r["flow_stage2"] + stage3 = r["flow_stage3"] + + self.assertIs(stage2.stage1_output, stage1) + self.assertIs(stage3.stage2_output, stage2) + + def test_diamond_pattern_shares_source_instance(self): + """Test that diamond pattern branches share the same source instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + source = r["diamond_source"] + branch_a = r["diamond_branch_a"] + branch_b = r["diamond_branch_b"] + aggregator = r["diamond_aggregator"] + + # Both branches should share the SAME source instance + self.assertIs(branch_a.source, source) + self.assertIs(branch_b.source, source) + self.assertIs(branch_a.source, branch_b.source) + + self.assertIs(aggregator.input_a, branch_a) + self.assertIs(aggregator.input_b, branch_b) + + def test_date_range_shares_instance(self): + """Test that date range pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_date_loader"] + processor = r["flow_date_processor"] + + self.assertIs(processor.raw_data, loader) + + +class TestFlowModelHydraOmegaConf(TestCase): + """Tests using OmegaConf.create for dynamic config creation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_instantiate_with_omegaconf(self): + """Test instantiation using OmegaConf.create via ModelRegistry.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "dynamic_source", + "multiplier": 7, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=3) + result = loader(ctx) + self.assertEqual(result.value, 21) # 3 * 7 + + def test_nested_deps_with_omegaconf(self): + """Test nested dependencies using OmegaConf with registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 50, + }, + "transformer": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 4, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + transformer = r["transformer"] + + ctx = SimpleContext(value=10) + result = transformer(ctx) + # source: 10 + 50 = 60 + # transformer: 60 * 4 = 240 + self.assertEqual(result.value, 240) + + self.assertIs(transformer.source, r["source"]) + + def test_diamond_with_omegaconf(self): + """Test diamond pattern with OmegaConf using registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 10, + }, + "branch_a": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 2, + }, + "branch_b": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 3, + }, + "aggregator": { + "_target_": "ccflow.tests.test_flow_model.data_aggregator", + "input_a": "branch_a", + "input_b": "branch_b", + "operation": "multiply", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + aggregator = r["aggregator"] + + ctx = SimpleContext(value=5) + result = aggregator(ctx) + # source: 5 + 10 = 15 + # branch_a: 15 * 2 = 30 + # branch_b: 15 * 3 = 45 + # aggregator: 30 * 45 = 1350 + self.assertEqual(result.value, 1350) + + # Verify SAME source instance is shared + self.assertIs(r["branch_a"].source, r["source"]) + self.assertIs(r["branch_b"].source, r["source"]) + + +class TestFlowModelHydraDefaults(TestCase): + """Tests that default parameter values work with Hydra.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_defaults_used_when_not_specified(self): + """Test that default values are used when not in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 10) # 10 * 1 (default) + + def test_defaults_can_be_overridden(self): + """Test that defaults can be overridden in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + "multiplier": 100, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 1000) # 10 * 100 + + +class TestFlowModelHydraModelProperties(TestCase): + """Tests that model properties are correct after Hydra instantiation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_context_type_property(self): + """Test that context_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.context_type, SimpleContext) + + def test_result_type_property(self): + """Test that result_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.result_type, GenericResult[int]) + + def test_deps_method_works(self): + """Test that __deps__ method works after Hydra instantiation.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + ctx = SimpleContext(value=5) + deps = transformer.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIsInstance(deps[0][0], CallableModel) + self.assertEqual(deps[0][1], [ctx]) + self.assertIs(deps[0][0], r["flow_source"]) + + +class TestFlowModelHydraDateRangeTransforms(TestCase): + """Tests transforms with DateRangeContext from Hydra config.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_transform_applied_from_yaml(self): + """Test that transform is applied when loaded from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + deps = processor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + dep_model, dep_contexts = deps[0] + + # The transform should extend start_date back by one day + transformed_ctx = dep_contexts[0] + self.assertEqual(transformed_ctx.start_date, date(2024, 1, 9)) + self.assertEqual(transformed_ctx.end_date, date(2024, 1, 31)) + + self.assertIs(dep_model, r["flow_date_loader"]) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md new file mode 100644 index 0000000..76d0eb7 --- /dev/null +++ b/docs/design/flow_model_design.md @@ -0,0 +1,440 @@ +# Flow.model and DepOf: Dependency Injection for CallableModel + +## Overview + +This document describes the `@Flow.model` decorator and `DepOf` annotation system for reducing boilerplate when creating `CallableModel` pipelines with dependencies. + +**Key features:** +- `@Flow.model` - Decorator that generates `CallableModel` classes from plain functions +- `DepOf[ContextType, ResultType]` - Type annotation for dependency fields +- `resolve()` - Function to access resolved dependency values in class-based models + +## Quick Start + +### Pattern 1: `@Flow.model` (Recommended for Simple Cases) + +```python +from datetime import date, timedelta +from typing import Annotated + +from ccflow import Flow, DateRangeContext, GenericResult, DepOf + +@Flow.model +def load_records(context: DateRangeContext, source: str) -> GenericResult[dict]: + return GenericResult(value={"count": 100, "date": str(context.start_date)}) + +@Flow.model +def compute_stats( + context: DateRangeContext, + records: DepOf[..., GenericResult[dict]], # Dependency field +) -> GenericResult[float]: + # records is already resolved - just use it directly + return GenericResult(value=records.value["count"] * 0.05) + +# Build pipeline +loader = load_records(source="main_db") +stats = compute_stats(records=loader) + +# Execute +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = stats(ctx) +``` + +### Pattern 2: Class-Based (For Complex Cases) + +Use class-based when you need **configurable transforms** that depend on instance fields: + +```python +from datetime import timedelta + +from ccflow import CallableModel, DateRangeContext, Flow, GenericResult, DepOf +from ccflow.callable import resolve # Import resolve for class-based models + +class AggregateWithWindow(CallableModel): + """Aggregate records with configurable lookback window.""" + + records: DepOf[..., GenericResult[dict]] + window: int = 7 # Configurable instance field + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + # Use resolve() to get the resolved value + records = resolve(self.records) + return GenericResult(value=records.value["count"] / self.window) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + # Transform uses self.window - this is why we need class-based! + lookback_ctx = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.records, [lookback_ctx])] + +# Usage - different window sizes, same source +loader = load_records(source="main_db") +agg_7 = AggregateWithWindow(records=loader, window=7) +agg_30 = AggregateWithWindow(records=loader, window=30) +``` + +## When to Use Which Pattern + +| Use `@Flow.model` when... | Use Class-Based when... | +|--------------------------------|--------------------------------------| +| Simple transformations | Transforms depend on instance fields | +| Fixed context transforms | Need `self.field` in `__deps__` | +| Less boilerplate is priority | Full control over resolution | +| No custom `__deps__` logic | Complex dependency patterns | + +## Core Concepts + +### `DepOf[ContextType, ResultType]` + +Shorthand for declaring dependency fields that can accept either: +- A pre-computed value of `ResultType` +- A `CallableModel` that produces `ResultType` + +```python +# Inherit context type from parent model +data: DepOf[..., GenericResult[dict]] + +# Explicit context type +data: DepOf[DateRangeContext, GenericResult[dict]] + +# Equivalent to: +data: Annotated[Union[GenericResult[dict], CallableModel], Dep()] +``` + +### `Dep(transform=..., context_type=...)` + +For transforms, use the full `Annotated` form: + +```python +from ccflow import Dep + +@Flow.model +def compute_stats( + context: DateRangeContext, + records: Annotated[GenericResult[dict], Dep( + transform=lambda ctx: ctx.model_copy( + update={"start_date": ctx.start_date - timedelta(days=1)} + ) + )], +) -> GenericResult[float]: + return GenericResult(value=records.value["count"] * 0.05) +``` + +### `resolve()` Function + +**Only needed for class-based models.** Accesses the resolved value of a `DepOf` field during `__call__`. + +```python +from ccflow.callable import resolve + +class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: MyContext) -> GenericResult[int]: + # resolve() returns the GenericResult, not the CallableModel + result = resolve(self.data) + return GenericResult(value=result.value + 1) +``` + +**Behavior:** +- Inside `__call__`: Returns the resolved value +- With direct values (not CallableModel): Returns unchanged (no-op) +- Outside `__call__`: Raises `RuntimeError` +- In `@Flow.model`: Not needed - values are passed as function arguments + +**Type inference:** +```python +data: DepOf[..., GenericResult[int]] +resolved = resolve(self.data) # Type: GenericResult[int] +``` + +## How Resolution Works + +### `@Flow.model` Resolution Flow + +1. User calls `model(context)` +2. Generated `__call__` invokes `_resolve_deps_and_call()` +3. For each `DepOf` field containing a `CallableModel`: + - Apply transform (if any) + - Call the dependency + - Store resolved value in context variable +4. Generated `__call__` retrieves resolved values via `resolve()` +5. Original function receives resolved values as arguments + +### Class-Based Resolution Flow + +1. User calls `model(context)` +2. `_resolve_deps_and_call()` runs +3. For each `DepOf` field containing a `CallableModel`: + - Check `__deps__` for custom transforms + - Call the dependency + - Store resolved value in context variable +4. User's `__call__` accesses values via `resolve(self.field)` + +**Important:** Resolution uses a context variable (`contextvars.ContextVar`), making it thread-safe and async-safe. + +## Design Decisions + +### Decision 1: `resolve()` Instead of Temporary Mutation + +**What we chose:** Explicit `resolve()` function with context variables. + +**Alternative considered:** Temporarily mutate `self.field` during `__call__` to hold the resolved value, then restore after. + +**Why we chose this:** +- No mutation of model state +- Thread/async-safe via contextvars +- Explicit about what's happening +- Easier to debug - `self.field` always shows the original value + +**Trade-off:** Slightly more verbose (`resolve(self.data).value` vs `self.data.value`). + +### Decision 2: Unified Resolution Path + +**What we chose:** Both `@Flow.model` and class-based use the same `_resolve_deps_and_call()` function. + +**Why:** +- Single source of truth for resolution logic +- Easier to maintain +- Consistent behavior across patterns + +### Decision 3: `resolve()` Not in Top-Level `__all__` + +**What we chose:** `resolve` must be imported explicitly: `from ccflow.callable import resolve` + +**Why:** +- Only needed for class-based models with `DepOf` +- Keeps top-level namespace clean +- Users who need it can find it easily + +### Decision 4: No Auto-Wrapping Return Values + +**What we chose:** Functions must explicitly return `ResultBase` subclass. + +**Why:** +- Type annotations remain honest +- Consistent with existing `CallableModel` contract +- `GenericResult(value=x)` is minimal overhead + +### Decision 5: Generated Classes Are Real CallableModels + +**What we chose:** Generate actual `CallableModel` subclasses using `type()`. + +**Why:** +- Full compatibility with existing infrastructure +- Caching, registry, serialization work unchanged +- Can mix with hand-written classes + +## Pitfalls and Limitations + +### Pitfall 1: Forgetting `resolve()` in Class-Based Models + +```python +class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context): + # WRONG - self.data is still the CallableModel! + return GenericResult(value=self.data.value + 1) + + # CORRECT + return GenericResult(value=resolve(self.data).value + 1) +``` + +**Error you'll see:** `AttributeError: '_SomeModel' object has no attribute 'value'` + +### Pitfall 2: Calling `resolve()` Outside `__call__` + +```python +model = MyModel(data=some_source()) +resolve(model.data) # RuntimeError! +``` + +`resolve()` only works during `__call__` execution. + +### Pitfall 3: Lambda Transforms Don't Serialize + +```python +# Won't serialize - lambdas can't be pickled +Dep(transform=lambda ctx: ctx.model_copy(...)) + +# Will serialize - use named functions +def shift_start(ctx): + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + +Dep(transform=shift_start) +``` + +### Pitfall 4: GraphEvaluator Requires Caching + +When using `GraphEvaluator` with `DepOf`, dependencies may be called twice (once by GraphEvaluator, once by resolution) unless caching is enabled. + +```python +# Use with caching +from ccflow.evaluators import GraphEvaluator, CachingEvaluator, MultiEvaluator + +evaluator = MultiEvaluator(evaluators=[ + CachingEvaluator(), + GraphEvaluator(), +]) +``` + +### Pitfall 5: Two Mental Models + +Users need to remember: +- `@Flow.model`: Use dependency values directly as function arguments +- Class-based: Use `resolve(self.field)` to access values + +### Limitation: `__deps__` Still Required for Class-Based + +Even without transforms, class-based models need `__deps__`: + +```python +class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context): + return GenericResult(value=resolve(self.data).value) + + @Flow.deps + def __deps__(self, context): + return [(self.data, [context])] # Boilerplate, but required +``` + +## Complete Example: Multi-Stage Pipeline + +```python +from datetime import date, timedelta +from typing import Annotated + +from ccflow import ( + CallableModel, DateRangeContext, Dep, DepOf, + Flow, GenericResult +) +from ccflow.callable import resolve + + +# Stage 1: Data loader (simple, use @Flow.model) +@Flow.model +def load_events(context: DateRangeContext, source: str) -> GenericResult[list]: + print(f"Loading from {source} for {context.start_date} to {context.end_date}") + return GenericResult(value=[ + {"date": str(context.start_date), "count": 100 + i} + for i in range(5) + ]) + + +# Stage 2: Transform with fixed lookback (use @Flow.model with Dep transform) +@Flow.model +def compute_daily_totals( + context: DateRangeContext, + events: Annotated[GenericResult[list], Dep( + transform=lambda ctx: ctx.model_copy( + update={"start_date": ctx.start_date - timedelta(days=1)} + ) + )], +) -> GenericResult[float]: + values = [e["count"] for e in events.value] + total = sum(values) / len(values) if values else 0 + return GenericResult(value=total) + + +# Stage 3: Configurable window (use class-based) +class ComputeRollingSummary(CallableModel): + """Summary with configurable lookback window.""" + + totals: DepOf[..., GenericResult[float]] + window: int = 20 + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + totals = resolve(self.totals) + # Scale by window size + summary = totals.value * (self.window ** 0.5) + return GenericResult(value=summary) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + lookback = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.totals, [lookback])] + + +# Build pipeline +events = load_events(source="main_db") +totals = compute_daily_totals(events=events) +summary_20 = ComputeRollingSummary(totals=totals, window=20) +summary_60 = ComputeRollingSummary(totals=totals, window=60) + +# Execute +ctx = DateRangeContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) +print(f"20-day summary: {summary_20(ctx).value}") +print(f"60-day summary: {summary_60(ctx).value}") +``` + +## API Reference + +### `@Flow.model` + +```python +@Flow.model( + context_args: list[str] = None, # Unpack context fields as function args + cacheable: bool = False, + volatile: bool = False, + log_level: int = logging.DEBUG, + validate_result: bool = True, + verbose: bool = True, + evaluator: EvaluatorBase = None, +) +def my_function(context: ContextType, ...) -> ResultType: + ... +``` + +### `DepOf[ContextType, ResultType]` + +```python +# Inherit context from parent +field: DepOf[..., GenericResult[int]] + +# Explicit context type +field: DepOf[DateRangeContext, GenericResult[int]] +``` + +### `Dep(transform=..., context_type=...)` + +```python +field: Annotated[GenericResult[int], Dep( + transform=my_transform_func, # Optional: (context) -> transformed_context + context_type=DateRangeContext, # Optional: Expected context type +)] +``` + +### `resolve(dep)` + +```python +from ccflow.callable import resolve + +# Inside __call__ of class-based CallableModel: +resolved_value = resolve(self.dep_field) + +# Type signature: +def resolve(dep: Union[T, CallableModel]) -> T: ... +``` + +## File Structure + +``` +ccflow/ +├── callable.py # CallableModel, Flow, resolve(), _resolve_deps_and_call() +├── dep.py # Dep, DepOf, extract_dep() +├── flow_model.py # @Flow.model implementation +└── tests/ + └── test_flow_model.py # Comprehensive tests +``` diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..a89d8f8 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -22,6 +22,121 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan `CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +### Flow.model Decorator + +The `@Flow.model` decorator provides a simpler way to define `CallableModel`s using plain Python functions instead of classes. It automatically generates a `CallableModel` class with proper `__call__` and `__deps__` methods. + +**Basic Example:** + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +@Flow.model +def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: + # Your data loading logic here + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + +# Create model instance +loader = load_data(source="my_database") + +# Execute with context +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) +``` + +**Composing Dependencies with `Dep` and `DepOf`:** + +Use `Dep()` or `DepOf` to mark parameters that accept other `CallableModel`s as dependencies. The framework automatically resolves the dependency graph. + +> **Tip:** If your function doesn't use the context directly (only passes it to dependencies), use `_` as the parameter name to signal this: `def my_func(_: DateRangeContext, data: DepOf[..., ResultType])`. This is a Python convention for intentionally unused parameters. + +```python +from datetime import date, timedelta +from typing import Annotated +from ccflow import Flow, GenericResult, DateRangeContext, Dep, DepOf + +@Flow.model +def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: + return GenericResult(value={"records": [1, 2, 3]}) + +@Flow.model +def transform_data( + _: DateRangeContext, # Context passed to dependency, not used directly + raw_data: Annotated[GenericResult[dict], Dep( + # Transform context to fetch one extra day for lookback + transform=lambda ctx: ctx.model_copy(update={ + "start_date": ctx.start_date - timedelta(days=1) + }) + )] +) -> GenericResult[dict]: + # raw_data.value contains the resolved result from load_data + return GenericResult(value={"transformed": raw_data.value["records"]}) + +# Or use DepOf shorthand (no transform needed): +@Flow.model +def aggregate_data( + _: DateRangeContext, # Context passed to dependency, not used directly + transformed: DepOf[..., GenericResult[dict]] # Shorthand for Annotated[T, Dep()] +) -> GenericResult[dict]: + return GenericResult(value={"count": len(transformed.value["transformed"])}) + +# Build the pipeline +data = load_data(source="my_database") +transformed = transform_data(raw_data=data) +aggregated = aggregate_data(transformed=transformed) + +# Execute - dependencies are automatically resolved +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = aggregated(ctx) +``` + +**Hydra/YAML Configuration:** + +`Flow.model` decorated functions work seamlessly with Hydra configuration and the `ModelRegistry`: + +```yaml +# config.yaml +data: + _target_: mymodule.load_data + source: my_database + +transformed: + _target_: mymodule.transform_data + raw_data: data # Reference by registry name (same instance is shared) + +aggregated: + _target_: mymodule.aggregate_data + transformed: transformed # Reference by registry name +``` + +When loaded via `ModelRegistry.load_config()`, references by name ensure the same object instance is shared across all consumers. + +**Auto-Unpacked Context with `context_args`:** + +Instead of taking an explicit `context` parameter, you can use `context_args` to automatically unpack context fields as function parameters. This is useful when you want cleaner function signatures: + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +# Instead of: def load_data(context: DateRangeContext, source: str) +# Use context_args to unpack the context fields directly: +@Flow.model(context_args=["start_date", "end_date"]) +def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + +# The decorator infers DateRangeContext from the parameter types +loader = load_data(source="my_database") +assert loader.context_type == DateRangeContext + +# Execute with context as usual +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" +``` + +The `context_args` parameter specifies which function parameters should be extracted from the context. The framework automatically determines the context type based on the parameter type annotations. + ## Model Registry A `ModelRegistry` is a named collection of models. diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py new file mode 100644 index 0000000..e93d452 --- /dev/null +++ b/examples/flow_model_example.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python +"""Example demonstrating Flow.model decorator and class-based CallableModel. + +This example shows: +- Flow.model for simple functions with minimal boilerplate +- Context transforms with Dep annotations +- Class-based CallableModel for complex cases needing instance field access +""" + +from datetime import date, timedelta +from typing import Annotated + +from ccflow import CallableModel, DateRangeContext, Dep, DepOf, Flow, GenericResult +from ccflow.callable import resolve + + +# ============================================================================= +# Example 1: Basic Flow.model - No more boilerplate classes! +# ============================================================================= + +@Flow.model +def load_records(context: DateRangeContext, source: str, limit: int = 100) -> GenericResult[list]: + """Load records from a data source for the given date range.""" + print(f" Loading from '{source}' for {context.start_date} to {context.end_date} (limit={limit})") + return GenericResult(value=[ + {"id": i, "date": str(context.start_date), "value": i * 10} + for i in range(min(limit, 5)) + ]) + + +# ============================================================================= +# Example 2: Dependencies with DepOf - Automatic dependency resolution +# ============================================================================= + +@Flow.model +def compute_totals( + _: DateRangeContext, # Context passed to dependency, not used directly here + records: DepOf[..., GenericResult[list]], +) -> GenericResult[dict]: + """Compute totals from loaded records.""" + total = sum(r["value"] for r in records.value) + count = len(records.value) + print(f" Computing totals: {count} records, total={total}") + return GenericResult(value={"total": total, "count": count}) + + +# ============================================================================= +# Example 3: Simple Transform with Flow.model +# When the transform is a fixed function, Flow.model works great +# ============================================================================= + +def lookback_7_days(ctx: DateRangeContext) -> DateRangeContext: + """Fixed transform that extends the date range back by 7 days.""" + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=7)}) + + +@Flow.model +def compute_weekly_average( + _: DateRangeContext, + records: Annotated[GenericResult[list], Dep(transform=lookback_7_days)], +) -> GenericResult[float]: + """Compute average using fixed 7-day lookback.""" + values = [r["value"] for r in records.value] + avg = sum(values) / len(values) if values else 0 + print(f" Computing weekly average: {avg:.2f} (from {len(values)} records)") + return GenericResult(value=avg) + + +# ============================================================================= +# Example 4: Class-based CallableModel with Configurable Transform +# When the transform needs access to instance fields (like window size), +# use a class-based approach with auto-resolution +# ============================================================================= + +class ComputeMovingAverage(CallableModel): + """Compute moving average with configurable lookback window. + + This demonstrates: + - Field uses DepOf annotation: accepts either result or CallableModel + - Instance field (window) accessible in __deps__ for custom transforms + - resolve() to access resolved dependency values during __call__ + """ + + records: DepOf[..., GenericResult[list]] + window: int = 7 # Configurable lookback window + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + """Compute the moving average - use resolve() to get resolved value.""" + records = resolve(self.records) # Get the resolved GenericResult + values = [r["value"] for r in records.value] + avg = sum(values) / len(values) if values else 0 + print(f" Computing {self.window}-day moving average: {avg:.2f} (from {len(values)} records)") + return GenericResult(value=avg) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + """Define dependencies with transform that uses self.window.""" + # This is where we can access instance fields! + lookback_ctx = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.records, [lookback_ctx])] + + +# ============================================================================= +# Example 5: Multi-stage pipeline - Composing models together +# ============================================================================= + +@Flow.model +def generate_report( + context: DateRangeContext, + totals: DepOf[..., GenericResult[dict]], + moving_avg: DepOf[..., GenericResult[float]], + report_name: str = "Daily Report", +) -> GenericResult[str]: + """Generate a report combining multiple data sources.""" + report = f""" +{report_name} +{'=' * len(report_name)} +Date Range: {context.start_date} to {context.end_date} +Total Value: {totals.value['total']} +Record Count: {totals.value['count']} +Moving Avg: {moving_avg.value:.2f} +""" + return GenericResult(value=report.strip()) + + +# ============================================================================= +# Example 6: Using context_args for cleaner signatures +# ============================================================================= + +@Flow.model(context_args=["start_date", "end_date"]) +def fetch_metadata(start_date: date, end_date: date, category: str) -> GenericResult[dict]: + """Fetch metadata - note how start_date/end_date are direct parameters.""" + print(f" Fetching metadata for '{category}' from {start_date} to {end_date}") + return GenericResult(value={ + "category": category, + "days": (end_date - start_date).days, + "generated_at": str(date.today()), + }) + + +# ============================================================================= +# Main: Build and execute the pipeline +# ============================================================================= + +def main(): + print("=" * 60) + print("Flow.model Example - Simplified CallableModel Creation") + print("=" * 60) + + ctx = DateRangeContext( + start_date=date(2024, 1, 15), + end_date=date(2024, 1, 31) + ) + + # --- Example 1: Basic model --- + print("\n[1] Basic Flow.model:") + loader = load_records(source="main_db", limit=5) + result = loader(ctx) + print(f" Result: {result.value}") + + # --- Example 2: Simple dependency chain --- + print("\n[2] Dependency chain (loader -> totals):") + loader = load_records(source="main_db") + totals = compute_totals(records=loader) + result = totals(ctx) + print(f" Result: {result.value}") + + # --- Example 3: Fixed transform with Flow.model --- + print("\n[3] Fixed transform (7-day lookback with Flow.model):") + loader = load_records(source="main_db") + weekly_avg = compute_weekly_average(records=loader) + result = weekly_avg(ctx) + print(f" Result: {result.value}") + + # --- Example 4: Configurable transform with class-based model --- + print("\n[4] Configurable transform (class-based with auto-resolution):") + loader = load_records(source="main_db") + + # 14-day window + moving_avg_14 = ComputeMovingAverage(records=loader, window=14) + result = moving_avg_14(ctx) + print(f" 14-day result: {result.value}") + + # 30-day window - same loader, different window + moving_avg_30 = ComputeMovingAverage(records=loader, window=30) + result = moving_avg_30(ctx) + print(f" 30-day result: {result.value}") + + # --- Example 5: Full pipeline --- + print("\n[5] Full pipeline (mixing Flow.model and class-based):") + loader = load_records(source="analytics_db") + totals = compute_totals(records=loader) + moving_avg = ComputeMovingAverage(records=loader, window=7) + report = generate_report( + totals=totals, + moving_avg=moving_avg, + report_name="Analytics Summary" + ) + result = report(ctx) + print(result.value) + + # --- Example 6: context_args --- + print("\n[6] Using context_args (auto-unpacked context):") + metadata = fetch_metadata(category="sales") + result = metadata(ctx) + print(f" Result: {result.value}") + + # --- Bonus: Inspecting models --- + print("\n[Bonus] Inspecting models:") + print(f" load_records.context_type = {loader.context_type.__name__}") + print(f" ComputeMovingAverage uses __deps__ for custom transforms") + deps = moving_avg.__deps__(ctx) + for dep_model, dep_contexts in deps: + print(f" - Dependency context start: {dep_contexts[0].start_date} (lookback applied)") + + +if __name__ == "__main__": + main()