Skip to content

Commit

Permalink
refactor: accept file handle instead of path
Browse files Browse the repository at this point in the history
  • Loading branch information
msto committed Jul 16, 2024
1 parent 5a28324 commit c5e5058
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 133 deletions.
24 changes: 16 additions & 8 deletions dataclass_io/_lib/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from dataclass_io._lib.dataclass_extensions import DataclassInstance
from dataclass_io._lib.dataclass_extensions import fieldnames
from dataclass_io._lib.file import FileFormat
from dataclass_io._lib.file import FileHeader
from dataclass_io._lib.file import ReadableFileHandle
from dataclass_io._lib.file import get_header


Expand Down Expand Up @@ -96,26 +96,34 @@ def assert_file_is_appendable(


def assert_file_header_matches_dataclass(
path: Path,
file: Path | ReadableFileHandle,
dataclass_type: type[DataclassInstance],
file_format: FileFormat,
delimiter: str,
comment_prefix: str,
) -> None:
"""
Check that the specified file has a header and its fields match those of the provided dataclass.
"""
with path.open("r") as fin:
header: FileHeader = get_header(fin, file_format=file_format)
header: FileHeader | None
if isinstance(file, Path):
with file.open("r") as fin:
header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix)
else:
pos = file.tell()
try:
header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix)
finally:
file.seek(pos)

if header is None:
raise ValueError(f"Could not find a header in the provided file: {path}")
raise ValueError("Could not find a header in the provided file")

if header.fieldnames != fieldnames(dataclass_type):
raise ValueError(
"The provided file does not have the same field names as the provided dataclass:\n"
f"\tDataclass: {dataclass_type.__name__}\n"
f"\tFile: {path}\n"
f"\tDataclass fields: {', '.join(fieldnames(dataclass_type))}\n"
f"\tFile: {', '.join(header.fieldnames)}\n"
f"\tFile fields: {', '.join(header.fieldnames)}\n"
)


Expand Down
29 changes: 9 additions & 20 deletions dataclass_io/_lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from enum import Enum
from enum import unique
from io import TextIOWrapper
from typing import IO
from typing import Optional
from typing import TextIO
from typing import TypeAlias

ReadableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO
WritableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO
ReadableFileHandle: TypeAlias = TextIOWrapper
"""A file handle open for reading."""

WritableFileHandle: TypeAlias = TextIOWrapper
"""A file handle open for writing."""


@unique
Expand Down Expand Up @@ -44,19 +45,6 @@ def __init__(self, _: str, abbreviation: str = None):
"""Append to an existing file."""


@dataclass(kw_only=True)
class FileFormat:
"""
Parameters describing the format and configuration of the dataclass file.
Most of these parameters, if specified, are passed through to `csv.DictReader`/`csv.DictWriter`
or `csv.reader`/`csv.writer`.
"""

delimiter: str = "\t"
comment: str = "#"


