@@ -2,64 +2,84 @@ use pyo3::intern;
22use pyo3:: prelude:: * ;
33use pyo3:: types:: { PyDict , PyList } ;
44
5- use ahash:: AHashSet ;
5+ use ahash:: { AHashMap , AHashSet } ;
66
77use crate :: build_tools:: { py_err, py_error_type, SchemaDict } ;
88use crate :: questions:: Answers ;
99use crate :: serializers:: CombinedSerializer ;
1010use crate :: validators:: { CombinedValidator , Validator } ;
1111
12- #[ derive( Clone ) ]
12+ #[ derive( Clone , Debug ) ]
1313struct Slot < T > {
1414 slot_ref : String ,
1515 op_val_ser : Option < T > ,
1616 answers : Option < Answers > ,
1717}
1818
19- /// `BuildContext` is used to store extra information while building validators and type_serializers,
20- /// currently it just holds a vec "slots" which holds validators/type_serializers which need to be accessed from
21- /// multiple other validators/type_serializers and therefore can't be owned by them directly.
22- #[ derive( Clone ) ]
19+ pub enum ThingOrId < T > {
20+ Thing ( T ) ,
21+ Id ( usize ) ,
22+ }
23+
24+ /// `BuildContext` is used to store extra information while building validators and type_serializers
25+ #[ derive( Clone , Debug ) ]
2326pub struct BuildContext < T > {
27+ /// set of used refs, useful to see if a `ref` is actually used elsewhere in the schema
2428 used_refs : AHashSet < String > ,
29+ /// holds validators/type_serializers which reference themselves and therefore can't be cloned and owned
30+ /// in one or multiple places.
2531 slots : Vec < Slot < T > > ,
32+ /// holds validators/type_serializers which need to be accessed from multiple other validators/type_serializers
33+ /// and therefore can't be owned by them directly.
34+ reusable : AHashMap < String , T > ,
2635}
2736
28- impl < T : Clone > BuildContext < T > {
29- pub fn new ( used_refs : AHashSet < String > ) -> Self {
30- Self {
31- used_refs,
32- slots : Vec :: new ( ) ,
33- }
34- }
35-
36- pub fn for_schema ( schema : & PyAny ) -> PyResult < Self > {
37+ impl < T : Clone + std:: fmt:: Debug > BuildContext < T > {
38+ pub fn new ( schema : & PyAny ) -> PyResult < Self > {
3739 let mut used_refs = AHashSet :: new ( ) ;
3840 extract_used_refs ( schema, & mut used_refs) ?;
3941 Ok ( Self {
4042 used_refs,
4143 slots : Vec :: new ( ) ,
44+ reusable : AHashMap :: new ( ) ,
4245 } )
4346 }
4447
4548 pub fn for_self_schema ( ) -> Self {
46- let mut used_refs = AHashSet :: new ( ) ;
49+ let mut used_refs = AHashSet :: with_capacity ( 3 ) ;
4750 // NOTE: we don't call `extract_used_refs` for performance reasons, if more recursive references
4851 // are used, they would need to be manually added here.
52+ // we use `2` as count to avoid `find_slot` pulling the validator out of slots and returning it directly
4953 used_refs. insert ( "root-schema" . to_string ( ) ) ;
5054 used_refs. insert ( "ser-schema" . to_string ( ) ) ;
5155 used_refs. insert ( "inc-ex-type" . to_string ( ) ) ;
5256 Self {
5357 used_refs,
5458 slots : Vec :: new ( ) ,
59+ reusable : AHashMap :: new ( ) ,
5560 }
5661 }
5762
63+ /// Check whether a ref is already in `reusable` or `slots`, we shouldn't allow repeated refs
64+ pub fn ref_already_used ( & self , ref_ : & str ) -> bool {
65+ self . reusable . contains_key ( ref_) || self . slots . iter ( ) . any ( |slot| slot. slot_ref == ref_)
66+ }
67+
5868 /// check if a ref is used elsewhere in the schema
5969 pub fn ref_used ( & self , ref_ : & str ) -> bool {
6070 self . used_refs . contains ( ref_)
6171 }
6272
73+ /// check if a ref is used within a given schema
74+ pub fn ref_used_within ( & self , schema_dict : & PyAny , ref_ : & str ) -> PyResult < bool > {
75+ check_ref_used ( schema_dict, ref_)
76+ }
77+
78+ /// add a validator/serializer to `reusable` so it can be cloned and used again elsewhere
79+ pub fn store_reusable ( & mut self , ref_ : String , val_ser : T ) {
80+ self . reusable . insert ( ref_, val_ser) ;
81+ }
82+
6383 /// First of two part process to add a new validator/serializer slot, we add the `slot_ref` to the array,
6484 /// but not the actual `validator`/`serializer`, we can't add that until it's build.
6585 /// But we need the `id` to build it, hence this two-step process.
@@ -89,15 +109,25 @@ impl<T: Clone> BuildContext<T> {
89109 }
90110 }
91111
92- /// find a slot by `slot_ref` - iterate over the slots until we find a matching reference - return the index
93- pub fn find_slot_id_answer ( & self , slot_ref : & str ) -> PyResult < ( usize , Option < Answers > ) > {
94- let is_match = |slot : & Slot < T > | slot. slot_ref == slot_ref;
95- match self . slots . iter ( ) . position ( is_match) {
96- Some ( id) => {
97- let slot = self . slots . get ( id) . unwrap ( ) ;
98- Ok ( ( id, slot. answers . clone ( ) ) )
99- }
100- None => py_err ! ( "Slots Error: ref '{}' not found" , slot_ref) ,
112+ /// find validator/serializer by `ref`, if the `ref` is in `resuable` return a clone of the validator/serializer,
113+ /// otherwise return the id of the slot.
114+ pub fn find ( & mut self , ref_ : & str ) -> PyResult < ThingOrId < T > > {
115+ if let Some ( val_ser) = self . reusable . get ( ref_) {
116+ Ok ( ThingOrId :: Thing ( val_ser. clone ( ) ) )
117+ } else {
118+ let id = match self . slots . iter ( ) . position ( |slot| slot. slot_ref == ref_) {
119+ Some ( id) => id,
120+ None => return py_err ! ( "Slots Error: ref '{}' not found" , ref_) ,
121+ } ;
122+ Ok ( ThingOrId :: Id ( id) )
123+ }
124+ }
125+
126+ /// get a slot answer by `id`
127+ pub fn get_slot_answer ( & self , slot_id : usize ) -> PyResult < Option < Answers > > {
128+ match self . slots . get ( slot_id) {
129+ Some ( slot) => Ok ( slot. answers . clone ( ) ) ,
130+ None => py_err ! ( "Slots Error: slot {} not found" , slot_id) ,
101131 }
102132 }
103133
@@ -147,9 +177,8 @@ impl BuildContext<CombinedSerializer> {
147177
148178fn extract_used_refs ( schema : & PyAny , refs : & mut AHashSet < String > ) -> PyResult < ( ) > {
149179 if let Ok ( dict) = schema. downcast :: < PyDict > ( ) {
150- let py = schema. py ( ) ;
151- if matches ! ( dict. get_as( intern!( py, "type" ) ) , Ok ( Some ( "recursive-ref" ) ) ) {
152- refs. insert ( dict. get_as_req ( intern ! ( py, "schema_ref" ) ) ?) ;
180+ if is_definition_ref ( dict) ? {
181+ refs. insert ( dict. get_as_req ( intern ! ( schema. py( ) , "schema_ref" ) ) ?) ;
153182 } else {
154183 for ( _, value) in dict. iter ( ) {
155184 extract_used_refs ( value, refs) ?;
@@ -162,3 +191,32 @@ fn extract_used_refs(schema: &PyAny, refs: &mut AHashSet<String>) -> PyResult<()
162191 }
163192 Ok ( ( ) )
164193}
194+
195+ fn check_ref_used ( schema : & PyAny , ref_ : & str ) -> PyResult < bool > {
196+ if let Ok ( dict) = schema. downcast :: < PyDict > ( ) {
197+ if is_definition_ref ( dict) ? {
198+ let key: & str = dict. get_as_req ( intern ! ( schema. py( ) , "schema_ref" ) ) ?;
199+ return Ok ( key == ref_) ;
200+ } else {
201+ for ( _, value) in dict. iter ( ) {
202+ if check_ref_used ( value, ref_) ? {
203+ return Ok ( true ) ;
204+ }
205+ }
206+ }
207+ } else if let Ok ( list) = schema. downcast :: < PyList > ( ) {
208+ for item in list. iter ( ) {
209+ if check_ref_used ( item, ref_) ? {
210+ return Ok ( true ) ;
211+ }
212+ }
213+ }
214+ Ok ( false )
215+ }
216+
217+ fn is_definition_ref ( dict : & PyDict ) -> PyResult < bool > {
218+ match dict. get_item ( intern ! ( dict. py( ) , "type" ) ) {
219+ Some ( type_value) => type_value. eq ( intern ! ( dict. py( ) , "definition-ref" ) ) ,
220+ None => Ok ( false ) ,
221+ }
222+ }
0 commit comments