diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py new file mode 100644 index 0000000..f608875 --- /dev/null +++ b/dataclass_io/writer.py @@ -0,0 +1,102 @@ +from csv import DictWriter +from dataclasses import asdict +from enum import Enum +from enum import unique +from io import TextIOWrapper +from pathlib import Path +from types import TracebackType +from typing import IO +from typing import Any +from typing import Iterable +from typing import Optional +from typing import TextIO +from typing import Type +from typing import TypeAlias + +from dataclass_io._lib.assertions import assert_dataclass_is_valid +from dataclass_io._lib.assertions import assert_file_is_appendable +from dataclass_io._lib.assertions import assert_file_is_writable +from dataclass_io._lib.dataclass_extensions import DataclassInstance +from dataclass_io._lib.dataclass_extensions import fieldnames +from dataclass_io._lib.file import FileHeader + +WritableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO + + +@unique +class WriteMode(Enum): + APPEND = "a" + WRITE = "w" + + +class DataclassWriter: + def __init__( + self, + path: Path, + dataclass_type: type[DataclassInstance], + mode: str = "w", + delimiter: str = "\t", + overwrite: bool = True, + header: Optional[FileHeader] = None, + **kwds: Any, + ) -> None: + """ + Args: + path: Path to the file to write. + dataclass_type: Dataclass type. + + Raises: + FileNotFoundError: If the input file does not exist. + IsADirectoryError: If the input file path is a directory. + PermissionError: If the input file is not readable. + TypeError: If the provided type is not a dataclass. + """ + + try: + write_mode = WriteMode(mode) + except ValueError: + raise ValueError(f"`mode` must be either 'a' (append) or 'w' (write): {mode}") from None + + if write_mode is WriteMode.WRITE: + assert_file_is_writable(path, overwrite=overwrite) + else: + assert_file_is_appendable(path, dataclass_type=dataclass_type) + + assert_dataclass_is_valid(dataclass_type) + + self.dataclass_type = dataclass_type + self.delimiter = delimiter + + self._fout = path.open(mode) + + # TODO: optionally add preface + # If we aren't appending, write the header before any rows + if write_mode is WriteMode.WRITE: + self._fout.write(self.delimiter.join(fieldnames(dataclass_type)) + "\n") + + self._writer = DictWriter( + self._fout, + fieldnames=fieldnames(dataclass_type), + delimiter=self.delimiter, + ) + + def __enter__(self) -> "DataclassWriter": + return self + + def __exit__( + self, + exc_type: Type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + self._fout.close() + + def close(self) -> None: + self._fout.close() + + def write(self, dataclass_instance: DataclassInstance) -> None: + self._writer.writerow(asdict(dataclass_instance)) + + def writeall(self, dataclass_instances: Iterable[DataclassInstance]) -> None: + for dataclass_instance in dataclass_instances: + self.write(dataclass_instance) diff --git a/tests/test_writer.py b/tests/test_writer.py new file mode 100644 index 0000000..2c8a76c --- /dev/null +++ b/tests/test_writer.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from dataclass_io.writer import DataclassWriter + + +@dataclass +class FakeDataclass: + foo: str + bar: int + + +def test_writer(tmp_path: Path) -> None: + fpath = tmp_path / "test.txt" + + with DataclassWriter(path=fpath, mode="w", dataclass_type=FakeDataclass) as writer: + writer.write(FakeDataclass(foo="abc", bar=1)) + writer.write(FakeDataclass(foo="def", bar=2)) + + with open(fpath, "r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + assert next(f) == "def\t2\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_writeall(tmp_path: Path) -> None: + fpath = tmp_path / "test.txt" + + data = [ + FakeDataclass(foo="abc", bar=1), + FakeDataclass(foo="def", bar=2), + ] + with DataclassWriter(path=fpath, mode="w", dataclass_type=FakeDataclass) as writer: + writer.writeall(data) + + with open(fpath, "r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + assert next(f) == "def\t2\n" + with pytest.raises(StopIteration): + next(f)