@dataclass(frozen=True, kw_only=True)
class FileHeader:
"""
Expand All @@ -76,7 +64,8 @@ class FileHeader:

def get_header(
reader: ReadableFileHandle,
file_format: FileFormat,
delimiter: str,
comment_prefix: str,
) -> Optional[FileHeader]:
"""
Read the header from an open file.
Expand All @@ -103,13 +92,13 @@ def get_header(
preface: list[str] = []

for line in reader:
if line.startswith(file_format.comment) or line.strip() == "":
if line.startswith(comment_prefix) or line.strip() == "":
preface.append(line.strip())
else:
break
else:
return None

fieldnames = line.strip().split(file_format.delimiter)
fieldnames = line.strip().split(delimiter)

return FileHeader(preface=preface, fieldnames=fieldnames)
72 changes: 33 additions & 39 deletions dataclass_io/reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from contextlib import contextmanager
from csv import DictReader
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Type
from typing import Iterator

from dataclass_io._lib.assertions import assert_dataclass_is_valid
from dataclass_io._lib.assertions import assert_file_header_matches_dataclass
from dataclass_io._lib.assertions import assert_file_is_readable
from dataclass_io._lib.dataclass_extensions import DataclassInstance
from dataclass_io._lib.dataclass_extensions import fieldnames
from dataclass_io._lib.dataclass_extensions import row_to_dataclass
from dataclass_io._lib.file import FileFormat
from dataclass_io._lib.file import FileHeader
from dataclass_io._lib.file import ReadableFileHandle
from dataclass_io._lib.file import get_header
Expand All @@ -24,58 +23,42 @@ class DataclassReader:

def __init__(
self,
filename: str | Path,
fin: ReadableFileHandle,
dataclass_type: type[DataclassInstance],
delimiter: str = "\t",
comment: str = "#",
comment_prefix: str = "#",
**kwds: Any,
) -> None:
"""
Args:
path: Path to the file to read.
fin: Open file handle for reading.
dataclass_type: Dataclass type.
delimiter: The input file delimiter.
comment_prefix: The prefix for any comment/preface rows preceding the header row.
dataclass_type: Dataclass type.
Raises:
FileNotFoundError: If the input file does not exist.
IsADirectoryError: If the input file path is a directory.
PermissionError: If the input file is not readable.
TypeError: If the provided type is not a dataclass.
"""

filepath: Path = filename if isinstance(filename, Path) else Path(filename)
file_format = FileFormat(
assert_dataclass_is_valid(dataclass_type)
assert_file_header_matches_dataclass(
file=fin,
dataclass_type=dataclass_type,
delimiter=delimiter,
comment=comment,
comment_prefix=comment_prefix,
)

assert_dataclass_is_valid(dataclass_type)
assert_file_is_readable(filepath)
assert_file_header_matches_dataclass(filepath, dataclass_type, file_format)

self._dataclass_type = dataclass_type
self._fin = filepath.open("r")
self._header = get_header(reader=self._fin, file_format=file_format)
self._fin = fin
self._header = get_header(
reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix
)
self._reader = DictReader(
f=self._fin,
fieldnames=fieldnames(dataclass_type),
delimiter=file_format.delimiter,
delimiter=delimiter,
)

def __enter__(self) -> "DataclassReader":
return self

def __exit__(
self,
exc_type: Type[BaseException],
exc_value: BaseException,
traceback: TracebackType,
) -> None:
self.close()

def close(self) -> None:
"""Close the reader."""
self._fin.close()

def __iter__(self) -> "DataclassReader":
return self

Expand All @@ -90,29 +73,40 @@ def open(
cls,
filename: str | Path,
dataclass_type: type[DataclassInstance],
comment_prefix: str = DEFAULT_COMMENT_PREFIX,
delimiter: str = "\t",
comment_prefix: str = "#",
) -> Iterator["DataclassReader"]:
"""
Open a new `DataclassReader` from a file path.
Args:
filename: The path to the file from which dataclass instances will be read.
dataclass_type: The dataclass type to read from file.
delimiter: The input file delimiter.
comment_prefix: The prefix for any comment/preface rows preceding the header row. These
rows will be ignored when reading the file.
Yields:
A `DataclassReader` instance.
Raises:
FileNotFoundError: If the input file does not exist.
IsADirectoryError: If the input file path is a directory.
PermissionError: If the input file is not readable.
"""
filepath: Path = Path(filename)

assert_dataclass_is_valid(dataclass_type)
# NB: The `DataclassReader` constructor will validate that the provided type is a valid
# dataclass and that the file's header matches the fields of the provided dataclass type.
assert_file_is_readable(filepath)
assert_file_header_matches_dataclass(filepath, dataclass_type, comment_prefix=comment_prefix)

fin = filepath.open("r")
try:
yield cls(
fin=fin,
dataclass_type=dataclass_type,
delimiter=delimiter,
comment=comment,
comment_prefix=comment_prefix,
)
finally:
fin.close()
Loading

0 comments on commit c5e5058

Please sign in to comment.