Skip to content

Commit

Permalink
fix: append and textio typing
Browse files Browse the repository at this point in the history
  • Loading branch information
msto committed Jul 16, 2024
1 parent c5e5058 commit ca3edd9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 3 additions & 1 deletion dataclass_io/_lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from enum import Enum
from enum import unique
from io import TextIOWrapper
from typing import IO
from typing import Any
from typing import Optional
from typing import TypeAlias

ReadableFileHandle: TypeAlias = TextIOWrapper
"""A file handle open for reading."""

WritableFileHandle: TypeAlias = TextIOWrapper
WritableFileHandle: TypeAlias = TextIOWrapper | IO[Any]
"""A file handle open for writing."""


Expand Down
11 changes: 8 additions & 3 deletions dataclass_io/writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from contextlib import contextmanager
from csv import DictWriter
from dataclasses import asdict
from io import TextIOWrapper
from pathlib import Path
from typing import Any
from typing import Iterable
Expand Down Expand Up @@ -31,6 +30,7 @@ def __init__(
delimiter: str = "\t",
include_fields: list[str] | None = None,
exclude_fields: list[str] | None = None,
write_header: bool = True,
**kwds: Any,
) -> None:
"""
Expand All @@ -44,6 +44,9 @@ def __init__(
exclude_fields: If specified, any listed fieldnames will be excluded when writing
records to file.
May not be used together with `include_fields`.
write_header: If True, a header row consisting of the dataclass's fieldnames will be
written before any records are written (including or excluding any fields specified
by `include_fields` or `exclude_fields`).
Raises:
TypeError: If the provided type is not a dataclass.
Expand All @@ -65,7 +68,8 @@ def __init__(
)

# TODO: permit writing comment/preface rows before header
self._writer.writeheader()
if write_header:
self._writer.writeheader()

def write(self, dataclass_instance: DataclassInstance) -> None:
"""
Expand Down Expand Up @@ -176,11 +180,12 @@ def open(
comment_prefix=comment_prefix,
)

fout = TextIOWrapper(filepath.open(write_mode.abbreviation))
fout = filepath.open(write_mode.abbreviation)
try:
yield cls(
fout=fout,
dataclass_type=dataclass_type,
write_header=(write_mode is WriteMode.WRITE), # Skip header when appending
**kwds,
)
finally:
Expand Down

0 comments on commit ca3edd9

Please sign in to comment.