Skip to content

Commit 089d6dc

Browse files
authored
Embed __tname__ in models, support using it in model_validate (#918)
Gel supports inheritance much more than pydantic does, and so when we have a link to a parent type, we need to be able to determine which of its potential children type to create. We do this by adding a `tname__` field to every generated class, renamed to `__tname__` via the alias system, and update `__gel_validate__` (which currently just directly calls `model_validate`) to discriminate on it. The hard part here is handling the case where there are linkprops, which are represented by per-link subclasses of ProxyModel. To deal with this, we dynamically generate subclasses of the link class. This mechanism also allows us to return correct types from the client for links with link properties that point to subtypes, which we previously weren't. The other potential option was to use pydantic's discriminated unions, either by directly generating them in the source or dynamically mocking them up when creating models. Getting this to play nicely with everything we do seemed pretty nasty. (The linkprop situation in particular scares me more.) The big potential advantage of doing it that way, instead of having our own custom callbacks, is that it might allow us to not lose track of the parameters of validation (like strict mode and whether we are in JSON mode). But it turns out we were already losing track of that due to various ways we were interposing ourselves in validation, so maybe it's fine. We can maybe fix this with contextvars later. (Though the json situation is rough.) Fixes #755.
1 parent 4b52b71 commit 089d6dc

File tree

14 files changed

+334
-45
lines changed

14 files changed

+334
-45
lines changed

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4716,6 +4716,13 @@ def write_object_type(
47164716
):
47174717
self.write(f'"""type {objtype.name}"""')
47184718
self.write()
4719+
literal = self.import_name("typing", "Literal")
4720+
field = self.import_name("pydantic", "Field")
4721+
self._write_model_attribute(
4722+
"tname__",
4723+
f'{literal}["{type_name}"] = {field}('
4724+
f'"{type_name}", alias="__tname__")',
4725+
)
47194726
pointers = _get_object_type_body(objtype)
47204727
if pointers:
47214728
localns = frozenset(ptr.name for ptr in pointers)

gel/_internal/_qbmodel/_abstract/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
DefaultValue,
1616
GelType,
1717
GelTypeMeta,
18+
LITERAL_TAG_FIELDS,
1819
)
1920

