diff --git a/ccflow/base.py b/ccflow/base.py index fe1fd82..30df80c 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -2,17 +2,14 @@ import collections.abc import copy -import inspect import logging import pathlib -import platform import sys import warnings -from types import GenericAlias, MappingProxyType -from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from types import MappingProxyType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar from omegaconf import DictConfig -from packaging import version from pydantic import ( BaseModel as PydanticBaseModel, ConfigDict, @@ -88,66 +85,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None) return deps -# Pydantic 2 has different handling of serialization. -# This requires some workarounds at the moment until the feature is added to easily get a mode that -# is compatible with Pydantic 1 -# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel, -# such that the new annotation contains SerializeAsAny -# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing -# https://github.com/pydantic/pydantic/issues/6423 -# https://github.com/pydantic/pydantic-core/pull/740 -# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation -# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478 -from pydantic._internal._model_construction import ModelMetaclass # noqa: E402 - -_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10") - - -def _adjust_annotations(annotation): - origin = get_origin(annotation) - args = get_args(annotation) - if not _IS_PY39: - from types import UnionType - - if origin is UnionType: - origin = Union - - if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)): - return SerializeAsAny[annotation] - elif origin and args: - # Filter out typing.Type and generic types - if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)): - return annotation - elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39 - return ClassVar[_adjust_annotations(args[0])] - else: - try: - return origin[tuple(_adjust_annotations(arg) for arg in args)] - except TypeError: - raise TypeError(f"Could not adjust annotations for {origin}") - else: - return annotation - - -class _SerializeAsAnyMeta(ModelMetaclass): - def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs): - annotations: dict = namespaces.get("__annotations__", {}) - - for base in bases: - for base_ in base.__mro__: - if base_ is PydanticBaseModel: - annotations.update(base_.__annotations__) - - for field, annotation in annotations.items(): - if not field.startswith("__"): - annotations[field] = _adjust_annotations(annotation) - - namespaces["__annotations__"] = annotations - - return super().__new__(self, name, bases, namespaces, **kwargs) - - -class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta): +class BaseModel(PydanticBaseModel, _RegistryMixin): """BaseModel is a base class for all pydantic models within the cubist flow framework. This gives us a way to add functionality to the framework, including @@ -179,6 +117,8 @@ def type_(self) -> PyObjectPath: # where the default behavior is just to drop the mis-named value. This prevents that extra="forbid", ser_json_timedelta="float", + # Polymorphic serialization is the behavior of allowing a subclass of a model (or Pydantic dataclass) to override serialization so that the subclass' serialization is used, rather than the original model types's serialization. This will expose all the data defined on the subclass in the serialized payload. + polymorphic_serialization=True, ) def __str__(self): diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..6183148 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -17,7 +17,16 @@ from inspect import Signature, isclass, signature from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin -from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator +from pydantic import ( + BaseModel as PydanticBaseModel, + ConfigDict, + Field, + InstanceOf, + PrivateAttr, + TypeAdapter, + field_validator, + model_validator, +) from typing_extensions import override from .base import ( diff --git a/ccflow/tests/test_base_serialize.py b/ccflow/tests/test_base_serialize.py index fbec5c2..69cc0c8 100644 --- a/ccflow/tests/test_base_serialize.py +++ b/ccflow/tests/test_base_serialize.py @@ -1,10 +1,8 @@ import pickle -import platform import unittest -from typing import Annotated, ClassVar, Dict, List, Optional, Type, Union +from typing import Annotated, Optional import numpy as np -from packaging import version from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, ValidationError from ccflow import BaseModel, NDArray @@ -213,45 +211,6 @@ class C(PydanticBaseModel): # C implements the normal pydantic BaseModel whichhould allow extra fields. _ = C(extra_field1=1) - def test_serialize_as_any(self): - # https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing - # https://github.com/pydantic/pydantic/issues/6423 - # This test could be removed once there is a different solution to the issue above - from pydantic import SerializeAsAny - from pydantic.types import constr - - if version.parse(platform.python_version()) >= version.parse("3.10"): - pipe_union = A | int - else: - pipe_union = Union[A, int] - - class MyNestedModel(BaseModel): - a1: A - a2: Optional[Union[A, int]] - a3: Dict[str, Optional[List[A]]] - a4: ClassVar[A] - a5: Type[A] - a6: constr(min_length=1) - a7: pipe_union - - target = { - "a1": SerializeAsAny[A], - "a2": Optional[Union[SerializeAsAny[A], int]], - "a4": ClassVar[SerializeAsAny[A]], - "a5": Type[A], - "a6": constr(min_length=1), # Uses Annotation - "a7": Union[SerializeAsAny[A], int], - } - target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]] - annotations = MyNestedModel.__annotations__ - self.assertEqual(str(annotations["a1"]), str(target["a1"])) - self.assertEqual(str(annotations["a2"]), str(target["a2"])) - self.assertEqual(str(annotations["a3"]), str(target["a3"])) - self.assertEqual(str(annotations["a4"]), str(target["a4"])) - self.assertEqual(str(annotations["a5"]), str(target["a5"])) - self.assertEqual(str(annotations["a6"]), str(target["a6"])) - self.assertEqual(str(annotations["a7"]), str(target["a7"])) - def test_pickle_consistency(self): model = MultiAttributeModel(z=1, y="test", x=3.14, w=True) serialized = pickle.dumps(model) diff --git a/ccflow/tests/test_evaluation_context_serialization.py b/ccflow/tests/test_evaluation_context_serialization.py new file mode 100644 index 0000000..cc6abb6 --- /dev/null +++ b/ccflow/tests/test_evaluation_context_serialization.py @@ -0,0 +1,79 @@ +import json +from datetime import date + +from ccflow import DateContext +from ccflow.callable import ModelEvaluationContext +from ccflow.evaluators import GraphEvaluator, LoggingEvaluator, MultiEvaluator +from ccflow.tests.evaluators.util import NodeModel + + +def _make_nested_mec(model): + ctx = DateContext(date=date(2022, 1, 1)) + mec = model.__call__.get_evaluation_context(model, ctx) + assert isinstance(mec, ModelEvaluationContext) + # ensure nested: outer model is an evaluator, inner is a ModelEvaluationContext + assert isinstance(mec.context, ModelEvaluationContext) + return mec + + +def test_mec_model_dump_basic(): + m = NodeModel() + mec = _make_nested_mec(m) + + d = mec.model_dump() + assert isinstance(d, dict) + assert "fn" in d and "model" in d and "context" in d and "options" in d + + s = mec.model_dump_json() + parsed = json.loads(s) + assert parsed["fn"] == d["fn"] + # Also verify mode-specific dumps + d_py = mec.model_dump(mode="python") + assert isinstance(d_py, dict) + d_json = mec.model_dump(mode="json") + assert isinstance(d_json, dict) + json.dumps(d_json) + + +def test_mec_model_dump_diamond_graph(): + n0 = NodeModel() + n1 = NodeModel(deps_model=[n0]) + n2 = NodeModel(deps_model=[n0]) + root = NodeModel(deps_model=[n1, n2]) + + mec = _make_nested_mec(root) + + d = mec.model_dump() + assert isinstance(d, dict) + assert set(["fn", "model", "context", "options"]).issubset(d.keys()) + + s = mec.model_dump_json() + json.loads(s) + # verify mode dumps + d_py = mec.model_dump(mode="python") + assert isinstance(d_py, dict) + d_json = mec.model_dump(mode="json") + assert isinstance(d_json, dict) + json.dumps(d_json) + + +def test_mec_model_dump_with_multi_evaluator(): + m = NodeModel() + _ = LoggingEvaluator() # ensure import/validation + evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(), GraphEvaluator()]) + + # Simulate how Flow builds evaluation context with a custom evaluator + ctx = DateContext(date=date(2022, 1, 1)) + mec = ModelEvaluationContext(model=evaluator, context=m.__call__.get_evaluation_context(m, ctx)) + + d = mec.model_dump() + assert isinstance(d, dict) + assert "fn" in d and "model" in d and "context" in d + s = mec.model_dump_json() + json.loads(s) + # verify mode dumps + d_py = mec.model_dump(mode="python") + assert isinstance(d_py, dict) + d_json = mec.model_dump(mode="json") + assert isinstance(d_json, dict) + json.dumps(d_json) diff --git a/pyproject.toml b/pyproject.toml index 6b7f712..5233ff8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "orjson", "pandas", "pyarrow", - "pydantic>=2.6,<3", + "pydantic>=2.13,<3", "smart_open", "tenacity", ]