Skip to content

Commit 5bc4689

Browse files
philhchensamuelcolvindmontagu
authored
Use ChoiceKey in TaggedUnionValidator for int keys (#405)
* Mostly working ChoiceKey for int keys * Update src/validators/union.rs Co-authored-by: Samuel Colvin <[email protected]> * Update src/validators/union.rs Co-authored-by: Samuel Colvin <[email protected]> * Update src/validators/union.rs Co-authored-by: Samuel Colvin <[email protected]> * address comments * address more comments * Address more comments and fix test * add tests for int choice keys * fix * add enum choices test * update repeated tag test * Set from_attributes to true by default for TaggedUnionValidator * Get test passing * fix tags_repr for ints * refactor repeated tags test * Use _extra.strict for literal validation of strs and ints * support i64 location keys, tests * change literal strict usage --------- Co-authored-by: Samuel Colvin <[email protected]> Co-authored-by: David Montague <[email protected]>
1 parent 3ea00bd commit 5bc4689

File tree

8 files changed

+290
-71
lines changed

8 files changed

+290
-71
lines changed

generate_self_schema.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,14 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
129129
schema = {'type': 'list', 'items_schema': schema_ref_validator}
130130
elif fr_arg == 'Dict[str, CoreSchema]':
131131
schema = {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
132-
elif fr_arg == 'Dict[str, Union[str, CoreSchema]]':
132+
elif fr_arg == 'Dict[Union[str, int], Union[str, int, CoreSchema]]':
133133
schema = {
134134
'type': 'dict',
135-
'keys_schema': {'type': 'str'},
136-
'values_schema': {'type': 'union', 'choices': [{'type': 'str'}, schema_ref_validator]},
135+
'keys_schema': {'type': 'union', 'choices': [{'type': 'str'}, {'type': 'int'}]},
136+
'values_schema': {
137+
'type': 'union',
138+
'choices': [{'type': 'str'}, {'type': 'int'}, schema_ref_validator],
139+
},
137140
}
138141
else:
139142
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')

pydantic_core/core_schema.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,27 +1805,29 @@ def union_schema(
18051805

18061806
class TaggedUnionSchema(TypedDict, total=False):
18071807
type: Required[Literal['tagged-union']]
1808-
choices: Required[Dict[str, Union[str, CoreSchema]]]
1808+
choices: Required[Dict[Union[str, int], Union[str, int, CoreSchema]]]
18091809
discriminator: Required[
18101810
Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[str]]]
18111811
]
18121812
custom_error_type: str
18131813
custom_error_message: str
18141814
custom_error_context: Dict[str, Union[str, int, float]]
18151815
strict: bool
1816+
from_attributes: bool # default: True
18161817
ref: str
18171818
metadata: Any
18181819
serialization: SerSchema
18191820

18201821

18211822
def tagged_union_schema(
1822-
choices: Dict[str, str | CoreSchema],
1823+
choices: Dict[Union[int, str], int | str | CoreSchema],
18231824
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | None],
18241825
*,
18251826
custom_error_type: str | None = None,
18261827
custom_error_message: str | None = None,
18271828
custom_error_context: dict[str, int | str | float] | None = None,
18281829
strict: bool | None = None,
1830+
from_attributes: bool | None = None,
18291831
ref: str | None = None,
18301832
metadata: Any = None,
18311833
serialization: SerSchema | None = None,
@@ -1873,6 +1875,7 @@ def tagged_union_schema(
18731875
custom_error_message: The custom error message to use if the validation fails
18741876
custom_error_context: The custom error context to use if the validation fails
18751877
strict: Whether the underlying schemas should be validated with strict mode
1878+
from_attributes: Whether to use the attributes of the object to retrieve the discriminator value
18761879
ref: See [TODO] for details
18771880
metadata: See [TODO] for details
18781881
serialization: Custom serialization schema
@@ -1885,6 +1888,7 @@ def tagged_union_schema(
18851888
custom_error_message=custom_error_message,
18861889
custom_error_context=custom_error_context,
18871890
strict=strict,
1891+
from_attributes=from_attributes,
18881892
ref=ref,
18891893
metadata=metadata,
18901894
serialization=serialization,

src/errors/location.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@ use pyo3::types::PyTuple;
1111
pub enum LocItem {
1212
/// string type key, used to identify items from a dict or anything that implements `__getitem__`
1313
S(String),
14-
/// integer key, used to get items from a list, tuple OR a dict with int keys `Dict[int, ...]` (python only)
15-
I(usize),
14+
/// integer key, used to get:
15+
/// * items from a list
16+
/// * items from a tuple
17+
/// * dict with int keys `Dict[int, ...]` (python only)
18+
/// * with integer keys in tagged unions
19+
I(i64),
1620
}
1721

1822
impl fmt::Display for LocItem {
@@ -36,12 +40,18 @@ impl From<&str> for LocItem {
3640
}
3741
}
3842

39-
impl From<usize> for LocItem {
40-
fn from(i: usize) -> Self {
43+
impl From<i64> for LocItem {
44+
fn from(i: i64) -> Self {
4145
Self::I(i)
4246
}
4347
}
4448

49+
impl From<usize> for LocItem {
50+
fn from(u: usize) -> Self {
51+
Self::I(u as i64)
52+
}
53+
}
54+
4555
impl ToPyObject for LocItem {
4656
fn to_object(&self, py: Python<'_>) -> PyObject {
4757
match self {

src/input/input_json.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl<'a> Input<'a> for JsonInput {
2323
#[cfg_attr(has_no_coverage, no_coverage)]
2424
fn as_loc_item(&self) -> LocItem {
2525
match self {
26-
JsonInput::Int(i) => LocItem::I(*i as usize),
26+
JsonInput::Int(i) => (*i).into(),
2727
JsonInput::String(s) => s.as_str().into(),
2828
v => format!("{v:?}").into(),
2929
}

src/validators/literal.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ impl Validator for LiteralSingleStringValidator {
7171
&'s self,
7272
py: Python<'data>,
7373
input: &'data impl Input<'data>,
74-
_extra: &Extra,
74+
extra: &Extra,
7575
_slots: &'data [CombinedValidator],
7676
_recursion_guard: &'s mut RecursionGuard,
7777
) -> ValResult<'data, PyObject> {
78-
let either_str = input.strict_str()?;
78+
let either_str = input.validate_str(extra.strict.unwrap_or(false))?;
7979
if either_str.as_cow()?.as_ref() == self.expected.as_str() {
8080
Ok(input.to_object(py))
8181
} else {
@@ -113,12 +113,12 @@ impl Validator for LiteralSingleIntValidator {
113113
&'s self,
114114
py: Python<'data>,
115115
input: &'data impl Input<'data>,
116-
_extra: &Extra,
116+
extra: &Extra,
117117
_slots: &'data [CombinedValidator],
118118
_recursion_guard: &'s mut RecursionGuard,
119119
) -> ValResult<'data, PyObject> {
120-
let str = input.strict_int()?;
121-
if str == self.expected {
120+
let int = input.validate_int(extra.strict.unwrap_or(false))?;
121+
if int == self.expected {
122122
Ok(input.to_object(py))
123123
} else {
124124
Err(ValError::new(
@@ -168,11 +168,11 @@ impl Validator for LiteralMultipleStringsValidator {
168168
&'s self,
169169
py: Python<'data>,
170170
input: &'data impl Input<'data>,
171-
_extra: &Extra,
171+
extra: &Extra,
172172
_slots: &'data [CombinedValidator],
173173
_recursion_guard: &'s mut RecursionGuard,
174174
) -> ValResult<'data, PyObject> {
175-
let either_str = input.strict_str()?;
175+
let either_str = input.validate_str(extra.strict.unwrap_or(false))?;
176176
if self.expected.contains(either_str.as_cow()?.as_ref()) {
177177
Ok(input.to_object(py))
178178
} else {
@@ -223,11 +223,11 @@ impl Validator for LiteralMultipleIntsValidator {
223223
&'s self,
224224
py: Python<'data>,
225225
input: &'data impl Input<'data>,
226-
_extra: &Extra,
226+
extra: &Extra,
227227
_slots: &'data [CombinedValidator],
228228
_recursion_guard: &'s mut RecursionGuard,
229229
) -> ValResult<'data, PyObject> {
230-
let int = input.strict_int()?;
230+
let int = input.validate_int(extra.strict.unwrap_or(false))?;
231231
if self.expected.contains(&int) {
232232
Ok(input.to_object(py))
233233
} else {
@@ -287,19 +287,20 @@ impl Validator for LiteralGeneralValidator {
287287
&'s self,
288288
py: Python<'data>,
289289
input: &'data impl Input<'data>,
290-
_extra: &Extra,
290+
extra: &Extra,
291291
_slots: &'data [CombinedValidator],
292292
_recursion_guard: &'s mut RecursionGuard,
293293
) -> ValResult<'data, PyObject> {
294+
let strict = extra.strict.unwrap_or(false);
294295
if !self.expected_int.is_empty() {
295-
if let Ok(int) = input.strict_int() {
296+
if let Ok(int) = input.validate_int(strict) {
296297
if self.expected_int.contains(&int) {
297298
return Ok(input.to_object(py));
298299
}
299300
}
300301
}
301302
if !self.expected_str.is_empty() {
302-
if let Ok(either_str) = input.strict_str() {
303+
if let Ok(either_str) = input.validate_str(strict) {
303304
if self.expected_str.contains(either_str.as_cow()?.as_ref()) {
304305
return Ok(input.to_object(py));
305306
}

src/validators/union.rs

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use std::borrow::Cow;
2+
use std::fmt;
23
use std::fmt::Write;
34

5+
use pyo3::exceptions::PyTypeError;
46
use pyo3::intern;
57
use pyo3::prelude::*;
68
use pyo3::types::{PyDict, PyList, PyString};
79

810
use ahash::AHashMap;
911

1012
use crate::build_tools::{is_strict, py_err, schema_or_config, SchemaDict};
11-
use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
13+
use crate::errors::{ErrorType, LocItem, ValError, ValLineError, ValResult};
1214
use crate::input::{GenericMapping, Input};
1315
use crate::lookup_key::LookupKey;
1416
use crate::questions::Question;
@@ -189,10 +191,53 @@ impl Discriminator {
189191
}
190192
}
191193

194+
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
195+
enum ChoiceKey {
196+
Int(i64),
197+
Str(String),
198+
}
199+
200+
impl ChoiceKey {
201+
fn from_py(raw: &PyAny) -> PyResult<Self> {
202+
if let Ok(py_int) = raw.extract::<i64>() {
203+
Ok(Self::Int(py_int))
204+
} else if let Ok(py_str) = raw.downcast::<PyString>() {
205+
Ok(Self::Str(py_str.to_str()?.to_string()))
206+
} else {
207+
py_err!(PyTypeError; "Expected int or str, got {}", raw.get_type().name().unwrap_or("<unknown python object>"))
208+
}
209+
}
210+
211+
fn repr(&self) -> String {
212+
match self {
213+
Self::Int(i) => i.to_string(),
214+
Self::Str(s) => format!("'{s}'"),
215+
}
216+
}
217+
}
218+
219+
impl fmt::Display for ChoiceKey {
220+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221+
match self {
222+
Self::Int(i) => write!(f, "{i}"),
223+
Self::Str(s) => write!(f, "{s}"),
224+
}
225+
}
226+
}
227+
228+
impl From<&ChoiceKey> for LocItem {
229+
fn from(key: &ChoiceKey) -> Self {
230+
match key {
231+
ChoiceKey::Str(s) => s.as_str().into(),
232+
ChoiceKey::Int(i) => (*i).into(),
233+
}
234+
}
235+
}
236+
192237
#[derive(Debug, Clone)]
193238
pub struct TaggedUnionValidator {
194-
choices: AHashMap<String, CombinedValidator>,
195-
repeat_choices: Option<AHashMap<String, String>>,
239+
choices: AHashMap<ChoiceKey, CombinedValidator>,
240+
repeat_choices: Option<AHashMap<ChoiceKey, ChoiceKey>>,
196241
discriminator: Discriminator,
197242
from_attributes: bool,
198243
strict: bool,
@@ -216,25 +261,27 @@ impl BuildValidator for TaggedUnionValidator {
216261

217262
let schema_choices: &PyDict = schema.get_as_req(intern!(py, "choices"))?;
218263
let mut choices = AHashMap::with_capacity(schema_choices.len());
219-
let mut repeat_choices_vec: Vec<(String, String)> = Vec::new();
264+
let mut repeat_choices_vec: Vec<(ChoiceKey, ChoiceKey)> = Vec::new();
220265
let mut first = true;
221266
let mut tags_repr = String::with_capacity(50);
222267
let mut descr = String::with_capacity(50);
223268

224269
for (key, value) in schema_choices {
225-
let tag: String = key.extract()?;
226-
if let Ok(py_str) = value.downcast::<PyString>() {
227-
let repeat_tag = py_str.to_str()?.to_string();
270+
let tag = ChoiceKey::from_py(key)?;
271+
272+
if let Ok(repeat_tag) = ChoiceKey::from_py(value) {
228273
repeat_choices_vec.push((tag, repeat_tag));
229274
continue;
230275
}
276+
231277
let validator = build_validator(value, config, build_context)?;
278+
let tag_repr = tag.repr();
232279
if first {
233280
first = false;
234-
write!(tags_repr, "'{tag}'").unwrap();
281+
write!(tags_repr, "{tag_repr}").unwrap();
235282
descr.push_str(validator.get_name());
236283
} else {
237-
write!(tags_repr, ", '{tag}'").unwrap();
284+
write!(tags_repr, ", {tag_repr}").unwrap();
238285
// no spaces in get_name() output to make loc easy to read
239286
write!(descr, ",{}", validator.get_name()).unwrap();
240287
}
@@ -246,9 +293,10 @@ impl BuildValidator for TaggedUnionValidator {
246293
let mut wrong_values = Vec::with_capacity(repeat_choices_vec.len());
247294
let mut repeat_choices = AHashMap::with_capacity(repeat_choices_vec.len());
248295
for (tag, repeat_tag) in repeat_choices_vec {
249-
match choices.get(repeat_tag.as_str()) {
296+
match choices.get(&repeat_tag) {
250297
Some(validator) => {
251-
write!(tags_repr, ", '{tag}'").unwrap();
298+
let tag_repr = tag.repr();
299+
write!(tags_repr, ", {tag_repr}").unwrap();
252300
write!(descr, ",{}", validator.get_name()).unwrap();
253301
repeat_choices.insert(tag, repeat_tag);
254302
}
@@ -265,7 +313,7 @@ impl BuildValidator for TaggedUnionValidator {
265313
};
266314

267315
let key = intern!(py, "from_attributes");
268-
let from_attributes = schema_or_config(schema, config, key, key)?.unwrap_or(false);
316+
let from_attributes = schema_or_config(schema, config, key, key)?.unwrap_or(true);
269317

270318
let descr = match discriminator {
271319
Discriminator::SelfSchema => "self-schema".to_string(),
@@ -304,10 +352,10 @@ impl Validator for TaggedUnionValidator {
304352
// errors when getting attributes which should be "raised"
305353
match lookup_key.$get_method($( $dict ),+)? {
306354
Some((_, value)) => {
307-
if self.strict {
308-
value.strict_str()
355+
if let Ok(int) = value.validate_int(self.strict) {
356+
Ok(ChoiceKey::Int(int))
309357
} else {
310-
value.lax_str()
358+
Ok(ChoiceKey::Str(value.validate_str(self.strict)?.as_cow()?.as_ref().to_string()))
311359
}
312360
}
313361
None => Err(self.tag_not_found(input)),
@@ -321,20 +369,20 @@ impl Validator for TaggedUnionValidator {
321369
GenericMapping::PyMapping(mapping) => find_validator!(py_get_mapping_item, mapping),
322370
GenericMapping::JsonObject(mapping) => find_validator!(json_get, mapping),
323371
}?;
324-
self.find_call_validator(py, tag.as_cow()?, input, extra, slots, recursion_guard)
372+
self.find_call_validator(py, &tag, input, extra, slots, recursion_guard)
325373
}
326374
Discriminator::Function(ref func) => {
327375
let tag = func.call1(py, (input.to_object(py),))?;
328376
if tag.is_none(py) {
329377
Err(self.tag_not_found(input))
330378
} else {
331-
let tag: &PyString = tag.downcast(py)?;
332-
self.find_call_validator(py, tag.to_string_lossy(), input, extra, slots, recursion_guard)
379+
let tag: &PyAny = tag.downcast(py)?;
380+
self.find_call_validator(py, &(ChoiceKey::from_py(tag)?), input, extra, slots, recursion_guard)
333381
}
334382
}
335383
Discriminator::SelfSchema => self.find_call_validator(
336384
py,
337-
self.self_schema_tag(py, input)?,
385+
&ChoiceKey::Str(self.self_schema_tag(py, input)?.into_owned()),
338386
input,
339387
extra,
340388
slots,
@@ -407,23 +455,23 @@ impl TaggedUnionValidator {
407455
fn find_call_validator<'s, 'data>(
408456
&'s self,
409457
py: Python<'data>,
410-
tag: Cow<str>,
458+
tag: &ChoiceKey,
411459
input: &'data impl Input<'data>,
412460
extra: &Extra,
413461
slots: &'data [CombinedValidator],
414462
recursion_guard: &'s mut RecursionGuard,
415463
) -> ValResult<'data, PyObject> {
416-
if let Some(validator) = self.choices.get(tag.as_ref()) {
464+
if let Some(validator) = self.choices.get(tag) {
417465
return match validator.validate(py, input, extra, slots, recursion_guard) {
418466
Ok(res) => Ok(res),
419-
Err(err) => Err(err.with_outer_location(tag.as_ref().into())),
467+
Err(err) => Err(err.with_outer_location(tag.into())),
420468
};
421469
} else if let Some(ref repeat_choices) = self.repeat_choices {
422-
if let Some(choice_tag) = repeat_choices.get(tag.as_ref()) {
470+
if let Some(choice_tag) = repeat_choices.get(tag) {
423471
let validator = &self.choices[choice_tag];
424472
return match validator.validate(py, input, extra, slots, recursion_guard) {
425473
Ok(res) => Ok(res),
426-
Err(err) => Err(err.with_outer_location(tag.as_ref().into())),
474+
Err(err) => Err(err.with_outer_location(tag.into())),
427475
};
428476
}
429477
}

0 commit comments

Comments
 (0)