Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use csv.DictReader to parse header fields #21

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dataclass_io/_lib/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os import access
from os import stat
from pathlib import Path
from typing import Any

from dataclass_io._lib.dataclass_extensions import DataclassInstance
from dataclass_io._lib.dataclass_extensions import fieldnames
Expand Down Expand Up @@ -98,20 +99,20 @@ def assert_file_is_appendable(
def assert_file_header_matches_dataclass(
file: Path | ReadableFileHandle,
dataclass_type: type[DataclassInstance],
delimiter: str,
comment_prefix: str,
**kwargs: Any,
) -> None:
"""
Check that the specified file has a header and its fields match those of the provided dataclass.
"""
header: FileHeader | None
if isinstance(file, Path):
with file.open("r") as fin:
header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix)
header = get_header(reader=fin, comment_prefix=comment_prefix, **kwargs)
else:
pos = file.tell()
try:
header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix)
header = get_header(reader=file, comment_prefix=comment_prefix, **kwargs)
finally:
file.seek(pos)

Expand Down
9 changes: 7 additions & 2 deletions dataclass_io/_lib/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from csv import DictReader
from dataclasses import dataclass
from enum import Enum
from enum import unique
Expand Down Expand Up @@ -66,8 +67,8 @@ class FileHeader:

def get_header(
reader: ReadableFileHandle,
delimiter: str,
comment_prefix: str,
**kwargs: Any,
) -> Optional[FileHeader]:
"""
Read the header from an open file.
Expand All @@ -85,6 +86,7 @@ def get_header(
Args:
reader: An open, readable file handle.
comment_char: The character which indicates the start of a comment line.
**kwargs: Additional keyword arguments to pass to `csv.DictReader`.

Returns:
A `FileHeader` containing the field names and any preceding lines.
Expand All @@ -103,6 +105,9 @@ def get_header(
else:
return None

fieldnames = line.strip().split(delimiter)
# msto#19 Read header fields
# Use csv.DictReader because RFC4180 is tricky to implement correctly
header_reader = DictReader([line], **kwargs)
fieldnames = list(header_reader.fieldnames)

return FileHeader(preface=preface, fieldnames=fieldnames)
12 changes: 9 additions & 3 deletions dataclass_io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@ def __init__(
self,
fin: ReadableFileHandle,
dataclass_type: type[DataclassInstance],
delimiter: str = "\t",
comment_prefix: str = "#",
**kwds: Any,
delimiter: str = "\t",
**kwargs: Any,
) -> None:
"""
Args:
fin: Open file handle for reading.
dataclass_type: Dataclass type.
delimiter: The input file delimiter.
comment_prefix: The prefix for any comment/preface rows preceding the header row.
quoting: Quoting style (enum value from Python csv package).
dataclass_type: Dataclass type.

Raises:
Expand All @@ -46,17 +47,22 @@ def __init__(
dataclass_type=dataclass_type,
delimiter=delimiter,
comment_prefix=comment_prefix,
**kwargs,
)

self._dataclass_type = dataclass_type
self._fin = fin
self._header = get_header(
reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix
reader=self._fin,
delimiter=delimiter,
comment_prefix=comment_prefix,
**kwargs,
)
self._reader = DictReader(
f=self._fin,
fieldnames=fieldnames(dataclass_type),
delimiter=delimiter,
**kwargs,
)

def __iter__(self) -> "DataclassReader":
Expand Down
14 changes: 8 additions & 6 deletions dataclass_io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
include_fields: list[str] | None = None,
exclude_fields: list[str] | None = None,
write_header: bool = True,
**kwds: Any,
**kwargs: Any,
) -> None:
"""
Args:
Expand Down Expand Up @@ -65,6 +65,7 @@ def __init__(
f=self._fout,
fieldnames=self._fieldnames,
delimiter=delimiter,
**kwargs,
)

# TODO: permit writing comment/preface rows before header
Expand Down Expand Up @@ -124,9 +125,9 @@ def open(
dataclass_type: type[DataclassInstance],
mode: str = "write",
overwrite: bool = True,
delimiter: str = "\t",
comment_prefix: str = "#",
**kwds: Any,
delimiter: str = "\t",
**kwargs: Any,
) -> Iterator["DataclassWriter"]:
"""
Open a new `DataclassWriter` from a file path.
Expand All @@ -142,11 +143,11 @@ def open(
`exclude_fields`.
overwrite: If `True`, and `mode="write"`, the file specified at `path` will be
overwritten if it exists.
delimiter: The output file delimiter.
comment_prefix: The prefix for any comment/preface rows preceding the header row.
(This argument is ignored when `mode="write"`. It is used when `mode="append"` to
validate that the existing file's header matches the specified dataclass.)
**kwds: Additional keyword arguments to be passed to the `DataclassWriter` constructor.
delimiter: The output file delimiter.
**kwds: Additional keyword arguments to be passed to `csv.DictWriter`.

Yields:
A `DataclassWriter` instance.
Expand Down Expand Up @@ -178,6 +179,7 @@ def open(
dataclass_type=dataclass_type,
delimiter=delimiter,
comment_prefix=comment_prefix,
**kwargs,
)

fout = filepath.open(write_mode.abbreviation)
Expand All @@ -186,7 +188,7 @@ def open(
fout=fout,
dataclass_type=dataclass_type,
write_header=(write_mode is WriteMode.WRITE), # Skip header when appending
**kwds,
**kwargs,
)
finally:
fout.close()
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pathlib import Path

import pytest


@pytest.fixture(scope="session")
def datadir() -> Path:
"""Path to the test data directory."""

return Path(__file__).parent / "data"
3 changes: 3 additions & 0 deletions tests/data/reader_should_parse_quotes.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"id" "title"
"fake" "A fake object"
"also_fake" "Another fake object"
20 changes: 20 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,23 @@ class FakeDataclass:
assert isinstance(rows[0], FakeDataclass)
assert rows[0].foo == "abc"
assert rows[0].bar == 1


def test_reader_should_parse_quotes(datadir: Path) -> None:
"""
Test that having quotes around column names in header row doesn't break anything
https://github.com/msto/dataclass_io/issues/19
"""
fpath = datadir / "reader_should_parse_quotes.tsv"

@dataclass
class FakeDataclass:
id: str
title: str

# Parse CSV using DataclassReader
with DataclassReader.open(fpath, FakeDataclass) as reader:
records = [record for record in reader]

assert records[0] == FakeDataclass(id="fake", title="A fake object")
assert records[1] == FakeDataclass(id="also_fake", title="Another fake object")
Loading