Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 11 additions & 65 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@

import collections.abc
import copy
import inspect
import logging
import pathlib
import platform
import sys
import warnings
from types import GenericAlias, MappingProxyType
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar

from omegaconf import DictConfig
from packaging import version
from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Expand Down Expand Up @@ -89,66 +87,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None)
return deps


# Pydantic 2 has different handling of serialization.
# This requires some workarounds at the moment until the feature is added to easily get a mode that
# is compatible with Pydantic 1
# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel,
# such that the new annotation contains SerializeAsAny
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# https://github.com/pydantic/pydantic-core/pull/740
# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation
# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478
from pydantic._internal._model_construction import ModelMetaclass # noqa: E402

_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10")


def _adjust_annotations(annotation):
origin = get_origin(annotation)
args = get_args(annotation)
if not _IS_PY39:
from types import UnionType

if origin is UnionType:
origin = Union

if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)):
return SerializeAsAny[annotation]
elif origin and args:
# Filter out typing.Type and generic types
if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)):
return annotation
elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39
return ClassVar[_adjust_annotations(args[0])]
else:
try:
return origin[tuple(_adjust_annotations(arg) for arg in args)]
except TypeError:
raise TypeError(f"Could not adjust annotations for {origin}")
else:
return annotation


class _SerializeAsAnyMeta(ModelMetaclass):
def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs):
annotations: dict = namespaces.get("__annotations__", {})

for base in bases:
for base_ in base.__mro__:
if base_ is PydanticBaseModel:
annotations.update(base_.__annotations__)

for field, annotation in annotations.items():
if not field.startswith("__"):
annotations[field] = _adjust_annotations(annotation)

namespaces["__annotations__"] = annotations

return super().__new__(self, name, bases, namespaces, **kwargs)


class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta):
class BaseModel(PydanticBaseModel, _RegistryMixin):
"""BaseModel is a base class for all pydantic models within the cubist flow framework.

This gives us a way to add functionality to the framework, including
Expand Down Expand Up @@ -182,6 +121,13 @@ def type_(self) -> PyObjectPath:
ser_json_timedelta="float",
)

# https://docs.pydantic.dev/latest/concepts/serialization/#overriding-the-serialize_as_any-default-false
def model_dump(self, **kwargs) -> dict[str, Any]:
return super().model_dump(serialize_as_any=True, **kwargs)

def model_dump_json(self, **kwargs) -> str:
return super().model_dump_json(serialize_as_any=True, **kwargs)

def __str__(self):
# Because the standard string representation does not include class name
return repr(self)
Expand Down Expand Up @@ -251,7 +197,7 @@ def _base_model_validator(cls, v, handler, info):

if isinstance(v, PydanticBaseModel):
# Coerce from one BaseModel type to another (because it worked automatically in v1)
v = v.model_dump(exclude={"type_"})
v = v.model_dump(serialize_as_any=True, exclude={"type_"})

return handler(v)

Expand Down
43 changes: 1 addition & 42 deletions ccflow/tests/test_base_serialize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pickle
import platform
import unittest
from typing import Annotated, ClassVar, Dict, List, Optional, Type, Union
from typing import Annotated, Optional

import numpy as np
from packaging import version
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, ValidationError

from ccflow import BaseModel, NDArray
Expand Down Expand Up @@ -213,45 +211,6 @@ class C(PydanticBaseModel):
# C implements the normal pydantic BaseModel whichhould allow extra fields.
_ = C(extra_field1=1)

def test_serialize_as_any(self):
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# This test could be removed once there is a different solution to the issue above
from pydantic import SerializeAsAny
from pydantic.types import constr

if version.parse(platform.python_version()) >= version.parse("3.10"):
pipe_union = A | int
else:
pipe_union = Union[A, int]

class MyNestedModel(BaseModel):
a1: A
a2: Optional[Union[A, int]]
a3: Dict[str, Optional[List[A]]]
a4: ClassVar[A]
a5: Type[A]
a6: constr(min_length=1)
a7: pipe_union

target = {
"a1": SerializeAsAny[A],
"a2": Optional[Union[SerializeAsAny[A], int]],
"a4": ClassVar[SerializeAsAny[A]],
"a5": Type[A],
"a6": constr(min_length=1), # Uses Annotation
"a7": Union[SerializeAsAny[A], int],
}
target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]]
annotations = MyNestedModel.__annotations__
self.assertEqual(str(annotations["a1"]), str(target["a1"]))
self.assertEqual(str(annotations["a2"]), str(target["a2"]))
self.assertEqual(str(annotations["a3"]), str(target["a3"]))
self.assertEqual(str(annotations["a4"]), str(target["a4"]))
self.assertEqual(str(annotations["a5"]), str(target["a5"]))
self.assertEqual(str(annotations["a6"]), str(target["a6"]))
self.assertEqual(str(annotations["a7"]), str(target["a7"]))

def test_pickle_consistency(self):
model = MultiAttributeModel(z=1, y="test", x=3.14, w=True)
serialized = pickle.dumps(model)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"orjson",
"pandas",
"pyarrow",
"pydantic>=2.6,<3",
"pydantic>=2.35,<3",
"smart_open",
"tenacity",
]
Expand Down
Loading