Skip to content

Commit 9bbe0d9

Browse files
authored
Add IsSubclassValidator (#301)
* add IsSubclassValidator * tweak type hint
1 parent 72caae4 commit 9bbe0d9

File tree

9 files changed

+194
-9
lines changed

9 files changed

+194
-9
lines changed

pydantic_core/core_schema.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,20 @@ def is_instance_schema(
378378
)
379379

380380

381+
class IsSubclassSchema(TypedDict, total=False):
382+
type: Required[Literal['is-subclass']]
383+
cls: Required[Type[Any]]
384+
cls_repr: str
385+
ref: str
386+
extra: Any
387+
388+
389+
def is_subclass_schema(
390+
cls: Type[Any], *, cls_repr: str | None = None, ref: str | None = None, extra: Any = None
391+
) -> IsInstanceSchema:
392+
return dict_not_none(type='is-subclass', cls=cls, cls_repr=cls_repr, ref=ref, extra=extra)
393+
394+
381395
class CallableSchema(TypedDict, total=False):
382396
type: Required[Literal['callable']]
383397
ref: str
@@ -1027,6 +1041,7 @@ def json_schema(schema: CoreSchema | None = None, *, ref: str | None = None, ext
10271041
TimedeltaSchema,
10281042
LiteralSchema,
10291043
IsInstanceSchema,
1044+
IsSubclassSchema,
10301045
CallableSchema,
10311046
ListSchema,
10321047
TuplePositionalSchema,
@@ -1118,6 +1133,7 @@ def json_schema(schema: CoreSchema | None = None, *, ref: str | None = None, ext
11181133
'time_delta_parsing',
11191134
'frozen_set_type',
11201135
'is_instance_of',
1136+
'is_subclass_of',
11211137
'callable_type',
11221138
'union_tag_invalid',
11231139
'union_tag_not_found',

src/errors/kinds.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ pub enum ErrorKind {
279279
IsInstanceOf {
280280
class: String,
281281
},
282+
#[strum(message = "Input should be a subclass of {class}")]
283+
IsSubclassOf {
284+
class: String,
285+
},
282286
#[strum(message = "Input should be callable")]
283287
CallableType,
284288
// ---------------------
@@ -432,6 +436,7 @@ impl ErrorKind {
432436
Self::DatetimeObjectInvalid { .. } => extract_context!(DatetimeObjectInvalid, ctx, error: String),
433437
Self::TimeDeltaParsing { .. } => extract_context!(Cow::Owned, TimeDeltaParsing, ctx, error: String),
434438
Self::IsInstanceOf { .. } => extract_context!(IsInstanceOf, ctx, class: String),
439+
Self::IsSubclassOf { .. } => extract_context!(IsSubclassOf, ctx, class: String),
435440
Self::UnionTagInvalid { .. } => extract_context!(
436441
UnionTagInvalid,
437442
ctx,
@@ -520,6 +525,7 @@ impl ErrorKind {
520525
Self::DatetimeObjectInvalid { error } => render!(self, error),
521526
Self::TimeDeltaParsing { error } => render!(self, error),
522527
Self::IsInstanceOf { class } => render!(self, class),
528+
Self::IsSubclassOf { class } => render!(self, class),
523529
Self::UnionTagInvalid {
524530
discriminator,
525531
tag,
@@ -568,6 +574,7 @@ impl ErrorKind {
568574
Self::DatetimeObjectInvalid { error } => py_dict!(py, error),
569575
Self::TimeDeltaParsing { error } => py_dict!(py, error),
570576
Self::IsInstanceOf { class } => py_dict!(py, class),
577+
Self::IsSubclassOf { class } => py_dict!(py, class),
571578
Self::UnionTagInvalid {
572579
discriminator,
573580
tag,

src/input/input_abstract.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
3838

3939
fn is_none(&self) -> bool;
4040

41-
fn is_type(&self, _class: &PyType) -> ValResult<bool> {
42-
Ok(false)
43-
}
44-
4541
#[cfg_attr(has_no_coverage, no_coverage)]
4642
fn get_attr(&self, _name: &PyString) -> Option<&PyAny> {
4743
None
@@ -50,6 +46,14 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
5046
// input_ prefix to differentiate from the function on PyAny
5147
fn input_is_instance(&self, class: &PyAny, json_mask: u8) -> PyResult<bool>;
5248

49+
fn is_exact_instance(&self, _class: &PyType) -> PyResult<bool> {
50+
Ok(false)
51+
}
52+
53+
fn input_is_subclass(&self, _class: &PyType) -> PyResult<bool> {
54+
Ok(false)
55+
}
56+
5357
fn callable(&self) -> bool {
5458
false
5559
}

src/input/input_python.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ impl<'a> Input<'a> for PyAny {
8585
self.is_none()
8686
}
8787

88-
fn is_type(&self, class: &PyType) -> ValResult<bool> {
89-
Ok(self.get_type().eq(class)?)
90-
}
91-
9288
fn get_attr(&self, name: &PyString) -> Option<&PyAny> {
9389
self.getattr(name).ok()
9490
}
@@ -101,6 +97,18 @@ impl<'a> Input<'a> for PyAny {
10197
Ok(result == 1)
10298
}
10399

100+
fn is_exact_instance(&self, class: &PyType) -> PyResult<bool> {
101+
self.get_type().eq(class)
102+
}
103+
104+
fn input_is_subclass(&self, class: &PyType) -> PyResult<bool> {
105+
if let Ok(py_type) = self.cast_as::<PyType>() {
106+
py_type.is_subclass(class)
107+
} else {
108+
Ok(false)
109+
}
110+
}
111+
104112
fn callable(&self) -> bool {
105113
self.is_callable()
106114
}

src/validators/is_subclass.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use pyo3::intern;
2+
use pyo3::prelude::*;
3+
use pyo3::types::{PyDict, PyType};
4+
5+
use crate::build_tools::SchemaDict;
6+
use crate::errors::{ErrorKind, ValError, ValResult};
7+
use crate::input::Input;
8+
use crate::recursion_guard::RecursionGuard;
9+
10+
use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
11+
12+
#[derive(Debug, Clone)]
13+
pub struct IsSubclassValidator {
14+
class: Py<PyType>,
15+
class_repr: String,
16+
name: String,
17+
}
18+
19+
impl BuildValidator for IsSubclassValidator {
20+
const EXPECTED_TYPE: &'static str = "is-subclass";
21+
22+
fn build(
23+
schema: &PyDict,
24+
_config: Option<&PyDict>,
25+
_build_context: &mut BuildContext,
26+
) -> PyResult<CombinedValidator> {
27+
let py = schema.py();
28+
let class: &PyType = schema.get_as_req(intern!(py, "cls"))?;
29+
30+
let class_repr = match schema.get_as(intern!(py, "cls_repr"))? {
31+
Some(s) => s,
32+
None => class.name()?.to_string(),
33+
};
34+
let name = format!("{}[{}]", Self::EXPECTED_TYPE, class_repr);
35+
Ok(Self {
36+
class: class.into(),
37+
class_repr,
38+
name,
39+
}
40+
.into())
41+
}
42+
}
43+
44+
impl Validator for IsSubclassValidator {
45+
fn validate<'s, 'data>(
46+
&'s self,
47+
py: Python<'data>,
48+
input: &'data impl Input<'data>,
49+
_extra: &Extra,
50+
_slots: &'data [CombinedValidator],
51+
_recursion_guard: &'s mut RecursionGuard,
52+
) -> ValResult<'data, PyObject> {
53+
match input.input_is_subclass(self.class.as_ref(py))? {
54+
true => Ok(input.to_object(py)),
55+
false => Err(ValError::new(
56+
ErrorKind::IsSubclassOf {
57+
class: self.class_repr.clone(),
58+
},
59+
input,
60+
)),
61+
}
62+
}
63+
64+
fn get_name(&self) -> &str {
65+
&self.name
66+
}
67+
}

src/validators/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ mod function;
3232
mod generator;
3333
mod int;
3434
mod is_instance;
35+
mod is_subclass;
3536
mod json;
3637
mod list;
3738
mod literal;
@@ -370,6 +371,7 @@ pub fn build_validator<'a>(
370371
timedelta::TimeDeltaValidator,
371372
// introspection types
372373
is_instance::IsInstanceValidator,
374+
is_subclass::IsSubclassValidator,
373375
callable::CallableValidator,
374376
// arguments
375377
arguments::ArgumentsValidator,
@@ -488,6 +490,7 @@ pub enum CombinedValidator {
488490
Timedelta(timedelta::TimeDeltaValidator),
489491
// introspection types
490492
IsInstance(is_instance::IsInstanceValidator),
493+
IsSubclass(is_subclass::IsSubclassValidator),
491494
Callable(callable::CallableValidator),
492495
// arguments
493496
Arguments(arguments::ArgumentsValidator),

src/validators/new_class.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl Validator for NewClassValidator {
7474
recursion_guard: &'s mut RecursionGuard,
7575
) -> ValResult<'data, PyObject> {
7676
let class = self.class.as_ref(py);
77-
if input.is_type(class)? {
77+
if input.is_exact_instance(class)? {
7878
if self.revalidate {
7979
let fields_set = input.get_attr(intern!(py, "__fields_set__"));
8080
let output = self.validator.validate(py, input, extra, slots, recursion_guard)?;

tests/test_errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def f(input_value, **kwargs):
239239
('time_delta_parsing', 'Input should be a valid timedelta, foobar', {'error': 'foobar'}),
240240
('frozen_set_type', 'Input should be a valid frozenset', None),
241241
('is_instance_of', 'Input should be an instance of Foo', {'class': 'Foo'}),
242+
('is_subclass_of', 'Input should be a subclass of Foo', {'class': 'Foo'}),
242243
('callable_type', 'Input should be callable', None),
243244
(
244245
'union_tag_invalid',
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
3+
from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema
4+
5+
6+
class Foo:
7+
pass
8+
9+
10+
class Foobar(Foo):
11+
pass
12+
13+
14+
class Bar:
15+
pass
16+
17+
18+
def test_is_subclass_basic():
19+
v = SchemaValidator(core_schema.is_subclass_schema(Foo))
20+
assert v.validate_python(Foo) == Foo
21+
with pytest.raises(ValidationError) as exc_info:
22+
v.validate_python(Bar)
23+
# insert_assert(exc_info.value.errors())
24+
assert exc_info.value.errors() == [
25+
{
26+
'kind': 'is_subclass_of',
27+
'loc': [],
28+
'message': 'Input should be a subclass of Foo',
29+
'input_value': Bar,
30+
'context': {'class': 'Foo'},
31+
}
32+
]
33+
34+
35+
@pytest.mark.parametrize(
36+
'input_value,valid',
37+
[
38+
(Foo, True),
39+
(Foobar, True),
40+
(Bar, False),
41+
(type, False),
42+
(1, False),
43+
('foo', False),
44+
(Foo(), False),
45+
(Foobar(), False),
46+
(Bar(), False),
47+
],
48+
)
49+
def test_is_subclass(input_value, valid):
50+
v = SchemaValidator(core_schema.is_subclass_schema(Foo))
51+
assert v.isinstance_python(input_value) == valid
52+
53+
54+
def test_not_parent():
55+
v = SchemaValidator(core_schema.is_subclass_schema(Foobar))
56+
assert v.isinstance_python(Foobar)
57+
assert not v.isinstance_python(Foo)
58+
59+
60+
def test_invalid_type():
61+
with pytest.raises(SchemaError, match="TypeError: 'Foo' object cannot be converted to 'PyType"):
62+
SchemaValidator(core_schema.is_subclass_schema(Foo()))
63+
64+
65+
def test_custom_repr():
66+
v = SchemaValidator(core_schema.is_subclass_schema(Foo, cls_repr='Spam'))
67+
assert v.validate_python(Foo) == Foo
68+
with pytest.raises(ValidationError) as exc_info:
69+
v.validate_python(Bar)
70+
# insert_assert(exc_info.value.errors())
71+
assert exc_info.value.errors() == [
72+
{
73+
'kind': 'is_subclass_of',
74+
'loc': [],
75+
'message': 'Input should be a subclass of Spam',
76+
'input_value': Bar,
77+
'context': {'class': 'Spam'},
78+
}
79+
]

0 commit comments

Comments
 (0)