@@ -8,7 +8,7 @@ use pyo3::types::{PyDict, PyList, PyString};
88use ahash:: AHashMap ;
99
1010use crate :: build_tools:: { is_strict, schema_or_config, SchemaDict } ;
11- use crate :: errors:: { ErrorKind , ValError , ValLineError , ValResult } ;
11+ use crate :: errors:: { ErrorKind , PydanticValueError , ValError , ValLineError , ValResult } ;
1212use crate :: input:: { GenericMapping , Input } ;
1313use crate :: lookup_key:: LookupKey ;
1414use crate :: questions:: Question ;
@@ -19,6 +19,7 @@ use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Ex
1919#[ derive( Debug , Clone ) ]
2020pub struct UnionValidator {
2121 choices : Vec < CombinedValidator > ,
22+ custom_error : Option < PydanticValueError > ,
2223 strict : bool ,
2324 name : String ,
2425}
@@ -31,8 +32,9 @@ impl BuildValidator for UnionValidator {
3132 config : Option < & PyDict > ,
3233 build_context : & mut BuildContext ,
3334 ) -> PyResult < CombinedValidator > {
35+ let py = schema. py ( ) ;
3436 let choices: Vec < CombinedValidator > = schema
35- . get_as_req :: < & PyList > ( intern ! ( schema . py ( ) , "choices" ) ) ?
37+ . get_as_req :: < & PyList > ( intern ! ( py , "choices" ) ) ?
3638 . iter ( )
3739 . map ( |choice| build_validator ( choice, config, build_context) )
3840 . collect :: < PyResult < Vec < CombinedValidator > > > ( ) ?;
@@ -41,13 +43,41 @@ impl BuildValidator for UnionValidator {
4143
4244 Ok ( Self {
4345 choices,
46+ custom_error : get_custom_error ( py, schema) ?,
4447 strict : is_strict ( schema, config) ?,
4548 name : format ! ( "{}[{}]" , Self :: EXPECTED_TYPE , descr) ,
4649 }
4750 . into ( ) )
4851 }
4952}
5053
54+ fn get_custom_error ( py : Python , schema : & PyDict ) -> PyResult < Option < PydanticValueError > > {
55+ match schema. get_as :: < & PyDict > ( intern ! ( py, "custom_error" ) ) ? {
56+ Some ( custom_error) => Ok ( Some ( PydanticValueError :: py_new (
57+ py,
58+ custom_error. get_as_req :: < String > ( intern ! ( py, "kind" ) ) ?,
59+ custom_error. get_as_req :: < String > ( intern ! ( py, "message" ) ) ?,
60+ None ,
61+ ) ) ) ,
62+ None => Ok ( None ) ,
63+ }
64+ }
65+
66+ impl UnionValidator {
67+ fn or_custom_error < ' s , ' data > (
68+ & ' s self ,
69+ errors : Option < Vec < ValLineError < ' data > > > ,
70+ input : & ' data impl Input < ' data > ,
71+ ) -> ValError < ' data > {
72+ if let Some ( errors) = errors {
73+ ValError :: LineErrors ( errors)
74+ } else {
75+ let value_error = self . custom_error . as_ref ( ) . unwrap ( ) ;
76+ value_error. clone ( ) . into_val_error ( input)
77+ }
78+ }
79+ }
80+
5181impl Validator for UnionValidator {
5282 fn validate < ' s , ' data > (
5383 & ' s self ,
@@ -58,7 +88,10 @@ impl Validator for UnionValidator {
5888 recursion_guard : & ' s mut RecursionGuard ,
5989 ) -> ValResult < ' data , PyObject > {
6090 if extra. strict . unwrap_or ( self . strict ) {
61- let mut errors: Vec < ValLineError > = Vec :: with_capacity ( self . choices . len ( ) ) ;
91+ let mut errors: Option < Vec < ValLineError > > = match self . custom_error {
92+ Some ( _) => None ,
93+ None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
94+ } ;
6295 let strict_extra = extra. as_strict ( ) ;
6396
6497 for validator in & self . choices {
@@ -67,14 +100,16 @@ impl Validator for UnionValidator {
67100 otherwise => return otherwise,
68101 } ;
69102
70- errors. extend (
71- line_errors
72- . into_iter ( )
73- . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
74- ) ;
103+ if let Some ( ref mut errors) = errors {
104+ errors. extend (
105+ line_errors
106+ . into_iter ( )
107+ . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
108+ ) ;
109+ }
75110 }
76111
77- Err ( ValError :: LineErrors ( errors) )
112+ Err ( self . or_custom_error ( errors, input ) )
78113 } else {
79114 // 1st pass: check if the value is an exact instance of one of the Union types,
80115 // e.g. use validate in strict mode
@@ -88,7 +123,10 @@ impl Validator for UnionValidator {
88123 return res;
89124 }
90125
91- let mut errors: Vec < ValLineError > = Vec :: with_capacity ( self . choices . len ( ) ) ;
126+ let mut errors: Option < Vec < ValLineError > > = match self . custom_error {
127+ Some ( _) => None ,
128+ None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
129+ } ;
92130
93131 // 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
94132 for validator in & self . choices {
@@ -97,14 +135,16 @@ impl Validator for UnionValidator {
97135 success => return success,
98136 } ;
99137
100- errors. extend (
101- line_errors
102- . into_iter ( )
103- . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
104- ) ;
138+ if let Some ( ref mut errors) = errors {
139+ errors. extend (
140+ line_errors
141+ . into_iter ( )
142+ . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
143+ ) ;
144+ }
105145 }
106146
107- Err ( ValError :: LineErrors ( errors) )
147+ Err ( self . or_custom_error ( errors, input ) )
108148 }
109149 }
110150
@@ -160,6 +200,7 @@ pub struct TaggedUnionValidator {
160200 discriminator : Discriminator ,
161201 from_attributes : bool ,
162202 strict : bool ,
203+ custom_error : Option < PydanticValueError > ,
163204 tags_repr : String ,
164205 discriminator_repr : String ,
165206 name : String ,
@@ -206,6 +247,7 @@ impl BuildValidator for TaggedUnionValidator {
206247 discriminator,
207248 from_attributes,
208249 strict : is_strict ( schema, config) ?,
250+ custom_error : get_custom_error ( py, schema) ?,
209251 tags_repr,
210252 discriminator_repr,
211253 name : format ! ( "{}[{}]" , Self :: EXPECTED_TYPE , descr) ,
@@ -341,6 +383,8 @@ impl TaggedUnionValidator {
341383 Ok ( res) => Ok ( res) ,
342384 Err ( err) => Err ( err. with_outer_location ( tag. as_ref ( ) . into ( ) ) ) ,
343385 }
386+ } else if let Some ( ref custom_error) = self . custom_error {
387+ Err ( custom_error. clone ( ) . into_val_error ( input) )
344388 } else {
345389 Err ( ValError :: new (
346390 ErrorKind :: UnionTagInvalid {
@@ -354,11 +398,15 @@ impl TaggedUnionValidator {
354398 }
355399
356400 fn tag_not_found < ' s , ' data > ( & ' s self , input : & ' data impl Input < ' data > ) -> ValError < ' data > {
357- ValError :: new (
358- ErrorKind :: UnionTagNotFound {
359- discriminator : self . discriminator_repr . clone ( ) ,
360- } ,
361- input,
362- )
401+ if let Some ( ref custom_error) = self . custom_error {
402+ custom_error. clone ( ) . into_val_error ( input)
403+ } else {
404+ ValError :: new (
405+ ErrorKind :: UnionTagNotFound {
406+ discriminator : self . discriminator_repr . clone ( ) ,
407+ } ,
408+ input,
409+ )
410+ }
363411 }
364412}
0 commit comments