Skip to content

Commit

Permalink
wip: dataclass coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
msto committed Mar 23, 2024
1 parent 048a664 commit 475de0d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
58 changes: 51 additions & 7 deletions dataclass_io/reader.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,62 @@
from csv import DictReader
from dataclasses import fields
from dataclasses import is_dataclass
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import ClassVar
from typing import Protocol
from typing import Type

from dataclass_io.lib import assert_readable_dataclass
from dataclass_io.lib import assert_readable_file


class DataclassInstance(Protocol):
"""
Type hint for a non-specific instance of a dataclass.
`DataclassReader` is an iterator over instances of the specified dataclass type. However, the
actual type is not known prior to instantiation. This `Protocol` is used to type hint the return
signature of `DataclassReader`'s `__next__` method.
https://stackoverflow.com/a/55240861
"""
__dataclass_fields__: ClassVar[dict[str, Any]]


class DataclassReader:
def __init__(
self,
path: Path,
dc_type: type,
dataclass_type: type,
**kwds: Any,
) -> None:
"""
Args:
path: Path to the file to read.
dc_type: Dataclass type.
dataclass_type: Dataclass type.
Raises:
FileNotFoundError: If the input file does not exist.
IsADirectoryError: If the input file path is a directory.
PermissionError: If the input file is not readable.
TypeError: If the provided type is not a dataclass.
"""

assert_readable_file(path)
assert_readable_dataclass(dc_type)
assert_readable_dataclass(dataclass_type)

# NB: Somewhat annoyingly, when this validation is extracted into an external helper,
# mypy can no longer recognize that `self._dataclass_type` is a dataclass, and complains
# about the return type on `_row_to_dataclass`.
#
# I'm leaving `assert_readable_dataclass` in case we want to extend the definition of what
# it means to be a valid dataclass, but this is needed here to satisfy type checking.
if not is_dataclass(dataclass_type):
raise TypeError("The provided type must be a dataclass: {}".format(dataclass_type))

self._dc_type = dc_type
self._dc_fields = fields(self._dc_type)
self._dataclass_type = dataclass_type

self._fin = path.open("r")
self._reader = DictReader(self._fin, **kwds)
Expand All @@ -45,7 +75,21 @@ def __exit__(
def __iter__(self) -> "DataclassReader":
return self

def __next__(self) -> dict[str, str]:
def __next__(self) -> DataclassInstance:
row = next(self._reader)
return row

return self._row_to_dataclass(row)

def _row_to_dataclass(self, row: dict[str, str]) -> DataclassInstance:
"""
Convert a row of a CSV file into a dataclass instance.
"""

coerced_values: dict[str, Any] = {}

# Coerce each value in the row to the type of the corresponding field
for field in fields(self._dataclass_type):
value = row[field.name]
coerced_values[field.name] = field.type(value)

return self._dataclass_type(**coerced_values)
6 changes: 3 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def test_reader(tmp_path: Path) -> None:
with open(fpath, "w") as f:
f.write("abc\t1\n")

dictreader_kwargs = {"fieldnames": ["foo", "bar"], "delimiter": "\t"}
with DataclassReader(path=fpath, dc_type=FakeDataclass, **dictreader_kwargs) as reader:
dictreader_kwds = {"fieldnames": ["foo", "bar"], "delimiter": "\t"}
with DataclassReader(path=fpath, dataclass_type=FakeDataclass, **dictreader_kwds) as reader:
rows = [row for row in reader]

assert rows[0] == {"foo": "abc", "bar": "1"}
assert rows[0] == FakeDataclass(foo="abc", bar=1)

0 comments on commit 475de0d

Please sign in to comment.