Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Instantiate reader/writer from file handle instead of path #16

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions dataclass_io/_lib/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from dataclass_io._lib.dataclass_extensions import DataclassInstance
from dataclass_io._lib.dataclass_extensions import fieldnames
from dataclass_io._lib.file import FileFormat
from dataclass_io._lib.file import FileHeader
from dataclass_io._lib.file import ReadableFileHandle
from dataclass_io._lib.file import get_header


Expand Down Expand Up @@ -96,26 +96,34 @@ def assert_file_is_appendable(


def assert_file_header_matches_dataclass(
path: Path,
file: Path | ReadableFileHandle,
dataclass_type: type[DataclassInstance],
file_format: FileFormat,
delimiter: str,
comment_prefix: str,
) -> None:
"""
Check that the specified file has a header and its fields match those of the provided dataclass.
"""
with path.open("r") as fin:
header: FileHeader = get_header(fin, file_format=file_format)
header: FileHeader | None
if isinstance(file, Path):
with file.open("r") as fin:
header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix)
else:
pos = file.tell()
try:
header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix)
finally:
file.seek(pos)

if header is None:
raise ValueError(f"Could not find a header in the provided file: {path}")
raise ValueError("Could not find a header in the provided file")

if header.fieldnames != fieldnames(dataclass_type):
raise ValueError(
"The provided file does not have the same field names as the provided dataclass:\n"
f"\tDataclass: {dataclass_type.__name__}\n"
f"\tFile: {path}\n"
f"\tDataclass fields: {', '.join(fieldnames(dataclass_type))}\n"
f"\tFile: {', '.join(header.fieldnames)}\n"
f"\tFile fields: {', '.join(header.fieldnames)}\n"
)


Expand Down
31 changes: 12 additions & 19 deletions dataclass_io/_lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
from enum import unique
from io import TextIOWrapper
from typing import IO
from typing import Any
from typing import Optional
from typing import TextIO
from typing import TypeAlias

ReadableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO
WritableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO
ReadableFileHandle: TypeAlias = TextIOWrapper
"""A file handle open for reading."""

WritableFileHandle: TypeAlias = TextIOWrapper | IO[Any]
"""A file handle open for writing."""


@unique
Expand Down Expand Up @@ -44,19 +47,6 @@ def __init__(self, _: str, abbreviation: str = None):
"""Append to an existing file."""


@dataclass(kw_only=True)
class FileFormat:
"""
Parameters describing the format and configuration of the dataclass file.

Most of these parameters, if specified, are passed through to `csv.DictReader`/`csv.DictWriter`
or `csv.reader`/`csv.writer`.
"""

delimiter: str = "\t"
comment: str = "#"


@dataclass(frozen=True, kw_only=True)
class FileHeader:
"""
Expand All @@ -76,7 +66,8 @@ class FileHeader:

def get_header(
reader: ReadableFileHandle,
file_format: FileFormat,
delimiter: str,
comment_prefix: str,
) -> Optional[FileHeader]:
"""
Read the header from an open file.
Expand All @@ -100,16 +91,18 @@ def get_header(
None if the file was empty or contained only comments or empty lines.
"""

# TODO: optionally reset file handle to the original position after reading the header

preface: list[str] = []

for line in reader:
if line.startswith(file_format.comment) or line.strip() == "":
if line.startswith(comment_prefix) or line.strip() == "":
preface.append(line.strip())
else:
break
else:
return None

fieldnames = line.strip().split(file_format.delimiter)
fieldnames = line.strip().split(delimiter)

return FileHeader(preface=preface, fieldnames=fieldnames)
97 changes: 62 additions & 35 deletions dataclass_io/reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from contextlib import contextmanager
from csv import DictReader
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Type
from typing import Iterator

from dataclass_io._lib.assertions import assert_dataclass_is_valid
from dataclass_io._lib.assertions import assert_file_header_matches_dataclass
from dataclass_io._lib.assertions import assert_file_is_readable
from dataclass_io._lib.dataclass_extensions import DataclassInstance
from dataclass_io._lib.dataclass_extensions import fieldnames
from dataclass_io._lib.dataclass_extensions import row_to_dataclass
from dataclass_io._lib.file import FileFormat
from dataclass_io._lib.file import FileHeader
from dataclass_io._lib.file import ReadableFileHandle
from dataclass_io._lib.file import get_header
Expand All @@ -24,62 +23,90 @@ class DataclassReader:

def __init__(
self,
filename: str | Path,
fin: ReadableFileHandle,
dataclass_type: type[DataclassInstance],
delimiter: str = "\t",
comment: str = "#",
comment_prefix: str = "#",
**kwds: Any,
) -> None:
"""
Args:
path: Path to the file to read.
fin: Open file handle for reading.
dataclass_type: Dataclass type.
delimiter: The input file delimiter.
comment_prefix: The prefix for any comment/preface rows preceding the header row.
dataclass_type: Dataclass type.

Raises:
FileNotFoundError: If the input file does not exist.
IsADirectoryError: If the input file path is a directory.
PermissionError: If the input file is not readable.
TypeError: If the provided type is not a dataclass.
"""

filepath: Path = filename if isinstance(filename, Path) else Path(filename)
file_format = FileFormat(
assert_dataclass_is_valid(dataclass_type)
assert_file_header_matches_dataclass(
file=fin,
dataclass_type=dataclass_type,
delimiter=delimiter,
comment=comment,
comment_prefix=comment_prefix,
)

assert_dataclass_is_valid(dataclass_type)
assert_file_is_readable(filepath)
assert_file_header_matches_dataclass(filepath, dataclass_type, file_format)

self._dataclass_type = dataclass_type
self._fin = filepath.open("r")
self._header = get_header(reader=self._fin, file_format=file_format)
self._fin = fin
self._header = get_header(
reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix
)
self._reader = DictReader(
f=self._fin,
fieldnames=fieldnames(dataclass_type),
delimiter=file_format.delimiter,
delimiter=delimiter,
)

def __enter__(self) -> "DataclassReader":
return self

def __exit__(
self,
exc_type: Type[BaseException],
exc_value: BaseException,
traceback: TracebackType,
) -> None:
self.close()

def close(self) -> None:
"""Close the reader."""
self._fin.close()

def __iter__(self) -> "DataclassReader":
return self

def __next__(self) -> DataclassInstance:
row = next(self._reader)

return row_to_dataclass(row, self._dataclass_type)

@classmethod
@contextmanager
def open(
cls,
filename: str | Path,
dataclass_type: type[DataclassInstance],
delimiter: str = "\t",
comment_prefix: str = "#",
) -> Iterator["DataclassReader"]:
"""
Open a new `DataclassReader` from a file path.

Args:
filename: The path to the file from which dataclass instances will be read.
dataclass_type: The dataclass type to read from file.
delimiter: The input file delimiter.
comment_prefix: The prefix for any comment/preface rows preceding the header row. These
rows will be ignored when reading the file.

Yields:
A `DataclassReader` instance.

Raises:
FileNotFoundError: If the input file does not exist.
IsADirectoryError: If the input file path is a directory.
PermissionError: If the input file is not readable.
"""
filepath: Path = Path(filename)

# NB: The `DataclassReader` constructor will validate that the provided type is a valid
# dataclass and that the file's header matches the fields of the provided dataclass type.
assert_file_is_readable(filepath)

fin = filepath.open("r")
try:
yield cls(
fin=fin,
dataclass_type=dataclass_type,
delimiter=delimiter,
comment_prefix=comment_prefix,
)
finally:
fin.close()
Loading