diff --git a/dataclass_io/_lib/assertions.py b/dataclass_io/_lib/assertions.py index d18f05b..36dac4c 100644 --- a/dataclass_io/_lib/assertions.py +++ b/dataclass_io/_lib/assertions.py @@ -7,8 +7,8 @@ from dataclass_io._lib.dataclass_extensions import DataclassInstance from dataclass_io._lib.dataclass_extensions import fieldnames -from dataclass_io._lib.file import FileFormat from dataclass_io._lib.file import FileHeader +from dataclass_io._lib.file import ReadableFileHandle from dataclass_io._lib.file import get_header @@ -96,26 +96,34 @@ def assert_file_is_appendable( def assert_file_header_matches_dataclass( - path: Path, + file: Path | ReadableFileHandle, dataclass_type: type[DataclassInstance], - file_format: FileFormat, + delimiter: str, + comment_prefix: str, ) -> None: """ Check that the specified file has a header and its fields match those of the provided dataclass. """ - with path.open("r") as fin: - header: FileHeader = get_header(fin, file_format=file_format) + header: FileHeader | None + if isinstance(file, Path): + with file.open("r") as fin: + header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix) + else: + pos = file.tell() + try: + header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix) + finally: + file.seek(pos) if header is None: - raise ValueError(f"Could not find a header in the provided file: {path}") + raise ValueError("Could not find a header in the provided file") if header.fieldnames != fieldnames(dataclass_type): raise ValueError( "The provided file does not have the same field names as the provided dataclass:\n" f"\tDataclass: {dataclass_type.__name__}\n" - f"\tFile: {path}\n" f"\tDataclass fields: {', '.join(fieldnames(dataclass_type))}\n" - f"\tFile: {', '.join(header.fieldnames)}\n" + f"\tFile fields: {', '.join(header.fieldnames)}\n" ) diff --git a/dataclass_io/_lib/file.py b/dataclass_io/_lib/file.py index f62f1e4..d657c84 100644 --- a/dataclass_io/_lib/file.py +++ b/dataclass_io/_lib/file.py @@ -2,13 +2,14 @@ from enum import Enum from enum import unique from io import TextIOWrapper -from typing import IO from typing import Optional -from typing import TextIO from typing import TypeAlias -ReadableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO -WritableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO +ReadableFileHandle: TypeAlias = TextIOWrapper +"""A file handle open for reading.""" + +WritableFileHandle: TypeAlias = TextIOWrapper +"""A file handle open for writing.""" @unique @@ -44,19 +45,6 @@ def __init__(self, _: str, abbreviation: str = None): """Append to an existing file.""" -@dataclass(kw_only=True) -class FileFormat: - """ - Parameters describing the format and configuration of the dataclass file. - - Most of these parameters, if specified, are passed through to `csv.DictReader`/`csv.DictWriter` - or `csv.reader`/`csv.writer`. - """ - - delimiter: str = "\t" - comment: str = "#" - - @dataclass(frozen=True, kw_only=True) class FileHeader: """ @@ -76,7 +64,8 @@ class FileHeader: def get_header( reader: ReadableFileHandle, - file_format: FileFormat, + delimiter: str, + comment_prefix: str, ) -> Optional[FileHeader]: """ Read the header from an open file. @@ -103,13 +92,13 @@ def get_header( preface: list[str] = [] for line in reader: - if line.startswith(file_format.comment) or line.strip() == "": + if line.startswith(comment_prefix) or line.strip() == "": preface.append(line.strip()) else: break else: return None - fieldnames = line.strip().split(file_format.delimiter) + fieldnames = line.strip().split(delimiter) return FileHeader(preface=preface, fieldnames=fieldnames) diff --git a/dataclass_io/reader.py b/dataclass_io/reader.py index 31c053b..44ed599 100644 --- a/dataclass_io/reader.py +++ b/dataclass_io/reader.py @@ -1,8 +1,8 @@ +from contextlib import contextmanager from csv import DictReader from pathlib import Path -from types import TracebackType from typing import Any -from typing import Type +from typing import Iterator from dataclass_io._lib.assertions import assert_dataclass_is_valid from dataclass_io._lib.assertions import assert_file_header_matches_dataclass @@ -10,7 +10,6 @@ from dataclass_io._lib.dataclass_extensions import DataclassInstance from dataclass_io._lib.dataclass_extensions import fieldnames from dataclass_io._lib.dataclass_extensions import row_to_dataclass -from dataclass_io._lib.file import FileFormat from dataclass_io._lib.file import FileHeader from dataclass_io._lib.file import ReadableFileHandle from dataclass_io._lib.file import get_header @@ -24,58 +23,42 @@ class DataclassReader: def __init__( self, - filename: str | Path, + fin: ReadableFileHandle, dataclass_type: type[DataclassInstance], delimiter: str = "\t", - comment: str = "#", + comment_prefix: str = "#", **kwds: Any, ) -> None: """ Args: - path: Path to the file to read. + fin: Open file handle for reading. + dataclass_type: Dataclass type. + delimiter: The input file delimiter. + comment_prefix: The prefix for any comment/preface rows preceding the header row. dataclass_type: Dataclass type. Raises: - FileNotFoundError: If the input file does not exist. - IsADirectoryError: If the input file path is a directory. - PermissionError: If the input file is not readable. TypeError: If the provided type is not a dataclass. """ - - filepath: Path = filename if isinstance(filename, Path) else Path(filename) - file_format = FileFormat( + assert_dataclass_is_valid(dataclass_type) + assert_file_header_matches_dataclass( + file=fin, + dataclass_type=dataclass_type, delimiter=delimiter, - comment=comment, + comment_prefix=comment_prefix, ) - assert_dataclass_is_valid(dataclass_type) - assert_file_is_readable(filepath) - assert_file_header_matches_dataclass(filepath, dataclass_type, file_format) - self._dataclass_type = dataclass_type - self._fin = filepath.open("r") - self._header = get_header(reader=self._fin, file_format=file_format) + self._fin = fin + self._header = get_header( + reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix + ) self._reader = DictReader( f=self._fin, fieldnames=fieldnames(dataclass_type), - delimiter=file_format.delimiter, + delimiter=delimiter, ) - def __enter__(self) -> "DataclassReader": - return self - - def __exit__( - self, - exc_type: Type[BaseException], - exc_value: BaseException, - traceback: TracebackType, - ) -> None: - self.close() - - def close(self) -> None: - """Close the reader.""" - self._fin.close() - def __iter__(self) -> "DataclassReader": return self @@ -90,11 +73,22 @@ def open( cls, filename: str | Path, dataclass_type: type[DataclassInstance], - comment_prefix: str = DEFAULT_COMMENT_PREFIX, + delimiter: str = "\t", + comment_prefix: str = "#", ) -> Iterator["DataclassReader"]: """ Open a new `DataclassReader` from a file path. + Args: + filename: The path to the file from which dataclass instances will be read. + dataclass_type: The dataclass type to read from file. + delimiter: The input file delimiter. + comment_prefix: The prefix for any comment/preface rows preceding the header row. These + rows will be ignored when reading the file. + + Yields: + A `DataclassReader` instance. + Raises: FileNotFoundError: If the input file does not exist. IsADirectoryError: If the input file path is a directory. @@ -102,9 +96,9 @@ def open( """ filepath: Path = Path(filename) - assert_dataclass_is_valid(dataclass_type) + # NB: The `DataclassReader` constructor will validate that the provided type is a valid + # dataclass and that the file's header matches the fields of the provided dataclass type. assert_file_is_readable(filepath) - assert_file_header_matches_dataclass(filepath, dataclass_type, comment_prefix=comment_prefix) fin = filepath.open("r") try: @@ -112,7 +106,7 @@ def open( fin=fin, dataclass_type=dataclass_type, delimiter=delimiter, - comment=comment, + comment_prefix=comment_prefix, ) finally: fin.close() diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py index 7cefb05..9201e5c 100644 --- a/dataclass_io/writer.py +++ b/dataclass_io/writer.py @@ -1,10 +1,11 @@ +from contextlib import contextmanager from csv import DictWriter from dataclasses import asdict +from io import TextIOWrapper from pathlib import Path -from types import TracebackType from typing import Any from typing import Iterable -from typing import Type +from typing import Iterator from dataclass_io._lib.assertions import assert_dataclass_is_valid from dataclass_io._lib.assertions import assert_fieldnames_are_dataclass_attributes @@ -13,7 +14,6 @@ from dataclass_io._lib.assertions import assert_file_is_writable from dataclass_io._lib.dataclass_extensions import DataclassInstance from dataclass_io._lib.dataclass_extensions import fieldnames -from dataclass_io._lib.file import FileFormat from dataclass_io._lib.file import WritableFileHandle from dataclass_io._lib.file import WriteMode @@ -26,28 +26,18 @@ class DataclassWriter: def __init__( self, - filename: str | Path, + fout: WritableFileHandle, dataclass_type: type[DataclassInstance], - mode: str = "write", delimiter: str = "\t", - overwrite: bool = True, include_fields: list[str] | None = None, exclude_fields: list[str] | None = None, **kwds: Any, ) -> None: """ Args: - path: Path to the file to write. + fout: Open file handle for writing. dataclass_type: Dataclass type. - mode: Either `"write"` or `"append"`. - If `"write"`, the specified file `path` must not already exist unless - `overwrite=True` is specified. - If `"append"`, the specified file `path` must already exist and contain a header row - matching the specified dataclass and any specified `include_fields` or - `exclude_fields`. delimiter: The output file delimiter. - overwrite: If `True`, and `mode="write"`, the file specified at `path` will be - overwritten if it exists. include_fields: If specified, only the listed fieldnames will be included when writing records to file. Fields will be written in the order provided. May not be used together with `exclude_fields`. @@ -56,27 +46,10 @@ def __init__( May not be used together with `include_fields`. Raises: - FileNotFoundError: If the output file does not exist when trying to append. - IsADirectoryError: If the output file path is a directory. - PermissionError: If the output file is not writable (or readable when trying to append). TypeError: If the provided type is not a dataclass. + ValueError: If both `include_fields` and `exclude_fields` are specified. """ - - filepath: Path = filename if isinstance(filename, Path) else Path(filename) - - try: - write_mode = WriteMode(mode) - except ValueError: - raise ValueError(f"`mode` must be either 'write' or 'append': {mode}") from None - - file_format = FileFormat(delimiter=delimiter) - assert_dataclass_is_valid(dataclass_type) - if write_mode is WriteMode.WRITE: - assert_file_is_writable(filepath, overwrite=overwrite) - else: - assert_file_is_appendable(filepath, dataclass_type=dataclass_type) - assert_file_header_matches_dataclass(filepath, dataclass_type, file_format) self._dataclass_type = dataclass_type self._fieldnames = _validate_output_fieldnames( @@ -84,7 +57,7 @@ def __init__( include_fields=include_fields, exclude_fields=exclude_fields, ) - self._fout = filepath.open(write_mode.abbreviation) + self._fout = fout self._writer = DictWriter( f=self._fout, fieldnames=self._fieldnames, @@ -92,23 +65,7 @@ def __init__( ) # TODO: permit writing comment/preface rows before header - # If we aren't appending, write the header before any rows - if write_mode is WriteMode.WRITE: - self._writer.writeheader() - - def __enter__(self) -> "DataclassWriter": - return self - - def __exit__( - self, - exc_type: Type[BaseException], - exc_value: BaseException, - traceback: TracebackType, - ) -> None: - self.close() - - def close(self) -> None: - self._fout.close() + self._writer.writeheader() def write(self, dataclass_instance: DataclassInstance) -> None: """ @@ -121,6 +78,9 @@ def write(self, dataclass_instance: DataclassInstance) -> None: Args: dataclass_instance: An instance of the specified dataclass. + + Raises: + ValueError: If the provided instance is not an instance of the writer's dataclass. """ # TODO: consider permitting other dataclass types *if* they contain the required attributes @@ -144,10 +104,88 @@ def writeall(self, dataclass_instances: Iterable[DataclassInstance]) -> None: Args: dataclass_instances: A sequence of instances of the specified dataclass. + + Raises: + ValueError: If any of the provided instances are not an instance of the writer's + dataclass. """ for dataclass_instance in dataclass_instances: self.write(dataclass_instance) + @classmethod + @contextmanager + def open( + cls, + filename: str | Path, + dataclass_type: type[DataclassInstance], + mode: str = "write", + overwrite: bool = True, + delimiter: str = "\t", + comment_prefix: str = "#", + **kwds: Any, + ) -> Iterator["DataclassWriter"]: + """ + Open a new `DataclassWriter` from a file path. + + Args: + filename: The path to the file to which dataclass instances will be written. + dataclass_type: The dataclass type to write to file. + mode: Either `"write"` or `"append"`. + - If `"write"`, the specified file `path` must not already exist unless + `overwrite=True` is specified. + - If `"append"`, the specified file `path` must already exist and contain a header + row matching the specified dataclass and any specified `include_fields` or + `exclude_fields`. + overwrite: If `True`, and `mode="write"`, the file specified at `path` will be + overwritten if it exists. + delimiter: The output file delimiter. + comment_prefix: The prefix for any comment/preface rows preceding the header row. + (This argument is ignored when `mode="write"`. It is used when `mode="append"` to + validate that the existing file's header matches the specified dataclass.) + **kwds: Additional keyword arguments to be passed to the `DataclassWriter` constructor. + + Yields: + A `DataclassWriter` instance. + + Raises: + TypeError: If the provided type is not a dataclass. + FileNotFoundError: If the output file does not exist when trying to append. + IsADirectoryError: If the output file path is a directory. + PermissionError: If the output file is not writable. + PermissionError: If `mode="append"` and the output file is not readable. (The output + file must be readable in order to validate that the existing file's header matches + the dataclass's fields.) + """ + + filepath: Path = Path(filename) + + try: + write_mode = WriteMode(mode) + except ValueError: + raise ValueError(f"`mode` must be either 'write' or 'append': {mode}") from None + + assert_dataclass_is_valid(dataclass_type) + if write_mode is WriteMode.WRITE: + assert_file_is_writable(filepath, overwrite=overwrite) + else: + assert_file_is_appendable(filepath, dataclass_type=dataclass_type) + assert_file_header_matches_dataclass( + file=filepath, + dataclass_type=dataclass_type, + delimiter=delimiter, + comment_prefix=comment_prefix, + ) + + fout = TextIOWrapper(filepath.open(write_mode.abbreviation)) + try: + yield cls( + fout=fout, + dataclass_type=dataclass_type, + **kwds, + ) + finally: + fout.close() + def _validate_output_fieldnames( dataclass_type: type[DataclassInstance], diff --git a/tests/test_reader.py b/tests/test_reader.py index 8a9ea22..b604534 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -17,7 +17,7 @@ def test_reader(tmp_path: Path) -> None: f.write("foo\tbar\n") f.write("abc\t1\n") - with DataclassReader(filename=fpath, dataclass_type=FakeDataclass) as reader: + with DataclassReader.open(filename=fpath, dataclass_type=FakeDataclass) as reader: rows = [row for row in reader] assert rows[0] == FakeDataclass(foo="abc", bar=1) @@ -31,7 +31,7 @@ def test_reader_from_str(tmp_path: Path) -> None: f.write("foo\tbar\n") f.write("abc\t1\n") - with DataclassReader(filename=str(fpath), dataclass_type=FakeDataclass) as reader: + with DataclassReader.open(filename=str(fpath), dataclass_type=FakeDataclass) as reader: rows = [row for row in reader] assert rows[0] == FakeDataclass(foo="abc", bar=1) diff --git a/tests/test_writer.py b/tests/test_writer.py index 1e997d0..b2bd4d8 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -15,7 +15,7 @@ class FakeDataclass: def test_writer(tmp_path: Path) -> None: fpath = tmp_path / "test.txt" - with DataclassWriter(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer: + with DataclassWriter.open(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) writer.write(FakeDataclass(foo="def", bar=2)) @@ -31,7 +31,9 @@ def test_writer_from_str(tmp_path: Path) -> None: """Test that we can create a writer when `filename` is a `str`.""" fpath = tmp_path / "test.txt" - with DataclassWriter(filename=str(fpath), mode="write", dataclass_type=FakeDataclass) as writer: + with DataclassWriter.open( + filename=str(fpath), mode="write", dataclass_type=FakeDataclass + ) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) with fpath.open("r") as f: @@ -48,7 +50,7 @@ def test_writer_writeall(tmp_path: Path) -> None: FakeDataclass(foo="abc", bar=1), FakeDataclass(foo="def", bar=2), ] - with DataclassWriter(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer: + with DataclassWriter.open(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer: writer.writeall(data) with fpath.open("r") as f: @@ -66,7 +68,11 @@ def test_writer_append(tmp_path: Path) -> None: with fpath.open("w") as fout: fout.write("foo\tbar\n") - with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter.open( + filename=fpath, + mode="append", + dataclass_type=FakeDataclass, + ) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) writer.write(FakeDataclass(foo="def", bar=2)) @@ -84,7 +90,9 @@ def test_writer_append_raises_if_empty(tmp_path: Path) -> None: fpath.touch() with pytest.raises(ValueError, match="The specified output file is empty"): - with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter.open( + filename=fpath, mode="append", dataclass_type=FakeDataclass + ) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) @@ -95,7 +103,9 @@ def test_writer_append_raises_if_no_header(tmp_path: Path) -> None: fout.write("abc\t1\n") with pytest.raises(ValueError, match="The provided file does not have the same field names"): - with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter.open( + filename=fpath, mode="append", dataclass_type=FakeDataclass + ) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) @@ -110,7 +120,9 @@ def test_writer_append_raises_if_header_does_not_match(tmp_path: Path) -> None: fout.write("foo\tbar\tbaz\n") with pytest.raises(ValueError, match="The provided file does not have the same field names"): - with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter.open( + filename=fpath, mode="append", dataclass_type=FakeDataclass + ) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) @@ -122,9 +134,8 @@ def test_writer_include_fields(tmp_path: Path) -> None: FakeDataclass(foo="abc", bar=1), FakeDataclass(foo="def", bar=2), ] - with DataclassWriter( + with DataclassWriter.open( filename=fpath, - mode="write", dataclass_type=FakeDataclass, include_fields=["foo"], ) as writer: @@ -146,9 +157,8 @@ def test_writer_include_fields_reorders(tmp_path: Path) -> None: FakeDataclass(foo="abc", bar=1), FakeDataclass(foo="def", bar=2), ] - with DataclassWriter( + with DataclassWriter.open( filename=fpath, - mode="write", dataclass_type=FakeDataclass, include_fields=["bar", "foo"], ) as writer: @@ -171,9 +181,8 @@ def test_writer_exclude_fields(tmp_path: Path) -> None: FakeDataclass(foo="abc", bar=1), FakeDataclass(foo="def", bar=2), ] - with DataclassWriter( + with DataclassWriter.open( filename=fpath, - mode="write", dataclass_type=FakeDataclass, exclude_fields=["bar"], ) as writer: