|  | 
|  | 1 | +use std::{borrow::Cow, sync::OnceLock}; | 
|  | 2 | + | 
|  | 3 | +use pyo3::{ | 
|  | 4 | +    intern, | 
|  | 5 | +    types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType}, | 
|  | 6 | +    Bound, Py, PyAny, PyObject, PyResult, Python, | 
|  | 7 | +}; | 
|  | 8 | + | 
|  | 9 | +use crate::{ | 
|  | 10 | +    definitions::DefinitionsBuilder, | 
|  | 11 | +    serializers::{ | 
|  | 12 | +        shared::{BuildSerializer, TypeSerializer}, | 
|  | 13 | +        CombinedSerializer, Extra, | 
|  | 14 | +    }, | 
|  | 15 | +    SchemaSerializer, | 
|  | 16 | +}; | 
|  | 17 | + | 
|  | 18 | +#[derive(Debug)] | 
|  | 19 | +pub struct NestedModelSerializer { | 
|  | 20 | +    model: Py<PyType>, | 
|  | 21 | +    name: String, | 
|  | 22 | +    get_serializer: Py<PyAny>, | 
|  | 23 | +    serializer: OnceLock<PyResult<Py<SchemaSerializer>>>, | 
|  | 24 | +} | 
|  | 25 | + | 
|  | 26 | +impl_py_gc_traverse!(NestedModelSerializer { | 
|  | 27 | +    model, | 
|  | 28 | +    get_serializer, | 
|  | 29 | +    serializer | 
|  | 30 | +}); | 
|  | 31 | + | 
|  | 32 | +impl BuildSerializer for NestedModelSerializer { | 
|  | 33 | +    const EXPECTED_TYPE: &'static str = "nested-model"; | 
|  | 34 | + | 
|  | 35 | +    fn build( | 
|  | 36 | +        schema: &Bound<'_, PyDict>, | 
|  | 37 | +        _config: Option<&Bound<'_, PyDict>>, | 
|  | 38 | +        _definitions: &mut DefinitionsBuilder<CombinedSerializer>, | 
|  | 39 | +    ) -> PyResult<CombinedSerializer> { | 
|  | 40 | +        let py = schema.py(); | 
|  | 41 | + | 
|  | 42 | +        let get_serializer = schema | 
|  | 43 | +            .get_item(intern!(py, "get_info"))? | 
|  | 44 | +            .expect("Invalid core schema for `nested-model` type, no `get_info`") | 
|  | 45 | +            .unbind(); | 
|  | 46 | + | 
|  | 47 | +        let model = schema | 
|  | 48 | +            .get_item(intern!(py, "model"))? | 
|  | 49 | +            .expect("Invalid core schema for `nested-model` type, no `model`") | 
|  | 50 | +            .downcast::<PyType>() | 
|  | 51 | +            .expect("Invalid core schema for `nested-model` type, not a `PyType`") | 
|  | 52 | +            .clone(); | 
|  | 53 | + | 
|  | 54 | +        let name = model.getattr(intern!(py, "__name__"))?.extract()?; | 
|  | 55 | + | 
|  | 56 | +        Ok(CombinedSerializer::NestedModel(NestedModelSerializer { | 
|  | 57 | +            model: model.clone().unbind(), | 
|  | 58 | +            name, | 
|  | 59 | +            get_serializer, | 
|  | 60 | +            serializer: OnceLock::new(), | 
|  | 61 | +        })) | 
|  | 62 | +    } | 
|  | 63 | +} | 
|  | 64 | + | 
|  | 65 | +impl NestedModelSerializer { | 
|  | 66 | +    fn nested_serializer<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaSerializer>> { | 
|  | 67 | +        self.serializer | 
|  | 68 | +            .get_or_init(|| { | 
|  | 69 | +                Ok(self | 
|  | 70 | +                    .get_serializer | 
|  | 71 | +                    .bind(py) | 
|  | 72 | +                    .call((), None)? | 
|  | 73 | +                    .downcast::<PyTuple>()? | 
|  | 74 | +                    .get_item(2)? | 
|  | 75 | +                    .downcast::<SchemaSerializer>()? | 
|  | 76 | +                    .clone() | 
|  | 77 | +                    .unbind()) | 
|  | 78 | +            }) | 
|  | 79 | +            .as_ref() | 
|  | 80 | +            .map_err(|e| e.clone_ref(py)) | 
|  | 81 | +    } | 
|  | 82 | +} | 
|  | 83 | + | 
|  | 84 | +impl TypeSerializer for NestedModelSerializer { | 
|  | 85 | +    fn to_python( | 
|  | 86 | +        &self, | 
|  | 87 | +        value: &Bound<'_, PyAny>, | 
|  | 88 | +        include: Option<&Bound<'_, PyAny>>, | 
|  | 89 | +        exclude: Option<&Bound<'_, PyAny>>, | 
|  | 90 | +        mut extra: &Extra, | 
|  | 91 | +    ) -> PyResult<PyObject> { | 
|  | 92 | +        let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?; | 
|  | 93 | + | 
|  | 94 | +        self.nested_serializer(value.py())? | 
|  | 95 | +            .bind(value.py()) | 
|  | 96 | +            .get() | 
|  | 97 | +            .serializer | 
|  | 98 | +            .to_python(value, include, exclude, guard.state()) | 
|  | 99 | +    } | 
|  | 100 | + | 
|  | 101 | +    fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> { | 
|  | 102 | +        self.nested_serializer(key.py())? | 
|  | 103 | +            .bind(key.py()) | 
|  | 104 | +            .get() | 
|  | 105 | +            .serializer | 
|  | 106 | +            .json_key(key, extra) | 
|  | 107 | +    } | 
|  | 108 | + | 
|  | 109 | +    fn serde_serialize<S: serde::ser::Serializer>( | 
|  | 110 | +        &self, | 
|  | 111 | +        value: &Bound<'_, PyAny>, | 
|  | 112 | +        serializer: S, | 
|  | 113 | +        include: Option<&Bound<'_, PyAny>>, | 
|  | 114 | +        exclude: Option<&Bound<'_, PyAny>>, | 
|  | 115 | +        mut extra: &Extra, | 
|  | 116 | +    ) -> Result<S::Ok, S::Error> { | 
|  | 117 | +        use super::py_err_se_err; | 
|  | 118 | + | 
|  | 119 | +        let mut guard = extra | 
|  | 120 | +            .recursion_guard(value, self.model.as_ptr() as usize) | 
|  | 121 | +            .map_err(py_err_se_err)?; | 
|  | 122 | + | 
|  | 123 | +        self.nested_serializer(value.py()) | 
|  | 124 | +            // FIXME(BoxyUwU): Don't unwrap this | 
|  | 125 | +            .unwrap() | 
|  | 126 | +            .bind(value.py()) | 
|  | 127 | +            .get() | 
|  | 128 | +            .serializer | 
|  | 129 | +            .serde_serialize(value, serializer, include, exclude, guard.state()) | 
|  | 130 | +    } | 
|  | 131 | + | 
|  | 132 | +    fn get_name(&self) -> &str { | 
|  | 133 | +        &self.name | 
|  | 134 | +    } | 
|  | 135 | +} | 
0 commit comments