Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forbid_unknown_fields as optional bool argument to all decoders. #796

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions msgspec/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def convert(
type: Type[T],
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
from_attributes: bool = False,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
builtin_types: Union[Iterable[type], None] = None,
Expand All @@ -171,6 +172,7 @@ def convert(
type: Any,
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
from_attributes: bool = False,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
builtin_types: Union[Iterable[type], None] = None,
Expand Down
167 changes: 139 additions & 28 deletions msgspec/_core.c

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions msgspec/json.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Encoder:
class Decoder(Generic[T]):
type: Type[T]
strict: bool
forbid_unknown_fields: bool
dec_hook: dec_hook_sig
float_hook: float_hook_sig

Expand All @@ -54,6 +55,7 @@ class Decoder(Generic[T]):
self: Decoder[Any],
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
Expand All @@ -63,6 +65,7 @@ class Decoder(Generic[T]):
type: Type[T] = ...,
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
Expand All @@ -72,6 +75,7 @@ class Decoder(Generic[T]):
type: Any = ...,
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
Expand All @@ -84,6 +88,7 @@ def decode(
/,
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
) -> Any: ...
@overload
Expand All @@ -93,6 +98,7 @@ def decode(
*,
type: Type[T] = ...,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
) -> T: ...
@overload
Expand All @@ -102,6 +108,7 @@ def decode(
*,
type: Any = ...,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
) -> Any: ...
def encode(obj: Any, /, *, enc_hook: enc_hook_sig = None, order: Literal[None, "deterministic", "sorted"] = None) -> bytes: ...
Expand Down
7 changes: 7 additions & 0 deletions msgspec/msgpack.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ class Ext:
class Decoder(Generic[T]):
type: Type[T]
strict: bool
forbid_unknown_fields: bool
dec_hook: dec_hook_sig
ext_hook: ext_hook_sig
@overload
def __init__(
self: Decoder[Any],
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
ext_hook: ext_hook_sig = None,
) -> None: ...
Expand All @@ -45,6 +47,7 @@ class Decoder(Generic[T]):
type: Type[T] = ...,
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
ext_hook: ext_hook_sig = None,
) -> None: ...
Expand All @@ -54,6 +57,7 @@ class Decoder(Generic[T]):
type: Any = ...,
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
ext_hook: ext_hook_sig = None,
) -> None: ...
Expand Down Expand Up @@ -83,6 +87,7 @@ def decode(
/,
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
ext_hook: ext_hook_sig = None,
) -> Any: ...
Expand All @@ -93,6 +98,7 @@ def decode(
*,
type: Type[T] = ...,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
ext_hook: ext_hook_sig = None,
) -> T: ...
Expand All @@ -103,6 +109,7 @@ def decode(
*,
type: Any = ...,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: dec_hook_sig = None,
ext_hook: ext_hook_sig = None,
) -> Any: ...
Expand Down
11 changes: 10 additions & 1 deletion msgspec/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def decode(
buf: Union[Buffer, str],
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
) -> Any:
pass
Expand All @@ -124,6 +125,7 @@ def decode(
*,
type: Type[T] = ...,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
) -> T:
pass
Expand All @@ -135,12 +137,13 @@ def decode(
*,
type: Any = ...,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
) -> Any:
pass


def decode(buf, *, type=Any, strict=True, dec_hook=None):
def decode(buf, *, type=Any, strict=True, forbid_unknown_fields=False, dec_hook=None):
"""Deserialize an object from TOML.

Parameters
Expand All @@ -156,6 +159,11 @@ def decode(buf, *, type=Any, strict=True, dec_hook=None):
Whether type coercion rules should be strict. Setting to False enables
a wider set of coercion rules from string to non-string types for all
values. Default is True.
forbid_unknown_fields : bool, optional
If True, an error is raised if an unknown field is encountered at any point
in the decoding process. If False (the default), no error is raised and the
unknown field is skipped, unless the unknown field is for a Struct with
``forbid_unknown_fields=True``.
dec_hook : callable, optional
An optional callback for handling decoding custom types. Should have
the signature ``dec_hook(type: Type, obj: Any) -> Any``, where ``type``
Expand Down Expand Up @@ -193,5 +201,6 @@ def decode(buf, *, type=Any, strict=True, dec_hook=None):
builtin_types=(_datetime.datetime, _datetime.date, _datetime.time),
str_keys=True,
strict=strict,
forbid_unknown_fields=forbid_unknown_fields,
dec_hook=dec_hook,
)
10 changes: 9 additions & 1 deletion msgspec/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def decode(
buf: Union[Buffer, str],
*,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
) -> Any:
pass
Expand All @@ -128,12 +129,13 @@ def decode(
*,
type: Any = ...,
strict: bool = True,
forbid_unknown_fields: bool = False,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
) -> Any:
pass


def decode(buf, *, type=Any, strict=True, dec_hook=None):
def decode(buf, *, type=Any, strict=True, forbid_unknown_fields=False, dec_hook=None):
"""Deserialize an object from YAML.

