@@ -93,6 +93,12 @@ def __repr__(self) -> str:
9393 ...
9494
9595
96+ class FieldSerializationInfo (SerializationInfo , Protocol ):
97+ @property
98+ def field_name (self ) -> str :
99+ ...
100+
101+
96102class ValidationInfo (Protocol ):
97103 """
98104 Argument passed to validation functions.
@@ -109,7 +115,7 @@ def config(self) -> CoreConfig | None:
109115 ...
110116
111117
112- class ModelFieldValidationInfo (ValidationInfo , Protocol ):
118+ class FieldValidationInfo (ValidationInfo , Protocol ):
113119 """
114120 Argument passed to model field validation functions.
115121 """
@@ -166,11 +172,26 @@ def simple_ser_schema(type: ExpectedSerializationTypes) -> SimpleSerSchema:
166172 return SimpleSerSchema (type = type )
167173
168174
169- class SerializePlainFunction (Protocol ): # pragma: no cover
175+ class GeneralSerializePlainFunction (Protocol ): # pragma: no cover
170176 def __call__ (self , __input_value : Any , __info : SerializationInfo ) -> Any :
171177 ...
172178
173179
180+ class FieldSerializePlainFunction (Protocol ): # pragma: no cover
181+ def __call__ (self , __model : Any , __input_value : Any , __info : FieldSerializationInfo ) -> Any :
182+ ...
183+
184+
185+ class GeneralSerializePlainFunctionSchema (TypedDict ):
186+ type : Literal ['general' ]
187+ function : GeneralSerializePlainFunction
188+
189+
190+ class FieldSerializePlainFunctionSchema (TypedDict ):
191+ type : Literal ['field' ]
192+ function : FieldSerializePlainFunction
193+
194+
174195# must match `src/serializers/ob_type.rs::ObType`
175196JsonReturnTypes = Literal [
176197 'int' ,
@@ -212,13 +233,16 @@ def __call__(self, __input_value: Any, __info: SerializationInfo) -> Any:
212233
213234class FunctionPlainSerSchema (TypedDict , total = False ):
214235 type : Required [Literal ['function-plain' ]]
215- function : Required [SerializePlainFunction ]
236+ function : Required [Union [ GeneralSerializePlainFunctionSchema , FieldSerializePlainFunctionSchema ] ]
216237 json_return_type : JsonReturnTypes
217238 when_used : WhenUsed # default: 'always'
218239
219240
220- def function_plain_ser_schema (
221- function : SerializePlainFunction , * , json_return_type : JsonReturnTypes | None = None , when_used : WhenUsed = 'always'
241+ def general_function_plain_ser_schema (
242+ function : GeneralSerializePlainFunction ,
243+ * ,
244+ json_return_type : JsonReturnTypes | None = None ,
245+ when_used : WhenUsed = 'always' ,
222246) -> FunctionPlainSerSchema :
223247 """
224248 Returns a schema for serialization with a function.
@@ -232,7 +256,35 @@ def function_plain_ser_schema(
232256 # just to avoid extra elements in schema, and to use the actual default defined in rust
233257 when_used = None # type: ignore
234258 return dict_not_none (
235- type = 'function-plain' , function = function , json_return_type = json_return_type , when_used = when_used
259+ type = 'function-plain' ,
260+ function = {'type' : 'general' , 'function' : function },
261+ json_return_type = json_return_type ,
262+ when_used = when_used ,
263+ )
264+
265+
266+ def field_function_plain_ser_schema (
267+ function : FieldSerializePlainFunction ,
268+ * ,
269+ json_return_type : JsonReturnTypes | None = None ,
270+ when_used : WhenUsed = 'always' ,
271+ ) -> FunctionPlainSerSchema :
272+ """
273+ Returns a schema to serialize a field from a model, TypedDict or dataclass.
274+
275+ Args:
276+ function: The function to use for serialization
277+ json_return_type: The type that the function returns if `mode='json'`
278+ when_used: When the function should be called
279+ """
280+ if when_used == 'always' :
281+ # just to avoid extra elements in schema, and to use the actual default defined in rust
282+ when_used = None # type: ignore
283+ return dict_not_none (
284+ type = 'function-plain' ,
285+ function = {'type' : 'field' , 'function' : function },
286+ json_return_type = json_return_type ,
287+ when_used = when_used ,
236288 )
237289
238290
@@ -241,21 +293,38 @@ def __call__(self, __input_value: Any, __index_key: int | str | None = None) ->
241293 ...
242294
243295
244- class SerializeWrapFunction (Protocol ): # pragma: no cover
296+ class GeneralSerializeWrapFunction (Protocol ): # pragma: no cover
245297 def __call__ (self , __input_value : Any , __serializer : SerializeWrapHandler , __info : SerializationInfo ) -> Any :
246298 ...
247299
248300
301+ class FieldSerializeWrapFunction (Protocol ): # pragma: no cover
302+ def __call__ (
303+ self , __model : Any , __input_value : Any , __serializer : SerializeWrapHandler , __info : FieldSerializationInfo
304+ ) -> Any :
305+ ...
306+
307+
308+ class GeneralSerializeWrapFunctionSchema (TypedDict ):
309+ type : Literal ['general' ]
310+ function : GeneralSerializeWrapFunction
311+
312+
313+ class FieldSerializeWrapFunctionSchema (TypedDict ):
314+ type : Literal ['field' ]
315+ function : FieldSerializeWrapFunction
316+
317+
249318class FunctionWrapSerSchema (TypedDict , total = False ):
250319 type : Required [Literal ['function-wrap' ]]
251- function : Required [SerializeWrapFunction ]
320+ function : Required [Union [ GeneralSerializeWrapFunctionSchema , FieldSerializeWrapFunctionSchema ] ]
252321 schema : Required [CoreSchema ]
253322 json_return_type : JsonReturnTypes
254323 when_used : WhenUsed # default: 'always'
255324
256325
257- def function_wrap_ser_schema (
258- function : SerializeWrapFunction ,
326+ def general_function_wrap_ser_schema (
327+ function : GeneralSerializeWrapFunction ,
259328 schema : CoreSchema ,
260329 * ,
261330 json_return_type : JsonReturnTypes | None = None ,
@@ -274,7 +343,39 @@ def function_wrap_ser_schema(
274343 # just to avoid extra elements in schema, and to use the actual default defined in rust
275344 when_used = None # type: ignore
276345 return dict_not_none (
277- type = 'function-wrap' , schema = schema , function = function , json_return_type = json_return_type , when_used = when_used
346+ type = 'function-wrap' ,
347+ schema = schema ,
348+ function = {'type' : 'general' , 'function' : function },
349+ json_return_type = json_return_type ,
350+ when_used = when_used ,
351+ )
352+
353+
354+ def field_function_wrap_ser_schema (
355+ function : FieldSerializeWrapFunction ,
356+ schema : CoreSchema ,
357+ * ,
358+ json_return_type : JsonReturnTypes | None = None ,
359+ when_used : WhenUsed = 'always' ,
360+ ) -> FunctionWrapSerSchema :
361+ """
362+ Returns a schema to serialize a field from a model, TypedDict or dataclass.
363+
364+ Args:
365+ function: The function to use for serialization
366+ schema: The schema to use for the inner serialization
367+ json_return_type: The type that the function returns if `mode='json'`
368+ when_used: When the function should be called
369+ """
370+ if when_used == 'always' :
371+ # just to avoid extra elements in schema, and to use the actual default defined in rust
372+ when_used = None # type: ignore
373+ return dict_not_none (
374+ type = 'function-wrap' ,
375+ schema = schema ,
376+ function = {'type' : 'field' , 'function' : function },
377+ json_return_type = json_return_type ,
378+ when_used = when_used ,
278379 )
279380
280381
@@ -290,7 +391,7 @@ def format_ser_schema(formatting_string: str, *, when_used: WhenUsed = 'json-unl
290391
291392 Args:
292393 formatting_string: String defining the format to use
293- when_used: Same meaning as for [function_plain_ser_schema ], but with a different default
394+ when_used: Same meaning as for [general_function_plain_ser_schema ], but with a different default
294395 """
295396 if when_used == 'json-unless-none' :
296397 # just to avoid extra elements in schema, and to use the actual default defined in rust
@@ -308,7 +409,7 @@ def to_string_ser_schema(*, when_used: WhenUsed = 'json-unless-none') -> ToStrin
308409 Returns a schema for serialization using python's `str()` / `__str__` method.
309410
310411 Args:
311- when_used: Same meaning as for [function_plain_ser_schema ], but with a different default
412+ when_used: Same meaning as for [general_function_plain_ser_schema ], but with a different default
312413 """
313414 s = dict (type = 'to-string' )
314415 if when_used != 'json-unless-none' :
@@ -1491,7 +1592,7 @@ def __call__(self, __input_value: Any, __info: ValidationInfo) -> Any: # pragma
14911592
14921593
14931594class FieldValidatorFunction (Protocol ):
1494- def __call__ (self , __input_value : Any , __info : ModelFieldValidationInfo ) -> Any : # pragma: no cover
1595+ def __call__ (self , __input_value : Any , __info : FieldValidationInfo ) -> Any : # pragma: no cover
14951596 ...
14961597
14971598
@@ -1532,7 +1633,7 @@ def field_before_validation_function(
15321633 ```py
15331634 from pydantic_core import SchemaValidator, core_schema
15341635
1535- def fn(v: bytes, info: core_schema.ModelFieldValidationInfo ) -> str:
1636+ def fn(v: bytes, info: core_schema.FieldValidationInfo ) -> str:
15361637 assert info.data is not None
15371638 assert info.field_name is not None
15381639 return v.decode() + 'world'
@@ -1624,7 +1725,7 @@ def field_after_validation_function(
16241725 ```py
16251726 from pydantic_core import SchemaValidator, core_schema
16261727
1627- def fn(v: str, info: core_schema.ModelFieldValidationInfo ) -> str:
1728+ def fn(v: str, info: core_schema.FieldValidationInfo ) -> str:
16281729 assert info.data is not None
16291730 assert info.field_name is not None
16301731 return v + 'world'
@@ -1709,7 +1810,7 @@ def __call__(
17091810
17101811class FieldWrapValidatorFunction (Protocol ):
17111812 def __call__ (
1712- self , __input_value : Any , __validator : CallableValidator , __info : ModelFieldValidationInfo
1813+ self , __input_value : Any , __validator : CallableValidator , __info : FieldValidationInfo
17131814 ) -> Any : # pragma: no cover
17141815 ...
17151816
@@ -1791,7 +1892,7 @@ def field_wrap_validation_function(
17911892 ```py
17921893 from pydantic_core import SchemaValidator, core_schema
17931894
1794- def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.ModelFieldValidationInfo ) -> str:
1895+ def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.FieldValidationInfo ) -> str:
17951896 assert info.data is not None
17961897 assert info.field_name is not None
17971898 return validator(v) + 'world'
@@ -1881,7 +1982,7 @@ def field_plain_validation_function(
18811982 from typing import Any
18821983 from pydantic_core import SchemaValidator, core_schema
18831984
1884- def fn(v: Any, info: core_schema.ModelFieldValidationInfo ) -> str:
1985+ def fn(v: Any, info: core_schema.FieldValidationInfo ) -> str:
18851986 assert info.data is not None
18861987 assert info.field_name is not None
18871988 return str(v) + 'world'
0 commit comments