11use std:: borrow:: Cow ;
2+ use std:: fmt;
23use std:: fmt:: Write ;
34
5+ use pyo3:: exceptions:: PyTypeError ;
46use pyo3:: intern;
57use pyo3:: prelude:: * ;
68use pyo3:: types:: { PyDict , PyList , PyString } ;
79
810use ahash:: AHashMap ;
911
1012use crate :: build_tools:: { is_strict, py_err, schema_or_config, SchemaDict } ;
11- use crate :: errors:: { ErrorType , ValError , ValLineError , ValResult } ;
13+ use crate :: errors:: { ErrorType , LocItem , ValError , ValLineError , ValResult } ;
1214use crate :: input:: { GenericMapping , Input } ;
1315use crate :: lookup_key:: LookupKey ;
1416use crate :: questions:: Question ;
@@ -189,10 +191,53 @@ impl Discriminator {
189191 }
190192}
191193
194+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
195+ enum ChoiceKey {
196+ Int ( i64 ) ,
197+ Str ( String ) ,
198+ }
199+
200+ impl ChoiceKey {
201+ fn from_py ( raw : & PyAny ) -> PyResult < Self > {
202+ if let Ok ( py_int) = raw. extract :: < i64 > ( ) {
203+ Ok ( Self :: Int ( py_int) )
204+ } else if let Ok ( py_str) = raw. downcast :: < PyString > ( ) {
205+ Ok ( Self :: Str ( py_str. to_str ( ) ?. to_string ( ) ) )
206+ } else {
207+ py_err ! ( PyTypeError ; "Expected int or str, got {}" , raw. get_type( ) . name( ) . unwrap_or( "<unknown python object>" ) )
208+ }
209+ }
210+
211+ fn repr ( & self ) -> String {
212+ match self {
213+ Self :: Int ( i) => i. to_string ( ) ,
214+ Self :: Str ( s) => format ! ( "'{s}'" ) ,
215+ }
216+ }
217+ }
218+
219+ impl fmt:: Display for ChoiceKey {
220+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
221+ match self {
222+ Self :: Int ( i) => write ! ( f, "{i}" ) ,
223+ Self :: Str ( s) => write ! ( f, "{s}" ) ,
224+ }
225+ }
226+ }
227+
228+ impl From < & ChoiceKey > for LocItem {
229+ fn from ( key : & ChoiceKey ) -> Self {
230+ match key {
231+ ChoiceKey :: Str ( s) => s. as_str ( ) . into ( ) ,
232+ ChoiceKey :: Int ( i) => ( * i) . into ( ) ,
233+ }
234+ }
235+ }
236+
192237#[ derive( Debug , Clone ) ]
193238pub struct TaggedUnionValidator {
194- choices : AHashMap < String , CombinedValidator > ,
195- repeat_choices : Option < AHashMap < String , String > > ,
239+ choices : AHashMap < ChoiceKey , CombinedValidator > ,
240+ repeat_choices : Option < AHashMap < ChoiceKey , ChoiceKey > > ,
196241 discriminator : Discriminator ,
197242 from_attributes : bool ,
198243 strict : bool ,
@@ -216,25 +261,27 @@ impl BuildValidator for TaggedUnionValidator {
216261
217262 let schema_choices: & PyDict = schema. get_as_req ( intern ! ( py, "choices" ) ) ?;
218263 let mut choices = AHashMap :: with_capacity ( schema_choices. len ( ) ) ;
219- let mut repeat_choices_vec: Vec < ( String , String ) > = Vec :: new ( ) ;
264+ let mut repeat_choices_vec: Vec < ( ChoiceKey , ChoiceKey ) > = Vec :: new ( ) ;
220265 let mut first = true ;
221266 let mut tags_repr = String :: with_capacity ( 50 ) ;
222267 let mut descr = String :: with_capacity ( 50 ) ;
223268
224269 for ( key, value) in schema_choices {
225- let tag: String = key . extract ( ) ?;
226- if let Ok ( py_str ) = value . downcast :: < PyString > ( ) {
227- let repeat_tag = py_str . to_str ( ) ? . to_string ( ) ;
270+ let tag = ChoiceKey :: from_py ( key ) ?;
271+
272+ if let Ok ( repeat_tag) = ChoiceKey :: from_py ( value ) {
228273 repeat_choices_vec. push ( ( tag, repeat_tag) ) ;
229274 continue ;
230275 }
276+
231277 let validator = build_validator ( value, config, build_context) ?;
278+ let tag_repr = tag. repr ( ) ;
232279 if first {
233280 first = false ;
234- write ! ( tags_repr, "'{tag}' " ) . unwrap ( ) ;
281+ write ! ( tags_repr, "{tag_repr} " ) . unwrap ( ) ;
235282 descr. push_str ( validator. get_name ( ) ) ;
236283 } else {
237- write ! ( tags_repr, ", '{tag}' " ) . unwrap ( ) ;
284+ write ! ( tags_repr, ", {tag_repr} " ) . unwrap ( ) ;
238285 // no spaces in get_name() output to make loc easy to read
239286 write ! ( descr, ",{}" , validator. get_name( ) ) . unwrap ( ) ;
240287 }
@@ -246,9 +293,10 @@ impl BuildValidator for TaggedUnionValidator {
246293 let mut wrong_values = Vec :: with_capacity ( repeat_choices_vec. len ( ) ) ;
247294 let mut repeat_choices = AHashMap :: with_capacity ( repeat_choices_vec. len ( ) ) ;
248295 for ( tag, repeat_tag) in repeat_choices_vec {
249- match choices. get ( repeat_tag. as_str ( ) ) {
296+ match choices. get ( & repeat_tag) {
250297 Some ( validator) => {
251- write ! ( tags_repr, ", '{tag}'" ) . unwrap ( ) ;
298+ let tag_repr = tag. repr ( ) ;
299+ write ! ( tags_repr, ", {tag_repr}" ) . unwrap ( ) ;
252300 write ! ( descr, ",{}" , validator. get_name( ) ) . unwrap ( ) ;
253301 repeat_choices. insert ( tag, repeat_tag) ;
254302 }
@@ -265,7 +313,7 @@ impl BuildValidator for TaggedUnionValidator {
265313 } ;
266314
267315 let key = intern ! ( py, "from_attributes" ) ;
268- let from_attributes = schema_or_config ( schema, config, key, key) ?. unwrap_or ( false ) ;
316+ let from_attributes = schema_or_config ( schema, config, key, key) ?. unwrap_or ( true ) ;
269317
270318 let descr = match discriminator {
271319 Discriminator :: SelfSchema => "self-schema" . to_string ( ) ,
@@ -304,10 +352,10 @@ impl Validator for TaggedUnionValidator {
304352 // errors when getting attributes which should be "raised"
305353 match lookup_key. $get_method( $( $dict ) ,+) ? {
306354 Some ( ( _, value) ) => {
307- if self . strict {
308- value . strict_str ( )
355+ if let Ok ( int ) = value . validate_int ( self . strict) {
356+ Ok ( ChoiceKey :: Int ( int ) )
309357 } else {
310- value. lax_str ( )
358+ Ok ( ChoiceKey :: Str ( value. validate_str ( self . strict ) ? . as_cow ( ) ? . as_ref ( ) . to_string ( ) ) )
311359 }
312360 }
313361 None => Err ( self . tag_not_found( input) ) ,
@@ -321,20 +369,20 @@ impl Validator for TaggedUnionValidator {
321369 GenericMapping :: PyMapping ( mapping) => find_validator ! ( py_get_mapping_item, mapping) ,
322370 GenericMapping :: JsonObject ( mapping) => find_validator ! ( json_get, mapping) ,
323371 } ?;
324- self . find_call_validator ( py, tag. as_cow ( ) ? , input, extra, slots, recursion_guard)
372+ self . find_call_validator ( py, & tag, input, extra, slots, recursion_guard)
325373 }
326374 Discriminator :: Function ( ref func) => {
327375 let tag = func. call1 ( py, ( input. to_object ( py) , ) ) ?;
328376 if tag. is_none ( py) {
329377 Err ( self . tag_not_found ( input) )
330378 } else {
331- let tag: & PyString = tag. downcast ( py) ?;
332- self . find_call_validator ( py, tag . to_string_lossy ( ) , input, extra, slots, recursion_guard)
379+ let tag: & PyAny = tag. downcast ( py) ?;
380+ self . find_call_validator ( py, & ( ChoiceKey :: from_py ( tag ) ? ) , input, extra, slots, recursion_guard)
333381 }
334382 }
335383 Discriminator :: SelfSchema => self . find_call_validator (
336384 py,
337- self . self_schema_tag ( py, input) ?,
385+ & ChoiceKey :: Str ( self . self_schema_tag ( py, input) ?. into_owned ( ) ) ,
338386 input,
339387 extra,
340388 slots,
@@ -407,23 +455,23 @@ impl TaggedUnionValidator {
407455 fn find_call_validator < ' s , ' data > (
408456 & ' s self ,
409457 py : Python < ' data > ,
410- tag : Cow < str > ,
458+ tag : & ChoiceKey ,
411459 input : & ' data impl Input < ' data > ,
412460 extra : & Extra ,
413461 slots : & ' data [ CombinedValidator ] ,
414462 recursion_guard : & ' s mut RecursionGuard ,
415463 ) -> ValResult < ' data , PyObject > {
416- if let Some ( validator) = self . choices . get ( tag. as_ref ( ) ) {
464+ if let Some ( validator) = self . choices . get ( tag) {
417465 return match validator. validate ( py, input, extra, slots, recursion_guard) {
418466 Ok ( res) => Ok ( res) ,
419- Err ( err) => Err ( err. with_outer_location ( tag. as_ref ( ) . into ( ) ) ) ,
467+ Err ( err) => Err ( err. with_outer_location ( tag. into ( ) ) ) ,
420468 } ;
421469 } else if let Some ( ref repeat_choices) = self . repeat_choices {
422- if let Some ( choice_tag) = repeat_choices. get ( tag. as_ref ( ) ) {
470+ if let Some ( choice_tag) = repeat_choices. get ( tag) {
423471 let validator = & self . choices [ choice_tag] ;
424472 return match validator. validate ( py, input, extra, slots, recursion_guard) {
425473 Ok ( res) => Ok ( res) ,
426- Err ( err) => Err ( err. with_outer_location ( tag. as_ref ( ) . into ( ) ) ) ,
474+ Err ( err) => Err ( err. with_outer_location ( tag. into ( ) ) ) ,
427475 } ;
428476 }
429477 }
0 commit comments