2021
from ._descriptors import (
@@ -110,6 +111,7 @@
110111

111112
__all__ = (
112113
"DEFAULT_VALUE",
114+
"LITERAL_TAG_FIELDS",
113115
"MODEL_SUBSTRATE_MODULE",
114116
"AbstractGelLinkModel",
115117
"AbstractGelModel",

gel/_internal/_qbmodel/_abstract/_base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
T_co = TypeVar("T_co", covariant=True)
3636

3737

38+
LITERAL_TAG_FIELDS = ('tname__',)
39+
40+
3841
if TYPE_CHECKING:
3942

4043
class GelTypeMeta(abc.ABCMeta):
@@ -162,8 +165,13 @@ def __new__(
162165
super().__new__(mcls, name, bases, namespace, **kwargs),
163166
)
164167
reflection = cls.__gel_reflection__
165-
if (tname := getattr(reflection, "name", None)) is not None:
166-
mcls.__gel_class_registry__[tname] = cls
168+
if (
169+
# The class registry only tracks the canonical base instances,
170+
# which are the ones that directly declare 'tname__'
171+
(tname := getattr(reflection, "name", None)) is not None
172+
and 'tname__' in namespace
173+
):
174+
mcls.__gel_class_registry__[str(tname)] = cls
167175
cls.__gel_shape__ = __gel_shape__
168176
return cls
169177

gel/_internal/_qbmodel/_abstract/_descriptors.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
AbstractGelLinkModel,
3737
is_gel_type,
3838
maybe_collapse_object_type_variant_union,
39+
LITERAL_TAG_FIELDS,
3940
)
4041

4142

@@ -116,7 +117,9 @@ def _try_resolve_type(self) -> Any:
116117
if (
117118
t is not None
118119
and _typing_inspect.is_generic_alias(t)
119-
and issubclass(typing.get_origin(t), PointerDescriptor)
120+
and (origin := typing.get_origin(t))
121+
and isinstance(origin, type)
122+
and issubclass(origin, PointerDescriptor)
120123
):
121124
self.__gel_resolved_descriptor__ = t
122125
t = typing.get_args(t)[0]
@@ -127,12 +130,15 @@ def _try_resolve_type(self) -> Any:
127130
if collapsed is not None:
128131
t = collapsed
129132

130-
if not is_gel_type(t):
133+
if (
134+
not is_gel_type(t)
135+
and self.__gel_name__ not in LITERAL_TAG_FIELDS
136+
):
131137
raise AssertionError(
132138
f"{self._fqname} type argument is not a GelType: {t}"
133139
)
134140

135-
self.__gel_resolved_type__ = t
141+
self.__gel_resolved_type__ = cast('type[GelType]', t)
136142

137143
return t
138144

@@ -606,6 +612,9 @@ def get(
606612

607613
class AbstractGelProxyModel(AbstractGelModel, Generic[_MT_co, _LM_co]):
608614
__linkprops__: GelLinkModelDescriptor[_LM_co]
615+
__gel_dynamic_proxy_base__: ClassVar[
616+
type[AbstractGelProxyModel[Any, Any]] | None
617+
] = None
609618

610619
if TYPE_CHECKING:
611620
_p__obj__: _MT_co
@@ -828,7 +837,7 @@ def proxy_link(
828837
) -> AbstractGelProxyModel[_MT_co, _LM_co]:
829838
tp_new = type(new)
830839

831-
if tp_new is proxy_type:
840+
if tp_new is proxy_type or issubclass(tp_new, proxy_type):
832841
# Fast path for the same proxy type.
833842

834843
new_proxy = cast("AbstractGelProxyModel[_MT_co, _LM_co]", new)

gel/_internal/_qbmodel/_pydantic/_fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def _validate(
436436
)
437437

438438
# defer to Pydantic
439-
return mt.model_validate(value)
439+
return mt.__gel_validate__(value) # type: ignore[return-value]
440440

441441

442442
class _AnyLinkWithProps(Generic[_PT_co, _BMT_co]):

gel/_internal/_qbmodel/_pydantic/_models.py

Lines changed: 119 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Self,
1919
)
2020

21+
import copyreg
2122
import inspect
2223
import itertools
2324
import types
@@ -395,7 +396,11 @@ def _process_pydantic_fields(
395396

396397
for fn, field in fields.items():
397398
ptr = cls.__gel_reflection__.pointers.get(fn)
398-
if (ptr is None or ptr.computed) and fn != "__linkprops__":
399+
if (
400+
(ptr is None or ptr.computed)
401+
and fn != "__linkprops__"
402+
and fn not in _abstract.LITERAL_TAG_FIELDS
403+
):
399404
# Regarding `fn != "__linkprops__"`: see MergedModelMeta --
400405
# it renames `linkprops____` to `__linkprops__` to circumvent
401406
# Pydantic's restriction on fields starting with `_`.
@@ -588,6 +593,8 @@ def _process_pydantic_fields(
588593
validate_assignment=True,
589594
defer_build=True,
590595
extra="forbid",
596+
serialize_by_alias=True,
597+
validate_by_name=False,
591598
)
592599

593600

@@ -659,13 +666,19 @@ def __gel_model_construct__(cls, __dict__: dict[str, Any] | None) -> Self:
659666
ll_setattr(self, "__pydantic_private__", None)
660667
ll_setattr(self, "__gel_changed_fields__", None)
661668

669+
ll_setattr(self, "tname__", str(cls.__gel_reflection__.name))
670+
662671
if cls.__gel_has_id_field__:
663672
mid = self.__dict__.get("id", _unset)
664673
assert mid is not UNSET_UUID
665674
ll_setattr(self, "__gel_new__", mid is _unset)
666675

667676
return self
668677

678+
@property
679+
def __tname__(self) -> str:
680+
return self.tname__ # type: ignore[no-any-return, attr-defined]
681+
669682
@classmethod
670683
def model_construct(
671684
cls,
@@ -890,7 +903,24 @@ def __delattr__(self, name: str) -> None:
890903

891904
@classmethod
892905
def __gel_validate__(cls, value: Any) -> GelSourceModel:
893-
return cls.model_validate(value)
906+
# __gel_validate__ is called when validating a gel type
907+
# reached by a link, and Gel links may point to subtypes of
908+
# the declared type. We dispatch on that ourselves, using a
909+
# __tname__ field in the dict.
910+
ccls = cls
911+
key = '__tname__'
912+
if isinstance(value, dict) and key in value:
913+
ncls: Any = cls.get_class_by_name(value[key])
914+
if ccls is not ncls:
915+
# For proxy models, we need an appropriate proxy.
916+
if issubclass(ccls, ProxyModel):
917+
ccls = ccls._get_subtype_proxy(ncls)
918+
else:
919+
ccls = ncls
920+
921+
res = ccls.model_validate(value)
922+
923+
return res
894924

895925

896926
class GelModel(
@@ -950,6 +980,7 @@ def __new__(
950980

951981
validator = cls.__pydantic_validator__
952982
for arg, value in kwargs.items():
983+
arg = _pydantic_utils.GEL_ALIASES.get(arg, arg) # noqa: PLW2901 # ignoring it in ruff.toml failed?
953984
# Faster than setattr()
954985
validator.validate_assignment(self, arg, value)
955986

@@ -1216,6 +1247,7 @@ def __build_custom_serializer(
12161247
"__get_pydantic_core_schema__": classmethod(
12171248
lambda cls, source_type, handler: handler(source_type)
12181249
),
1250+
"__gel_is_custom_serializer__": True,
12191251
},
12201252
)
12211253

@@ -1363,6 +1395,18 @@ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
13631395
return pydantic.BaseModel.model_dump(self, *args, **kwargs)
13641396

13651397

1398+
def _pickle_dynamic_proxy_model(cls: Any) -> Any:
1399+
# See discussion in _typing_parametric. We do the same tricks as
1400+
# PickleableClassParametricType, basically.
1401+
if "__gel_dynamic_proxy_base__" in cls.__dict__:
1402+
base = cls.__dict__["__gel_dynamic_proxy_base__"]
1403+
return base._get_subtype_proxy, (cls.__proxy_of__,)
1404+
else:
1405+
# If it is not one of our dynamic things, we return the
1406+
# classname and pickle goes and does something useful.
1407+
return cls.__qualname__
1408+
1409+
13661410
class ProxyModel(
13671411
GelModel,
13681412
_abstract.AbstractGelProxyModel[_MT_co, GelLinkModel],
@@ -1373,6 +1417,12 @@ class ProxyModel(
13731417
if TYPE_CHECKING:
13741418
__gel_proxy_merged_model_cache__: ClassVar[type[_MergedModelBase]]
13751419

1420+
# A cache for dynamically generated proxies where the __proxy_of__
1421+
# class is a *subtype*. Generated by _get_subtype_proxy.
1422+
__gel_subtype_proxy_cache__: ClassVar[
1423+
dict[type[GelModel], type[ProxyModel[Any]]] | None
1424+
] = None
1425+
13761426
# NB: __linkprops__ is not in slots because it is managed by
13771427
# GelLinkModelDescriptor.
13781428

@@ -1406,9 +1456,56 @@ def __init__(
14061456
# We want ProxyModel to be a trasparent wrapper, so we
14071457
# forward the constructor arguments to the wrapped object.
14081458
wrapped = self.__proxy_of__(**kwargs)
1459+
assert not isinstance(wrapped, ProxyModel)
14091460
ll_setattr(self, "_p__obj__", wrapped)
14101461
# __linkprops__ is written into __dict__ by GelLinkModelDescriptor
14111462

1463+
def __init_subclass__(cls) -> None:
1464+
super().__init_subclass__()
1465+
1466+
# Register a custom reducer for every *metaclass*
1467+
# of ProxyModels. We need this to deal with our dynamic classes.
1468+
# Shouldn't be many.
1469+
copyreg.pickle(
1470+
type(cls),
1471+
_pickle_dynamic_proxy_model, # pyright: ignore [reportArgumentType]
1472+
)
1473+
1474+
@classmethod
1475+
def _get_subtype_proxy(cls, ncls: type[GelModel]) -> type[Self]:
1476+
"""Generate a subtype that has ncls as its __proxy_of__"""
1477+
1478+
if ncls is cls.__proxy_of__:
1479+
return cls
1480+
1481+
if cls.__gel_subtype_proxy_cache__ is None:
1482+
cls.__gel_subtype_proxy_cache__ = {}
1483+
1484+
if cached := cls.__gel_subtype_proxy_cache__.get(ncls):
1485+
return cast('type[Self]', cached)
1486+
1487+
core_schema = ProxyModel.__dict__['__get_pydantic_core_schema__']
1488+
new_proxy = type(cls)( # type: ignore[misc]
1489+
f'{cls.__name__}[{ncls.__name__}]',
1490+
(ncls, ProxyModel[ncls], cls), # type: ignore[valid-type]
1491+
{
1492+
k: cls.__dict__[k]
1493+
for k in (
1494+
'__module__',
1495+
'__qualname__',
1496+
)
1497+
if k in cls.__dict__
1498+
}
1499+
| {
1500+
# This gets overridden in some cases, and we need to restore it
1501+
'__get_pydantic_core_schema__': core_schema,
1502+
'__gel_dynamic_proxy_base__': cls,
1503+
},
1504+
)
1505+
new_proxy.model_rebuild()
1506+
cls.__gel_subtype_proxy_cache__[ncls] = new_proxy
1507+
return cast('type[Self]', new_proxy)
1508+
14121509
@classmethod
14131510
def link(cls, obj: _MT_co, /, **link_props: Any) -> Self: # type: ignore [misc]
14141511
proxy_of = ll_type_getattr(cls, "__proxy_of__")
@@ -1439,9 +1536,12 @@ def link(cls, obj: _MT_co, /, **link_props: Any) -> Self: # type: ignore [misc]
14391536
f"are allowed, got {type(obj).__name__}",
14401537
)
14411538

1442-
self = cls.__new__(cls)
1443-
lprops = cls.__linkprops__(**link_props)
1539+
ncls = cls._get_subtype_proxy(type(obj))
1540+
1541+
self = ncls.__new__(ncls)
1542+
lprops = ncls.__linkprops__(**link_props)
14441543
ll_setattr(self, "__linkprops__", lprops)
1544+
assert not isinstance(obj, ProxyModel)
14451545
ll_setattr(self, "_p__obj__", obj)
14461546

14471547
# Treat newly created link props as if they had all their values
@@ -1521,6 +1621,7 @@ def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
15211621
generic_meta = cls.__pydantic_generic_metadata__
15221622
if generic_meta["origin"] is ProxyModel and generic_meta["args"]:
15231623
cls.__proxy_of__ = generic_meta["args"][0]
1624+
assert issubclass(cls.__proxy_of__, _abstract.AbstractGelModel)
15241625

15251626
@classmethod
15261627
def __make_merged_model(cls) -> type[_MergedModelBase]:
@@ -1543,18 +1644,22 @@ def __make_merged_model(cls) -> type[_MergedModelBase]:
15431644
# _MergedModelBase has a custom metaclass, so we must
15441645
# create a common subclass of that and whatever the
15451646
# metaclass of the real type is.
1546-
metaclass = types.new_class(
1547-
f"{cls.__name__}Meta",
1548-
(_MergedModelMeta, type(cls.__proxy_of__)),
1549-
)
1647+
pmcls = type(cls.__proxy_of__)
1648+
if issubclass(pmcls, _MergedModelMeta):
1649+
metaclass = pmcls
1650+
else:
1651+
metaclass = types.new_class(
1652+
f"{cls.__name__}Meta",
1653+
(_MergedModelMeta, pmcls),
1654+
)
15501655

15511656
merged = cast(
15521657
"type[_MergedModelBase]",
15531658
pydantic.create_model(
15541659
cls.__name__,
15551660
__base__=(
1556-
_MergedModelBase,
15571661
cls.__proxy_of__,
1662+
_MergedModelBase,
15581663
), # inherit all wrapped fields
15591664
__config__=DEFAULT_MODEL_CONFIG,
15601665
__cls_kwargs__={"metaclass": metaclass},
@@ -1632,11 +1737,14 @@ def __gel_proxy_construct__(
16321737
*,
16331738
linked: bool = False,
16341739
) -> Self:
1635-
pnv = cls.__gel_model_construct__(None)
1740+
ncls = cls._get_subtype_proxy(type(obj))
1741+
1742+
pnv = ncls.__gel_model_construct__(None)
1743+
assert not isinstance(obj, ProxyModel)
16361744
ll_setattr(pnv, "_p__obj__", obj)
16371745

16381746
if type(lprops) is dict:
1639-
lp_obj = cls.__linkprops__.__gel_model_construct__(lprops)
1747+
lp_obj = ncls.__linkprops__.__gel_model_construct__(lprops)
16401748
else:
16411749
lp_obj = lprops # type: ignore [assignment]
16421750

gel/_internal/_qbmodel/_pydantic/_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
from ._models import GelModel
3030

3131

32+
# The real field name is tname__, and we need to manually apply the
33+
# alias in some places.
34+
GEL_ALIASES = {
35+
"__tname__": "tname__",
36+
}
37+
38+
3239
# TypeAliasType does not support recursive types. mypy errors out with:
3340
# error: Cannot resolve name "IncEx" (possible cyclic definition) [misc]
3441
#
@@ -201,6 +208,14 @@ def massage_model_dump_kwargs(
201208
except KeyError:
202209
exclude = kwargs["exclude"] = set()
203210

211+
# HACK: Handle exclude for internal gel aliases, since pydantic won't.b
212+
for from_, to_ in GEL_ALIASES.items():
213+
if from_ in exclude:
214+
if isinstance(exclude, dict):
215+
exclude[to_] = exclude[from_]
216+
elif isinstance(exclude, set):
217+
exclude.add(to_)
218+
204219
if model.__gel_has_id_field__ and model.__gel_new__:
205220
# You're generally not supposed to dump an unsaved model,
206221
# so this branch might be a bit slow -- we can potentially

0 commit comments

Comments
 (0)