Skip to content

Commit c17abdd

Browse files
Adding definitions support (#406)
* adding definitions support * fix rust benchmarks * rename "recursive" -> "definition" * in-line some valdiators * error positions for definitions errors * inlining, take 3 * cleanup * correctly inline serializers * switch to defintions as a core schema type * errors on repeated refs, more tests * Fix benches * Update src/build_context.rs * Add metadata and serialization to definitions schema --------- Co-authored-by: David Montague <[email protected]>
1 parent 8d0678c commit c17abdd

24 files changed

+723
-220
lines changed

generate_self_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
core_schema_spec.loader.exec_module(core_schema)
4141

4242
# the validator for referencing schema (Schema is used recursively, so has to use a reference)
43-
schema_ref_validator = {'type': 'recursive-ref', 'schema_ref': 'root-schema'}
43+
schema_ref_validator = {'type': 'definition-ref', 'schema_ref': 'root-schema'}
4444

4545

4646
def get_schema(obj) -> core_schema.CoreSchema:
@@ -151,7 +151,7 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
151151
schema = get_schema(field_type)
152152
if fr_arg == 'SerSchema':
153153
if defined_ser_schema:
154-
schema = {'type': 'recursive-ref', 'schema_ref': 'ser-schema'}
154+
schema = {'type': 'definition-ref', 'schema_ref': 'ser-schema'}
155155
else:
156156
defined_ser_schema = True
157157
schema = tagged_union(schema, 'type', 'ser-schema')

pydantic_core/core_schema.py

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,35 +2318,6 @@ def call_schema(
23182318
)
23192319

23202320

2321-
class RecursiveReferenceSchema(TypedDict, total=False):
2322-
type: Required[Literal['recursive-ref']]
2323-
schema_ref: Required[str]
2324-
metadata: Any
2325-
serialization: SerSchema
2326-
2327-
2328-
def recursive_reference_schema(
2329-
schema_ref: str, metadata: Any = None, serialization: SerSchema | None = None
2330-
) -> RecursiveReferenceSchema:
2331-
"""
2332-
Returns a schema that matches a recursive reference value, e.g.:
2333-
2334-
```py
2335-
from pydantic_core import SchemaValidator, core_schema
2336-
schema_recursive = core_schema.recursive_reference_schema('list-schema')
2337-
schema = core_schema.list_schema(items_schema=schema_recursive, ref='list-schema')
2338-
v = SchemaValidator(schema)
2339-
assert v.validate_python([[]]) == [[]]
2340-
```
2341-
2342-
Args:
2343-
schema_ref: The schema ref to use for the recursive reference schema
2344-
metadata: See [TODO] for details
2345-
serialization: Custom serialization schema
2346-
"""
2347-
return dict_not_none(type='recursive-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization)
2348-
2349-
23502321
class CustomErrorSchema(TypedDict, total=False):
23512322
type: Required[Literal['custom-error']]
23522323
schema: Required[CoreSchema]
@@ -2579,6 +2550,66 @@ def multi_host_url_schema(
25792550
)
25802551

25812552

2553+
class DefinitionsSchema(TypedDict, total=False):
2554+
type: Required[Literal['definitions']]
2555+
schema: Required[CoreSchema]
2556+
definitions: Required[List[CoreSchema]]
2557+
metadata: Any
2558+
serialization: SerSchema
2559+
2560+
2561+
def definitions_schema(schema: CoreSchema, definitions: list[CoreSchema]) -> DefinitionsSchema:
2562+
"""
2563+
Build a schema that contains both an inner schema and a list of definitions which can be used
2564+
within the inner schema.
2565+
2566+
```py
2567+
from pydantic_core import SchemaValidator, core_schema
2568+
schema = core_schema.definitions_schema(
2569+
core_schema.list_schema(core_schema.definition_reference_schema('foobar')),
2570+
[core_schema.int_schema(ref='foobar')],
2571+
)
2572+
v = SchemaValidator(schema)
2573+
assert v.validate_python([1, 2, '3']) == [1, 2, 3]
2574+
```
2575+
2576+
Args:
2577+
schema: The inner schema
2578+
definitions: List of definitions which can be referenced within inner schema
2579+
"""
2580+
return DefinitionsSchema(type='definitions', schema=schema, definitions=definitions)
2581+
2582+
2583+
class DefinitionReferenceSchema(TypedDict, total=False):
2584+
type: Required[Literal['definition-ref']]
2585+
schema_ref: Required[str]
2586+
metadata: Any
2587+
serialization: SerSchema
2588+
2589+
2590+
def definition_reference_schema(
2591+
schema_ref: str, metadata: Any = None, serialization: SerSchema | None = None
2592+
) -> DefinitionReferenceSchema:
2593+
"""
2594+
Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive
2595+
models and also when you want to define validators separately from the main schema, e.g.:
2596+
2597+
```py
2598+
from pydantic_core import SchemaValidator, core_schema
2599+
schema_definition = core_schema.definition_reference_schema('list-schema')
2600+
schema = core_schema.list_schema(items_schema=schema_definition, ref='list-schema')
2601+
v = SchemaValidator(schema)
2602+
assert v.validate_python([[]]) == [[]]
2603+
```
2604+
2605+
Args:
2606+
schema_ref: The schema ref to use for the definition reference schema
2607+
metadata: See [TODO] for details
2608+
serialization: Custom serialization schema
2609+
"""
2610+
return dict_not_none(type='definition-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization)
2611+
2612+
25822613
CoreSchema = Union[
25832614
AnySchema,
25842615
NoneSchema,
@@ -2615,11 +2646,12 @@ def multi_host_url_schema(
26152646
ModelSchema,
26162647
ArgumentsSchema,
26172648
CallSchema,
2618-
RecursiveReferenceSchema,
26192649
CustomErrorSchema,
26202650
JsonSchema,
26212651
UrlSchema,
26222652
MultiHostUrlSchema,
2653+
DefinitionsSchema,
2654+
DefinitionReferenceSchema,
26232655
]
26242656

26252657
# to update this, call `pytest -k test_core_schema_type_literal` and copy the output
@@ -2656,11 +2688,12 @@ def multi_host_url_schema(
26562688
'model',
26572689
'arguments',
26582690
'call',
2659-
'recursive-ref',
26602691
'custom-error',
26612692
'json',
26622693
'url',
26632694
'multi-host-url',
2695+
'definitions',
2696+
'definition-ref',
26642697
]
26652698

26662699

src/build_context.rs

Lines changed: 86 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,84 @@ use pyo3::intern;
22
use pyo3::prelude::*;
33
use pyo3::types::{PyDict, PyList};
44

5-
use ahash::AHashSet;
5+
use ahash::{AHashMap, AHashSet};
66

77
use crate::build_tools::{py_err, py_error_type, SchemaDict};
88
use crate::questions::Answers;
99
use crate::serializers::CombinedSerializer;
1010
use crate::validators::{CombinedValidator, Validator};
1111

12-
#[derive(Clone)]
12+
#[derive(Clone, Debug)]
1313
struct Slot<T> {
1414
slot_ref: String,
1515
op_val_ser: Option<T>,
1616
answers: Option<Answers>,
1717
}
1818

19-
/// `BuildContext` is used to store extra information while building validators and type_serializers,
20-
/// currently it just holds a vec "slots" which holds validators/type_serializers which need to be accessed from
21-
/// multiple other validators/type_serializers and therefore can't be owned by them directly.
22-
#[derive(Clone)]
19+
pub enum ThingOrId<T> {
20+
Thing(T),
21+
Id(usize),
22+
}
23+
24+
/// `BuildContext` is used to store extra information while building validators and type_serializers
25+
#[derive(Clone, Debug)]
2326
pub struct BuildContext<T> {
27+
/// set of used refs, useful to see if a `ref` is actually used elsewhere in the schema
2428
used_refs: AHashSet<String>,
29+
/// holds validators/type_serializers which reference themselves and therefore can't be cloned and owned
30+
/// in one or multiple places.
2531
slots: Vec<Slot<T>>,
32+
/// holds validators/type_serializers which need to be accessed from multiple other validators/type_serializers
33+
/// and therefore can't be owned by them directly.
34+
reusable: AHashMap<String, T>,
2635
}
2736

28-
impl<T: Clone> BuildContext<T> {
29-
pub fn new(used_refs: AHashSet<String>) -> Self {
30-
Self {
31-
used_refs,
32-
slots: Vec::new(),
33-
}
34-
}
35-
36-
pub fn for_schema(schema: &PyAny) -> PyResult<Self> {
37+
impl<T: Clone + std::fmt::Debug> BuildContext<T> {
38+
pub fn new(schema: &PyAny) -> PyResult<Self> {
3739
let mut used_refs = AHashSet::new();
3840
extract_used_refs(schema, &mut used_refs)?;
3941
Ok(Self {
4042
used_refs,
4143
slots: Vec::new(),
44+
reusable: AHashMap::new(),
4245
})
4346
}
4447

4548
pub fn for_self_schema() -> Self {
46-
let mut used_refs = AHashSet::new();
49+
let mut used_refs = AHashSet::with_capacity(3);
4750
// NOTE: we don't call `extract_used_refs` for performance reasons, if more recursive references
4851
// are used, they would need to be manually added here.
52+
// we use `2` as count to avoid `find_slot` pulling the validator out of slots and returning it directly
4953
used_refs.insert("root-schema".to_string());
5054
used_refs.insert("ser-schema".to_string());
5155
used_refs.insert("inc-ex-type".to_string());
5256
Self {
5357
used_refs,
5458
slots: Vec::new(),
59+
reusable: AHashMap::new(),
5560
}
5661
}
5762

63+
/// Check whether a ref is already in `reusable` or `slots`, we shouldn't allow repeated refs
64+
pub fn ref_already_used(&self, ref_: &str) -> bool {
65+
self.reusable.contains_key(ref_) || self.slots.iter().any(|slot| slot.slot_ref == ref_)
66+
}
67+
5868
/// check if a ref is used elsewhere in the schema
5969
pub fn ref_used(&self, ref_: &str) -> bool {
6070
self.used_refs.contains(ref_)
6171
}
6272

73+
/// check if a ref is used within a given schema
74+
pub fn ref_used_within(&self, schema_dict: &PyAny, ref_: &str) -> PyResult<bool> {
75+
check_ref_used(schema_dict, ref_)
76+
}
77+
78+
/// add a validator/serializer to `reusable` so it can be cloned and used again elsewhere
79+
pub fn store_reusable(&mut self, ref_: String, val_ser: T) {
80+
self.reusable.insert(ref_, val_ser);
81+
}
82+
6383
/// First of two part process to add a new validator/serializer slot, we add the `slot_ref` to the array,
6484
/// but not the actual `validator`/`serializer`, we can't add that until it's build.
6585
/// But we need the `id` to build it, hence this two-step process.
@@ -89,15 +109,25 @@ impl<T: Clone> BuildContext<T> {
89109
}
90110
}
91111

92-
/// find a slot by `slot_ref` - iterate over the slots until we find a matching reference - return the index
93-
pub fn find_slot_id_answer(&self, slot_ref: &str) -> PyResult<(usize, Option<Answers>)> {
94-
let is_match = |slot: &Slot<T>| slot.slot_ref == slot_ref;
95-
match self.slots.iter().position(is_match) {
96-
Some(id) => {
97-
let slot = self.slots.get(id).unwrap();
98-
Ok((id, slot.answers.clone()))
99-
}
100-
None => py_err!("Slots Error: ref '{}' not found", slot_ref),
112+
/// find validator/serializer by `ref`, if the `ref` is in `resuable` return a clone of the validator/serializer,
113+
/// otherwise return the id of the slot.
114+
pub fn find(&mut self, ref_: &str) -> PyResult<ThingOrId<T>> {
115+
if let Some(val_ser) = self.reusable.get(ref_) {
116+
Ok(ThingOrId::Thing(val_ser.clone()))
117+
} else {
118+
let id = match self.slots.iter().position(|slot| slot.slot_ref == ref_) {
119+
Some(id) => id,
120+
None => return py_err!("Slots Error: ref '{}' not found", ref_),
121+
};
122+
Ok(ThingOrId::Id(id))
123+
}
124+
}
125+
126+
/// get a slot answer by `id`
127+
pub fn get_slot_answer(&self, slot_id: usize) -> PyResult<Option<Answers>> {
128+
match self.slots.get(slot_id) {
129+
Some(slot) => Ok(slot.answers.clone()),
130+
None => py_err!("Slots Error: slot {} not found", slot_id),
101131
}
102132
}
103133

@@ -147,9 +177,8 @@ impl BuildContext<CombinedSerializer> {
147177

148178
fn extract_used_refs(schema: &PyAny, refs: &mut AHashSet<String>) -> PyResult<()> {
149179
if let Ok(dict) = schema.downcast::<PyDict>() {
150-
let py = schema.py();
151-
if matches!(dict.get_as(intern!(py, "type")), Ok(Some("recursive-ref"))) {
152-
refs.insert(dict.get_as_req(intern!(py, "schema_ref"))?);
180+
if is_definition_ref(dict)? {
181+
refs.insert(dict.get_as_req(intern!(schema.py(), "schema_ref"))?);
153182
} else {
154183
for (_, value) in dict.iter() {
155184
extract_used_refs(value, refs)?;
@@ -162,3 +191,32 @@ fn extract_used_refs(schema: &PyAny, refs: &mut AHashSet<String>) -> PyResult<()
162191
}
163192
Ok(())
164193
}
194+
195+
fn check_ref_used(schema: &PyAny, ref_: &str) -> PyResult<bool> {
196+
if let Ok(dict) = schema.downcast::<PyDict>() {
197+
if is_definition_ref(dict)? {
198+
let key: &str = dict.get_as_req(intern!(schema.py(), "schema_ref"))?;
199+
return Ok(key == ref_);
200+
} else {
201+
for (_, value) in dict.iter() {
202+
if check_ref_used(value, ref_)? {
203+
return Ok(true);
204+
}
205+
}
206+
}
207+
} else if let Ok(list) = schema.downcast::<PyList>() {
208+
for item in list.iter() {
209+
if check_ref_used(item, ref_)? {
210+
return Ok(true);
211+
}
212+
}
213+
}
214+
Ok(false)
215+
}
216+
217+
fn is_definition_ref(dict: &PyDict) -> PyResult<bool> {
218+
match dict.get_item(intern!(dict.py(), "type")) {
219+
Some(type_value) => type_value.eq(intern!(dict.py(), "definition-ref")),
220+
None => Ok(false),
221+
}
222+
}

src/recursion_guard.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use ahash::AHashSet;
22

33
/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
4-
/// It's used in `validators/recursive.rs` to detect when a reference is reused within itself.
4+
/// It's used in `validators/definition` to detect when a reference is reused within itself.
55
#[derive(Debug, Clone, Default)]
66
pub struct RecursionGuard {
77
ids: Option<AHashSet<usize>>,
8-
// see validators/recursive.rs::BACKUP_GUARD_LIMIT for details
8+
// see validators/definition::BACKUP_GUARD_LIMIT for details
99
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1010
// use one number for all validators
1111
depth: u16,

src/serializers/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use pyo3::prelude::*;
44
use pyo3::types::{PyBytes, PyDict};
55

66
use crate::build_context::BuildContext;
7-
use crate::SchemaValidator;
7+
use crate::validators::SelfValidator;
88

99
use config::SerializationConfig;
1010
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
@@ -34,8 +34,10 @@ pub struct SchemaSerializer {
3434
impl SchemaSerializer {
3535
#[new]
3636
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
37-
let schema = SchemaValidator::validate_schema(py, schema)?;
38-
let mut build_context = BuildContext::for_schema(schema)?;
37+
let self_validator = SelfValidator::new(py)?;
38+
let schema = self_validator.validate_schema(py, schema)?;
39+
let mut build_context = BuildContext::new(schema)?;
40+
3941
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut build_context)?;
4042
Ok(Self {
4143
serializer,

0 commit comments

Comments
 (0)