From ca3edd9ae538b0d5c959ed09386690c59b4c9e3f Mon Sep 17 00:00:00 2001 From: Matt Stone Date: Mon, 15 Jul 2024 21:09:31 -0400 Subject: [PATCH] fix: append and textio typing --- dataclass_io/_lib/file.py | 4 +++- dataclass_io/writer.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/dataclass_io/_lib/file.py b/dataclass_io/_lib/file.py index d657c84..37d644c 100644 --- a/dataclass_io/_lib/file.py +++ b/dataclass_io/_lib/file.py @@ -2,13 +2,15 @@ from enum import Enum from enum import unique from io import TextIOWrapper +from typing import IO +from typing import Any from typing import Optional from typing import TypeAlias ReadableFileHandle: TypeAlias = TextIOWrapper """A file handle open for reading.""" -WritableFileHandle: TypeAlias = TextIOWrapper +WritableFileHandle: TypeAlias = TextIOWrapper | IO[Any] """A file handle open for writing.""" diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py index 9201e5c..1d71b9a 100644 --- a/dataclass_io/writer.py +++ b/dataclass_io/writer.py @@ -1,7 +1,6 @@ from contextlib import contextmanager from csv import DictWriter from dataclasses import asdict -from io import TextIOWrapper from pathlib import Path from typing import Any from typing import Iterable @@ -31,6 +30,7 @@ def __init__( delimiter: str = "\t", include_fields: list[str] | None = None, exclude_fields: list[str] | None = None, + write_header: bool = True, **kwds: Any, ) -> None: """ @@ -44,6 +44,9 @@ def __init__( exclude_fields: If specified, any listed fieldnames will be excluded when writing records to file. May not be used together with `include_fields`. + write_header: If True, a header row consisting of the dataclass's fieldnames will be + written before any records are written (including or excluding any fields specified + by `include_fields` or `exclude_fields`). Raises: TypeError: If the provided type is not a dataclass. @@ -65,7 +68,8 @@ def __init__( ) # TODO: permit writing comment/preface rows before header - self._writer.writeheader() + if write_header: + self._writer.writeheader() def write(self, dataclass_instance: DataclassInstance) -> None: """ @@ -176,11 +180,12 @@ def open( comment_prefix=comment_prefix, ) - fout = TextIOWrapper(filepath.open(write_mode.abbreviation)) + fout = filepath.open(write_mode.abbreviation) try: yield cls( fout=fout, dataclass_type=dataclass_type, + write_header=(write_mode is WriteMode.WRITE), # Skip header when appending **kwds, ) finally: