Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ccflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from .compose import *
from .callable import *
from .context import *
from .dep import *
from .enums import Enum
from .flow_model import FlowAPI, BoundModel, Lazy
from .global_state import *
from .local_persistence import *
from .models import *
Expand Down
281 changes: 279 additions & 2 deletions ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import abc
import inspect
import logging
from contextvars import ContextVar
from functools import lru_cache, wraps
from inspect import Signature, isclass, signature
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
Expand All @@ -28,6 +29,7 @@
ResultBase,
ResultType,
)
from .dep import Dep, extract_dep
from .local_persistence import create_ccflow_model
from .validators import str_to_log_level

Expand All @@ -46,6 +48,8 @@
"EvaluatorBase",
"Evaluator",
"WrapperModel",
# Note: resolve() is intentionally not in __all__ to avoid namespace pollution.
# Users who need it can import explicitly: from ccflow.callable import resolve
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -128,7 +132,7 @@ def _check_result_type(cls, result_type):
@model_validator(mode="after")
def _check_signature(self):
sig_call = _cached_signature(self.__class__.__call__)
if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context")
if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters:
raise ValueError("__call__ method must take a single argument, named 'context'")

sig_deps = _cached_signature(self.__class__.__deps__)
Expand Down Expand Up @@ -195,6 +199,176 @@ def _get_logging_evaluator(log_level):
return LoggingEvaluator(log_level=log_level)


def _get_dep_fields(model_class) -> Dict[str, Dep]:
"""Analyze class fields to find Dep-annotated fields.

Returns a dict mapping field name to Dep instance for fields that need resolution.
"""
dep_fields = {}

# Get type hints from the class
hints = {}
for cls in model_class.__mro__:
if hasattr(cls, "__annotations__"):
for name, annotation in cls.__annotations__.items():
if name not in hints: # Don't override child class annotations
hints[name] = annotation

for name, annotation in hints.items():
base_type, dep = extract_dep(annotation)
if dep is not None:
dep_fields[name] = dep

return dep_fields


def _wrap_with_dep_resolution(fn):
"""Wrap a function to auto-resolve DepOf fields before calling.

For each Dep-annotated field on the model that contains a CallableModel,
resolves it using __deps__ and temporarily sets the resolved value on self.

Note: This wrapper is only applied at runtime when the function is called,
not during decoration. This avoids issues with functools.wraps flattening
the __wrapped__ chain.

Args:
fn: The original function

Returns:
The original function unchanged - dep resolution happens at the call site
"""
# Don't modify the function - dep resolution is handled in ModelEvaluationContext
return fn


# Context variable for storing resolved dependency values during __call__
# Maps id(callable_model) -> resolved_value
_resolved_deps: ContextVar[Dict[int, Any]] = ContextVar("resolved_deps", default={})

# TypeVar for resolve() function to enable proper type inference
_T = TypeVar("_T")


def resolve(dep: Union[_T, "_CallableModel"]) -> _T:
"""Access the resolved value of a DepOf dependency during __call__.

This function is used inside a CallableModel's __call__ method to get
the resolved value of a dependency field. It provides proper type inference -
if the field is `DepOf[..., GenericResult[int]]`, this returns `GenericResult[int]`.

Args:
dep: The dependency field value (either a CallableModel or already-resolved value)

Returns:
The resolved value. If dep is already a resolved value (not a CallableModel),
returns it unchanged.

Raises:
RuntimeError: If called outside of __call__ or if the dependency wasn't resolved.

Example:
class MyModel(CallableModel):
data: DepOf[..., GenericResult[int]]

@Flow.call
def __call__(self, context: MyContext) -> GenericResult[int]:
# resolve() provides proper type inference
data = resolve(self.data) # type: GenericResult[int]
return GenericResult(value=data.value + 1)
"""
# If it's not a CallableModel, it's already a resolved value - pass through
if not isinstance(dep, _CallableModel):
return dep # type: ignore[return-value]

# Look up in context var
store = _resolved_deps.get()
dep_id = id(dep)
if dep_id not in store:
raise RuntimeError(
"resolve() can only be used inside __call__ for DepOf fields. Make sure the field is annotated with DepOf and contains a CallableModel."
)
return store[dep_id]


def _resolve_deps_and_call(model, context, fn):
"""Resolve DepOf fields and call the function.

This is called from ModelEvaluationContext.__call__ to handle dep resolution.
Resolved values are stored in a context variable and accessed via resolve().

Args:
model: The CallableModel instance
context: The context to pass to the function
fn: The function to call

Returns:
The result of calling fn(model, context)
"""
# Don't resolve deps for __deps__ method
if fn.__name__ == "__deps__":
return fn(model, context)

# Get Dep-annotated fields for this model class
dep_fields = _get_dep_fields(model.__class__)

# Check if model has custom deps (from @func.deps decorator)
has_custom_deps = getattr(model.__class__, "__has_custom_deps__", False)

if not dep_fields and not has_custom_deps:
return fn(model, context)

# Get dependencies from __deps__
deps_result = model.__deps__(context)
# Build a map from model instance id to (model, contexts) for lookup
dep_map = {}
for dep_model, contexts in deps_result:
dep_map[id(dep_model)] = (dep_model, contexts)

# Resolve dependencies and store in context var
resolved_values = {}

# If custom deps, resolve ALL CallableModel fields from dep_map
if has_custom_deps:
for dep_model, contexts in deps_result:
resolved = dep_model(contexts[0]) if contexts else dep_model(context)
# Unwrap GenericResult if present (consistent with auto-detected deps)
if hasattr(resolved, 'value'):
resolved = resolved.value
resolved_values[id(dep_model)] = resolved
else:
# Standard path: iterate over Dep-annotated fields
for field_name, dep in dep_fields.items():
field_value = getattr(model, field_name, None)
if field_value is None:
continue

# Check if field is a CallableModel that needs resolution
if not isinstance(field_value, _CallableModel):
continue # Already a resolved value, skip

# Check if this field is in __deps__ (for custom transforms)
if id(field_value) in dep_map:
dep_model, contexts = dep_map[id(field_value)]
# Call dependency with the (transformed) context
resolved = dep_model(contexts[0]) if contexts else dep_model(context)
else:
# Not in __deps__, use Dep annotation transform directly
transformed_ctx = dep.apply(context)
resolved = field_value(transformed_ctx)

resolved_values[id(field_value)] = resolved

# Store in context var and call function
current_store = _resolved_deps.get()
new_store = {**current_store, **resolved_values}
token = _resolved_deps.set(new_store)
try:
return fn(model, context)
finally:
_resolved_deps.reset(token)


class FlowOptions(BaseModel):
"""Options for Flow evaluation.

Expand Down Expand Up @@ -246,6 +420,9 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase":
return self._get_evaluator_from_options(options)

def __call__(self, fn):
# Wrap function with dependency resolution for DepOf fields
fn = _wrap_with_dep_resolution(fn)

# Used for building a graph of model evaluation contexts without evaluating
def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None):
# Create the evaluation context.
Expand Down Expand Up @@ -451,6 +628,33 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:

# The generated context inherits from DateContext, so it's compatible
# with infrastructure expecting DateContext instances.

Auto-Resolve Dependencies Example:
When __call__ has parameters beyond 'self' and 'context' that match field
names annotated with DepOf/Dep, those dependencies are automatically resolved
using __deps__ (if defined) or auto-generated from Dep annotations.

class MyModel(CallableModel):
data: Annotated[GenericResult[dict], Dep(transform=my_transform)]

@Flow.call
def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]:
# data is automatically resolved - no manual calling needed
return GenericResult(value=process(data.value))

For transforms that need access to instance fields, define __deps__ manually:

class MyModel(CallableModel):
data: DepOf[..., GenericResult[dict]]
window: int = 7

def __deps__(self, context):
# Can access self.window here
return [(self.data, [context.with_lookback(self.window)])]

@Flow.call
def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]:
return GenericResult(value=process(data.value))
"""
# Extract auto_context option (not part of FlowOptions)
# Can be: False, True, or a ContextBase subclass
Expand Down Expand Up @@ -502,6 +706,78 @@ def deps(*args, **kwargs):
# Note that the code below is executed only once
return FlowOptionsDeps(**kwargs)

@staticmethod
def model(*args, **kwargs):
"""Decorator that generates a CallableModel class from a plain Python function.

This is syntactic sugar over CallableModel. The decorator generates a real
CallableModel class with proper __call__ and __deps__ methods, so all existing
features (caching, evaluation, registry, serialization) work unchanged.

Args:
context_args: List of parameter names that come from context (for unpacked mode)
cacheable: Enable caching of results (default: False)
volatile: Mark as volatile (default: False)
log_level: Logging verbosity (default: logging.DEBUG)
validate_result: Validate return type (default: True)
verbose: Verbose logging output (default: True)
evaluator: Custom evaluator (default: None)

Two Context Modes:

Mode 1 - Explicit context parameter:
Function has a 'context' parameter annotated with a ContextBase subclass.

@Flow.model
def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]:
return GenericResult(value=query_db(source, context.start_date, context.end_date))

Mode 2 - Unpacked context_args:
Context fields are unpacked into function parameters.

@Flow.model(context_args=["start_date", "end_date"])
def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]:
return GenericResult(value=query_db(source, start_date, end_date))

Dependencies:
Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies:

from ccflow import Dep, DepOf
from typing import Annotated

@Flow.model
def compute_returns(
context: DateRangeContext,
prices: Annotated[GenericResult[pl.DataFrame], Dep(
transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)})
)]
) -> GenericResult[pl.DataFrame]:
return GenericResult(value=prices.value.pct_change())

# Or use DepOf shorthand for no transform:
@Flow.model
def compute_stats(
context: DateRangeContext,
data: DepOf[..., GenericResult[pl.DataFrame]]
) -> GenericResult[pl.DataFrame]:
return GenericResult(value=data.value.describe())

Usage:
# Create model instances
loader = load_prices(source="prod_db")
returns = compute_returns(prices=loader)

# Execute
ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))
result = returns(ctx)

Returns:
A factory function that creates CallableModel instances
"""
from .flow_model import flow_model

return flow_model(*args, **kwargs)


# *****************************************************************************
# Define "Evaluators" and associated types
Expand Down Expand Up @@ -555,7 +831,8 @@ def _context_validator(cls, values, handler, info):
def __call__(self) -> ResultType:
fn = getattr(self.model, self.fn)
if hasattr(fn, "__wrapped__"):
result = fn.__wrapped__(self.model, self.context)
# Call through _resolve_deps_and_call to handle DepOf field resolution
result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__)
# If it's a callable model, then we can validate the result
if self.options.get("validate_result", True):
if fn.__name__ == "__deps__":
Expand Down
Loading
Loading