22
33import sys
44from datetime import date , datetime , time , timedelta
5- from typing import Any , Callable , Dict , List , Optional , Type , Union
5+ from typing import Any , Callable , Dict , List , Optional , Type , Union , overload
66
77if sys .version_info < (3 , 11 ):
88 from typing_extensions import NotRequired , Required
@@ -275,7 +275,7 @@ class LiteralSchema(TypedDict):
275275
276276
277277def literal_schema (* expected : Any , ref : str | None = None ) -> LiteralSchema :
278- return dict_not_none (type = 'literal' , expected = list ( expected ) , ref = ref )
278+ return dict_not_none (type = 'literal' , expected = expected , ref = ref )
279279
280280
281281class IsInstanceSchema (TypedDict ):
@@ -327,8 +327,7 @@ class TuplePositionalSchema(TypedDict, total=False):
327327
328328
329329def tuple_positional_schema (
330- items_schema : List [CoreSchema ],
331- * ,
330+ * items_schema : CoreSchema ,
332331 extra_schema : CoreSchema | None = None ,
333332 strict : bool | None = None ,
334333 ref : str | None = None ,
@@ -349,7 +348,7 @@ class TupleVariableSchema(TypedDict, total=False):
349348
350349
351350def tuple_variable_schema (
352- items_schema : CoreSchema ,
351+ items_schema : CoreSchema | None = None ,
353352 * ,
354353 min_length : int | None = None ,
355354 max_length : int | None = None ,
@@ -427,9 +426,9 @@ class DictSchema(TypedDict, total=False):
427426
428427
429428def dict_schema (
430- * ,
431429 keys_schema : CoreSchema | None = None ,
432430 values_schema : CoreSchema | None = None ,
431+ * ,
433432 min_length : int | None = None ,
434433 max_length : int | None = None ,
435434 strict : bool | None = None ,
@@ -456,16 +455,27 @@ class FunctionSchema(TypedDict):
456455 ref : NotRequired [str ]
457456
458457
459- def function_schema (
460- mode : Literal ['before' , 'after' , 'wrap' ],
461- function : Callable [..., Any ],
462- schema : CoreSchema ,
463- * ,
464- validator_instance : Any | None = None ,
465- ref : str | None = None ,
458+ def function_before_schema (
459+ function : Callable [..., Any ], schema : CoreSchema , * , validator_instance : Any | None = None , ref : str | None = None
466460) -> FunctionSchema :
467461 return dict_not_none (
468- type = 'function' , mode = mode , function = function , schema = schema , validator_instance = validator_instance , ref = ref
462+ type = 'function' , mode = 'before' , function = function , schema = schema , validator_instance = validator_instance , ref = ref
463+ )
464+
465+
466+ def function_after_schema (
467+ function : Callable [..., Any ], schema : CoreSchema , * , validator_instance : Any | None = None , ref : str | None = None
468+ ) -> FunctionSchema :
469+ return dict_not_none (
470+ type = 'function' , mode = 'after' , function = function , schema = schema , validator_instance = validator_instance , ref = ref
471+ )
472+
473+
474+ def function_wrap_schema (
475+ function : Callable [..., Any ], schema : CoreSchema , * , validator_instance : Any | None = None , ref : str | None = None
476+ ) -> FunctionSchema :
477+ return dict_not_none (
478+ type = 'function' , mode = 'wrap' , function = function , schema = schema , validator_instance = validator_instance , ref = ref
469479 )
470480
471481
@@ -496,24 +506,24 @@ class WithDefaultSchema(TypedDict, total=False):
496506 ref : str
497507
498508
509+ Omitted = object ()
510+
511+
499512def with_default_schema (
500513 schema : CoreSchema ,
501514 * ,
502- default : Any | None = None ,
515+ default : Any = Omitted ,
503516 default_factory : Callable [[], Any ] | None = None ,
504517 on_error : Literal ['raise' , 'omit' , 'default' ] | None = None ,
505518 strict : bool | None = None ,
506519 ref : str | None = None ,
507520) -> WithDefaultSchema :
508- return dict_not_none (
509- type = 'default' ,
510- schema = schema ,
511- default = default ,
512- default_factory = default_factory ,
513- on_error = on_error ,
514- strict = strict ,
515- ref = ref ,
521+ s = dict_not_none (
522+ type = 'default' , schema = schema , default_factory = default_factory , on_error = on_error , strict = strict , ref = ref
516523 )
524+ if default is not Omitted :
525+ s ['default' ] = default
526+ return s
517527
518528
519529class NullableSchema (TypedDict , total = False ):
@@ -532,6 +542,14 @@ class CustomError(TypedDict):
532542 message : str
533543
534544
545+ def _custom_error (kind : str | None , message : str | None ) -> CustomError | None :
546+ if kind is None and message is None :
547+ return None
548+ else :
549+ # let schema validation raise the error
550+ return CustomError (kind = kind , message = message ) # type: ignore
551+
552+
535553class UnionSchema (TypedDict , total = False ):
536554 type : Required [Literal ['union' ]]
537555 choices : Required [List [CoreSchema ]]
@@ -540,8 +558,36 @@ class UnionSchema(TypedDict, total=False):
540558 ref : str
541559
542560
543- def union_schema (choices : List [CoreSchema ], * , strict : bool | None = None , ref : str | None = None ) -> UnionSchema :
544- return dict_not_none (type = 'union' , choices = choices , strict = strict , ref = ref )
561+ @overload
562+ def union_schema (
563+ * choices : CoreSchema ,
564+ custom_error_kind : str ,
565+ custom_error_message : str ,
566+ strict : bool | None = None ,
567+ ref : str | None = None ,
568+ ) -> UnionSchema :
569+ ...
570+
571+
572+ @overload
573+ def union_schema (* choices : CoreSchema , strict : bool | None = None , ref : str | None = None ) -> UnionSchema :
574+ ...
575+
576+
577+ def union_schema (
578+ * choices : CoreSchema ,
579+ custom_error_kind : str | None = None ,
580+ custom_error_message : str | None = None ,
581+ strict : bool | None = None ,
582+ ref : str | None = None ,
583+ ) -> UnionSchema :
584+ return dict_not_none (
585+ type = 'union' ,
586+ choices = choices ,
587+ custom_error = _custom_error (custom_error_kind , custom_error_message ),
588+ strict = strict ,
589+ ref = ref ,
590+ )
545591
546592
547593class TaggedUnionSchema (TypedDict ):
@@ -553,14 +599,47 @@ class TaggedUnionSchema(TypedDict):
553599 ref : NotRequired [str ]
554600
555601
602+ @overload
603+ def tagged_union_schema (
604+ choices : Dict [str , CoreSchema ],
605+ discriminator : str | list [str | int ] | list [list [str | int ]] | Callable [[Any ], str | None ],
606+ * ,
607+ custom_error_kind : str ,
608+ custom_error_message : str ,
609+ strict : bool | None = None ,
610+ ref : str | None = None ,
611+ ) -> TaggedUnionSchema :
612+ ...
613+
614+
615+ @overload
556616def tagged_union_schema (
557617 choices : Dict [str , CoreSchema ],
558618 discriminator : str | list [str | int ] | list [list [str | int ]] | Callable [[Any ], str | None ],
559619 * ,
560620 strict : bool | None = None ,
561621 ref : str | None = None ,
562622) -> TaggedUnionSchema :
563- return dict_not_none (type = 'tagged-union' , choices = choices , discriminator = discriminator , strict = strict , ref = ref )
623+ ...
624+
625+
626+ def tagged_union_schema (
627+ choices : Dict [str , CoreSchema ],
628+ discriminator : str | list [str | int ] | list [list [str | int ]] | Callable [[Any ], str | None ],
629+ * ,
630+ custom_error_kind : str | None = None ,
631+ custom_error_message : str | None = None ,
632+ strict : bool | None = None ,
633+ ref : str | None = None ,
634+ ) -> TaggedUnionSchema :
635+ return dict_not_none (
636+ type = 'tagged-union' ,
637+ choices = choices ,
638+ discriminator = discriminator ,
639+ custom_error = _custom_error (custom_error_kind , custom_error_message ),
640+ strict = strict ,
641+ ref = ref ,
642+ )
564643
565644
566645class ChainSchema (TypedDict ):
@@ -569,7 +648,7 @@ class ChainSchema(TypedDict):
569648 ref : NotRequired [str ]
570649
571650
572- def chain_schema (steps : List [ CoreSchema ], * , ref : str | None = None ) -> ChainSchema :
651+ def chain_schema (* steps : CoreSchema , ref : str | None = None ) -> ChainSchema :
573652 return dict_not_none (type = 'chain' , steps = steps , ref = ref )
574653
575654
@@ -681,16 +760,15 @@ class ArgumentsSchema(TypedDict, total=False):
681760
682761
683762def arguments_schema (
684- arguments_schema : List [ArgumentsParameter ],
685- * ,
763+ * arguments : ArgumentsParameter ,
686764 populate_by_name : bool | None = None ,
687765 var_args_schema : CoreSchema | None = None ,
688766 var_kwargs_schema : CoreSchema | None = None ,
689767 ref : str | None = None ,
690768) -> ArgumentsSchema :
691769 return dict_not_none (
692770 type = 'arguments' ,
693- arguments_schema = arguments_schema ,
771+ arguments_schema = arguments ,
694772 populate_by_name = populate_by_name ,
695773 var_args_schema = var_args_schema ,
696774 var_kwargs_schema = var_kwargs_schema ,
@@ -700,21 +778,21 @@ def arguments_schema(
700778
701779class CallSchema (TypedDict ):
702780 type : Literal ['call' ]
703- function : Callable [..., Any ]
704781 arguments_schema : CoreSchema
782+ function : Callable [..., Any ]
705783 return_schema : NotRequired [CoreSchema ]
706784 ref : NotRequired [str ]
707785
708786
709787def call_schema (
788+ arguments : CoreSchema ,
710789 function : Callable [..., Any ],
711- arguments_schema : CoreSchema ,
712790 * ,
713791 return_schema : CoreSchema | None = None ,
714792 ref : str | None = None ,
715793) -> CallSchema :
716794 return dict_not_none (
717- type = 'call' , function = function , arguments_schema = arguments_schema , return_schema = return_schema , ref = ref
795+ type = 'call' , arguments_schema = arguments , function = function , return_schema = return_schema , ref = ref
718796 )
719797
720798
0 commit comments