Skip to content

Commit

Permalink
feat: support frozen dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
msto committed Oct 16, 2024
1 parent b0f32e7 commit ca1cee2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
14 changes: 13 additions & 1 deletion dataclass_io/_lib/dataclass_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,23 @@ def row_to_dataclass(
# version of the dataclass with validation. We instantiate from this version to take
# advantage of pydantic's validation, but then unpack the validated data in order to return
# an instance of the user-specified dataclass.
pydantic_cls = pydantic_dataclass(dataclass_type)

params = dataclass_type.__dataclass_params__ # type:ignore[attr-defined]

pydantic_cls = pydantic_dataclass(
_cls=dataclass_type,
repr=params.repr,
eq=params.eq,
order=params.order,
unsafe_hash=params.unsafe_hash,
frozen=params.frozen,
)

validated_data = pydantic_cls(**row)
unpacked_data = {
field.name: getattr(validated_data, field.name) for field in fields(dataclass_type)
}

data = dataclass_type(**unpacked_data)

return data
43 changes: 21 additions & 22 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
from dataclasses import dataclass
from pathlib import Path
from typing import cast

from dataclass_io.reader import DataclassReader

import pytest

@dataclass(kw_only=True, eq=True)
class FakeDataclass:
foo: str
bar: int
from dataclass_io.reader import DataclassReader


def test_reader(tmp_path: Path) -> None:
@pytest.mark.parametrize("kw_only", [True, False])
@pytest.mark.parametrize("eq", [True, False])
@pytest.mark.parametrize("frozen", [True, False])
def test_reader(kw_only: bool, eq: bool, frozen: bool, tmp_path: Path) -> None:
fpath = tmp_path / "test.txt"

@dataclass(frozen=frozen, eq=eq, kw_only=kw_only) # type: ignore[literal-required]
class FakeDataclass:
foo: str
bar: int

with fpath.open("w") as f:
f.write("foo\tbar\n")
f.write("abc\t1\n")

rows: list[FakeDataclass]
with DataclassReader.open(filename=fpath, dataclass_type=FakeDataclass) as reader:
rows = [row for row in reader]

assert rows[0] == FakeDataclass(foo="abc", bar=1)


def test_reader_from_str(tmp_path: Path) -> None:
"""Test that we can create a reader when `filename` is a `str`."""
fpath = tmp_path / "test.txt"

with fpath.open("w") as f:
f.write("foo\tbar\n")
f.write("abc\t1\n")
# TODO make `DataclassReader` generic
rows = cast(list[FakeDataclass], [row for row in reader])

with DataclassReader.open(filename=str(fpath), dataclass_type=FakeDataclass) as reader:
rows = [row for row in reader]
assert len(rows) == 1

assert rows[0] == FakeDataclass(foo="abc", bar=1)
if eq:
assert rows[0] == FakeDataclass(foo="abc", bar=1)
else:
assert rows[0].foo == "abc"
assert rows[0].bar == 1

0 comments on commit ca1cee2

Please sign in to comment.