Parameters
Expand All @@ -149,6 +151,11 @@ def decode(buf, *, type=Any, strict=True, dec_hook=None):
Whether type coercion rules should be strict. Setting to False enables
a wider set of coercion rules from string to non-string types for all
values. Default is True.
forbid_unknown_fields : bool, optional
If True, an error is raised if an unknown field is encountered at any point
in the decoding process. If False (the default), no error is raised and the
unknown field is skipped, unless the unknown field is for a Struct with
``forbid_unknown_fields=True``.
dec_hook : callable, optional
An optional callback for handling decoding custom types. Should have
the signature ``dec_hook(type: Type, obj: Any) -> Any``, where ``type``
Expand Down Expand Up @@ -188,5 +195,6 @@ def decode(buf, *, type=Any, strict=True, dec_hook=None):
type,
builtin_types=(_datetime.datetime, _datetime.date),
strict=strict,
forbid_unknown_fields=forbid_unknown_fields,
dec_hook=dec_hook,
)
64 changes: 51 additions & 13 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,31 +1809,49 @@ class Test(Struct, omit_defaults=True, array_like=True):


class TestStructForbidUnknownFields:
def test_forbid_unknown_fields(self, proto):
class Test(Struct, forbid_unknown_fields=True):
@pytest.mark.parametrize("forbid_on_cls", [True, False])
def test_forbid_unknown_fields(self, proto, forbid_on_cls):
class Test(Struct, forbid_unknown_fields=forbid_on_cls):
x: int
y: int

good = Test(1, 2)
assert proto.decode(proto.encode(good), type=Test) == good
assert (
proto.decode(
proto.encode(good), type=Test, forbid_unknown_fields=not forbid_on_cls
)
== good
)

bad = proto.encode({"x": 1, "y": 2, "z": 3})
with pytest.raises(ValidationError, match="Object contains unknown field `z`"):
proto.decode(bad, type=Test)
proto.decode(
bad,
type=Test,
forbid_unknown_fields=not forbid_on_cls,
)

def test_forbid_unknown_fields_array_like(self, proto):
class Test(Struct, forbid_unknown_fields=True, array_like=True):
@pytest.mark.parametrize("forbid_on_cls", [True, False])
def test_forbid_unknown_fields_array_like(self, proto, forbid_on_cls):
class Test(Struct, forbid_unknown_fields=forbid_on_cls, array_like=True):
x: int
y: int

good = Test(1, 2)
assert proto.decode(proto.encode(good), type=Test) == good
assert (
proto.decode(
proto.encode(good),
type=Test,
forbid_unknown_fields=not forbid_on_cls,
)
== good
)

bad = proto.encode([1, 2, 3])
with pytest.raises(
ValidationError, match="Expected `array` of at most length 2"
):
proto.decode(bad, type=Test)
proto.decode(bad, type=Test, forbid_unknown_fields=not forbid_on_cls)


class PointUpper(Struct, rename="upper"):
Expand Down Expand Up @@ -2066,6 +2084,17 @@ class Ex(TypedDict):
dec.decode(msg)
assert "Object missing required field `b`" == str(rec.value)

def test_forbid_unknown_fields(self, proto):
class Ex(TypedDict):
a: int
b: str

temp = proto.encode({"a": 1, "b": "two", "c": 3})
with pytest.raises(ValidationError, match="Object contains unknown field"):
proto.decode(temp, type=Ex, forbid_unknown_fields=True)

proto.decode(temp, type=Ex) == {"a": 1, "b": "two"}

def test_total_false(self, proto):
class Ex(TypedDict, total=False):
a: int
Expand Down Expand Up @@ -2672,6 +2701,9 @@ class Sub(Base):
assert proto.decode(msg, type=Base) == Base(1)
assert proto.decode(msg, type=Sub) == Sub(1, 2)

with pytest.raises(ValidationError, match="Object contains unknown field `y`"):
proto.decode(msg, type=Base, forbid_unknown_fields=True)

def test_multiple_dataclasses_errors(self, proto):
@dataclass
class Ex1:
Expand Down Expand Up @@ -2749,7 +2781,8 @@ class Ex:
proto.Decoder(mod.Ex)

@pytest.mark.parametrize("slots", [False, True])
def test_decode_dataclass(self, proto, slots):
@pytest.mark.parametrize("forbid_unknown_fields", [False, True])
def test_decode_dataclass(self, proto, slots, forbid_unknown_fields):
if slots:
if not PY310:
pytest.skip(reason="Python 3.10+ required")
Expand All @@ -2763,16 +2796,21 @@ class Example:
b: int
c: int

dec = proto.Decoder(Example)
dec = proto.Decoder(Example, forbid_unknown_fields=forbid_unknown_fields)
msg = Example(1, 2, 3)
res = dec.decode(proto.encode(msg))
assert res == msg

# Extra fields ignored
res = dec.decode(
proto.encode({"x": -1, "a": 1, "y": -2, "b": 2, "z": -3, "c": 3, "": -4})
encoded = proto.encode(
{"x": -1, "a": 1, "y": -2, "b": 2, "z": -3, "c": 3, "": -4}
)
assert res == msg
if forbid_unknown_fields:
with pytest.raises(ValidationError, match="Object contains unknown field"):
dec.decode(encoded)
else:
res = dec.decode(encoded)
assert res == msg

# Missing fields error
with pytest.raises(ValidationError, match="missing required field `b`"):
Expand Down
Loading