diff --git a/ccflow/base.py b/ccflow/base.py index 8724468..b7be600 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -2,17 +2,15 @@ import collections.abc import copy -import inspect import logging import pathlib import platform import sys import warnings -from types import GenericAlias, MappingProxyType -from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from types import MappingProxyType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar from omegaconf import DictConfig -from packaging import version from pydantic import ( BaseModel as PydanticBaseModel, ConfigDict, @@ -89,66 +87,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None) return deps -# Pydantic 2 has different handling of serialization. -# This requires some workarounds at the moment until the feature is added to easily get a mode that -# is compatible with Pydantic 1 -# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel, -# such that the new annotation contains SerializeAsAny -# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing -# https://github.com/pydantic/pydantic/issues/6423 -# https://github.com/pydantic/pydantic-core/pull/740 -# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation -# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478 -from pydantic._internal._model_construction import ModelMetaclass # noqa: E402 - -_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10") - - -def _adjust_annotations(annotation): - origin = get_origin(annotation) - args = get_args(annotation) - if not _IS_PY39: - from types import UnionType - - if origin is UnionType: - origin = Union - - if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)): - return SerializeAsAny[annotation] - elif origin and args: - # Filter out typing.Type and generic types - if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)): - return annotation - elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39 - return ClassVar[_adjust_annotations(args[0])] - else: - try: - return origin[tuple(_adjust_annotations(arg) for arg in args)] - except TypeError: - raise TypeError(f"Could not adjust annotations for {origin}") - else: - return annotation - - -class _SerializeAsAnyMeta(ModelMetaclass): - def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs): - annotations: dict = namespaces.get("__annotations__", {}) - - for base in bases: - for base_ in base.__mro__: - if base_ is PydanticBaseModel: - annotations.update(base_.__annotations__) - - for field, annotation in annotations.items(): - if not field.startswith("__"): - annotations[field] = _adjust_annotations(annotation) - - namespaces["__annotations__"] = annotations - - return super().__new__(self, name, bases, namespaces, **kwargs) - - -class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta): +class BaseModel(PydanticBaseModel, _RegistryMixin): """BaseModel is a base class for all pydantic models within the cubist flow framework. This gives us a way to add functionality to the framework, including @@ -182,6 +121,13 @@ def type_(self) -> PyObjectPath: ser_json_timedelta="float", ) + # https://docs.pydantic.dev/latest/concepts/serialization/#overriding-the-serialize_as_any-default-false + def model_dump(self, **kwargs) -> dict[str, Any]: + return super().model_dump(serialize_as_any=True, **kwargs) + + def model_dump_json(self, **kwargs) -> str: + return super().model_dump_json(serialize_as_any=True, **kwargs) + def __str__(self): # Because the standard string representation does not include class name return repr(self) @@ -251,7 +197,7 @@ def _base_model_validator(cls, v, handler, info): if isinstance(v, PydanticBaseModel): # Coerce from one BaseModel type to another (because it worked automatically in v1) - v = v.model_dump(exclude={"type_"}) + v = v.model_dump(serialize_as_any=True, exclude={"type_"}) return handler(v) diff --git a/ccflow/tests/test_base_serialize.py b/ccflow/tests/test_base_serialize.py index fbec5c2..69cc0c8 100644 --- a/ccflow/tests/test_base_serialize.py +++ b/ccflow/tests/test_base_serialize.py @@ -1,10 +1,8 @@ import pickle -import platform import unittest -from typing import Annotated, ClassVar, Dict, List, Optional, Type, Union +from typing import Annotated, Optional import numpy as np -from packaging import version from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, ValidationError from ccflow import BaseModel, NDArray @@ -213,45 +211,6 @@ class C(PydanticBaseModel): # C implements the normal pydantic BaseModel whichhould allow extra fields. _ = C(extra_field1=1) - def test_serialize_as_any(self): - # https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing - # https://github.com/pydantic/pydantic/issues/6423 - # This test could be removed once there is a different solution to the issue above - from pydantic import SerializeAsAny - from pydantic.types import constr - - if version.parse(platform.python_version()) >= version.parse("3.10"): - pipe_union = A | int - else: - pipe_union = Union[A, int] - - class MyNestedModel(BaseModel): - a1: A - a2: Optional[Union[A, int]] - a3: Dict[str, Optional[List[A]]] - a4: ClassVar[A] - a5: Type[A] - a6: constr(min_length=1) - a7: pipe_union - - target = { - "a1": SerializeAsAny[A], - "a2": Optional[Union[SerializeAsAny[A], int]], - "a4": ClassVar[SerializeAsAny[A]], - "a5": Type[A], - "a6": constr(min_length=1), # Uses Annotation - "a7": Union[SerializeAsAny[A], int], - } - target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]] - annotations = MyNestedModel.__annotations__ - self.assertEqual(str(annotations["a1"]), str(target["a1"])) - self.assertEqual(str(annotations["a2"]), str(target["a2"])) - self.assertEqual(str(annotations["a3"]), str(target["a3"])) - self.assertEqual(str(annotations["a4"]), str(target["a4"])) - self.assertEqual(str(annotations["a5"]), str(target["a5"])) - self.assertEqual(str(annotations["a6"]), str(target["a6"])) - self.assertEqual(str(annotations["a7"]), str(target["a7"])) - def test_pickle_consistency(self): model = MultiAttributeModel(z=1, y="test", x=3.14, w=True) serialized = pickle.dumps(model) diff --git a/pyproject.toml b/pyproject.toml index e3e8503..9310dd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "orjson", "pandas", "pyarrow", - "pydantic>=2.6,<3", + "pydantic>=2.35,<3", "smart_open", "tenacity", ]