Skip to content

Commit

Permalink
feat: permit str or path filename (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
msto authored Apr 14, 2024
1 parent e48e16b commit fa25463
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 26 deletions.
9 changes: 5 additions & 4 deletions dataclass_io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DataclassReader:

def __init__(
self,
path: Path,
filename: str | Path,
dataclass_type: type[DataclassInstance],
delimiter: str = "\t",
comment: str = "#",
Expand All @@ -42,17 +42,18 @@ def __init__(
TypeError: If the provided type is not a dataclass.
"""

filepath: Path = filename if isinstance(filename, Path) else Path(filename)
file_format = FileFormat(
delimiter=delimiter,
comment=comment,
)

assert_dataclass_is_valid(dataclass_type)
assert_file_is_readable(path)
assert_file_header_matches_dataclass(path, dataclass_type, file_format)
assert_file_is_readable(filepath)
assert_file_header_matches_dataclass(filepath, dataclass_type, file_format)

self._dataclass_type = dataclass_type
self._fin = path.open("r")
self._fin = filepath.open("r")
self._header = get_header(reader=self._fin, file_format=file_format)
self._reader = DictReader(
f=self._fin,
Expand Down
12 changes: 7 additions & 5 deletions dataclass_io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DataclassWriter:

def __init__(
self,
path: Path,
filename: str | Path,
dataclass_type: type[DataclassInstance],
mode: str = "write",
delimiter: str = "\t",
Expand Down Expand Up @@ -62,6 +62,8 @@ def __init__(
TypeError: If the provided type is not a dataclass.
"""

filepath: Path = filename if isinstance(filename, Path) else Path(filename)

try:
write_mode = WriteMode(mode)
except ValueError:
Expand All @@ -71,18 +73,18 @@ def __init__(

assert_dataclass_is_valid(dataclass_type)
if write_mode is WriteMode.WRITE:
assert_file_is_writable(path, overwrite=overwrite)
assert_file_is_writable(filepath, overwrite=overwrite)
else:
assert_file_is_appendable(path, dataclass_type=dataclass_type)
assert_file_header_matches_dataclass(path, dataclass_type, file_format)
assert_file_is_appendable(filepath, dataclass_type=dataclass_type)
assert_file_header_matches_dataclass(filepath, dataclass_type, file_format)

self._dataclass_type = dataclass_type
self._fieldnames = _validate_output_fieldnames(
dataclass_type=dataclass_type,
include_fields=include_fields,
exclude_fields=exclude_fields,
)
self._fout = path.open(write_mode.abbreviation)
self._fout = filepath.open(write_mode.abbreviation)
self._writer = DictWriter(
f=self._fout,
fieldnames=self._fieldnames,
Expand Down
18 changes: 16 additions & 2 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,25 @@ class FakeDataclass:
def test_reader(tmp_path: Path) -> None:
fpath = tmp_path / "test.txt"

with open(fpath, "w") as f:
with fpath.open("w") as f:
f.write("foo\tbar\n")
f.write("abc\t1\n")

with DataclassReader(path=fpath, dataclass_type=FakeDataclass) as reader:
with DataclassReader(filename=fpath, dataclass_type=FakeDataclass) as reader:
rows = [row for row in reader]

assert rows[0] == FakeDataclass(foo="abc", bar=1)


def test_reader_from_str(tmp_path: Path) -> None:
"""Test that we can create a reader when `filename` is a `str`."""
fpath = tmp_path / "test.txt"

with fpath.open("w") as f:
f.write("foo\tbar\n")
f.write("abc\t1\n")

with DataclassReader(filename=str(fpath), dataclass_type=FakeDataclass) as reader:
rows = [row for row in reader]

assert rows[0] == FakeDataclass(foo="abc", bar=1)
44 changes: 29 additions & 15 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,43 @@ class FakeDataclass:
def test_writer(tmp_path: Path) -> None:
fpath = tmp_path / "test.txt"

with DataclassWriter(path=fpath, mode="write", dataclass_type=FakeDataclass) as writer:
with DataclassWriter(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer:
writer.write(FakeDataclass(foo="abc", bar=1))
writer.write(FakeDataclass(foo="def", bar=2))

with open(fpath, "r") as f:
with fpath.open("r") as f:
assert next(f) == "foo\tbar\n"
assert next(f) == "abc\t1\n"
assert next(f) == "def\t2\n"
with pytest.raises(StopIteration):
next(f)


def test_writer_from_str(tmp_path: Path) -> None:
"""Test that we can create a writer when `filename` is a `str`."""
fpath = tmp_path / "test.txt"

with DataclassWriter(filename=str(fpath), mode="write", dataclass_type=FakeDataclass) as writer:
writer.write(FakeDataclass(foo="abc", bar=1))

with fpath.open("r") as f:
assert next(f) == "foo\tbar\n"
assert next(f) == "abc\t1\n"
with pytest.raises(StopIteration):
next(f)


def test_writer_writeall(tmp_path: Path) -> None:
fpath = tmp_path / "test.txt"

data = [
FakeDataclass(foo="abc", bar=1),
FakeDataclass(foo="def", bar=2),
]
with DataclassWriter(path=fpath, mode="write", dataclass_type=FakeDataclass) as writer:
with DataclassWriter(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer:
writer.writeall(data)

with open(fpath, "r") as f:
with fpath.open("r") as f:
assert next(f) == "foo\tbar\n"
assert next(f) == "abc\t1\n"
assert next(f) == "def\t2\n"
Expand All @@ -52,11 +66,11 @@ def test_writer_append(tmp_path: Path) -> None:
with fpath.open("w") as fout:
fout.write("foo\tbar\n")

with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
writer.write(FakeDataclass(foo="abc", bar=1))
writer.write(FakeDataclass(foo="def", bar=2))

with open(fpath, "r") as f:
with fpath.open("r") as f:
assert next(f) == "foo\tbar\n"
assert next(f) == "abc\t1\n"
assert next(f) == "def\t2\n"
Expand All @@ -70,7 +84,7 @@ def test_writer_append_raises_if_empty(tmp_path: Path) -> None:
fpath.touch()

with pytest.raises(ValueError, match="The specified output file is empty"):
with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
writer.write(FakeDataclass(foo="abc", bar=1))


Expand All @@ -81,7 +95,7 @@ def test_writer_append_raises_if_no_header(tmp_path: Path) -> None:
fout.write("abc\t1\n")

with pytest.raises(ValueError, match="The provided file does not have the same field names"):
with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
writer.write(FakeDataclass(foo="abc", bar=1))


Expand All @@ -96,7 +110,7 @@ def test_writer_append_raises_if_header_does_not_match(tmp_path: Path) -> None:
fout.write("foo\tbar\tbaz\n")

with pytest.raises(ValueError, match="The provided file does not have the same field names"):
with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer:
writer.write(FakeDataclass(foo="abc", bar=1))


Expand All @@ -109,14 +123,14 @@ def test_writer_include_fields(tmp_path: Path) -> None:
FakeDataclass(foo="def", bar=2),
]
with DataclassWriter(
path=fpath,
filename=fpath,
mode="write",
dataclass_type=FakeDataclass,
include_fields=["foo"],
) as writer:
writer.writeall(data)

with open(fpath, "r") as f:
with fpath.open("r") as f:
assert next(f) == "foo\n"
assert next(f) == "abc\n"
assert next(f) == "def\n"
Expand All @@ -133,14 +147,14 @@ def test_writer_include_fields_reorders(tmp_path: Path) -> None:
FakeDataclass(foo="def", bar=2),
]
with DataclassWriter(
path=fpath,
filename=fpath,
mode="write",
dataclass_type=FakeDataclass,
include_fields=["bar", "foo"],
) as writer:
writer.writeall(data)

with open(fpath, "r") as f:
with fpath.open("r") as f:
assert next(f) == "bar\tfoo\n"
assert next(f) == "1\tabc\n"
assert next(f) == "2\tdef\n"
Expand All @@ -158,14 +172,14 @@ def test_writer_exclude_fields(tmp_path: Path) -> None:
FakeDataclass(foo="def", bar=2),
]
with DataclassWriter(
path=fpath,
filename=fpath,
mode="write",
dataclass_type=FakeDataclass,
exclude_fields=["bar"],
) as writer:
writer.writeall(data)

with open(fpath, "r") as f:
with fpath.open("r") as f:
assert next(f) == "foo\n"
assert next(f) == "abc\n"
assert next(f) == "def\n"
Expand Down

0 comments on commit fa25463

Please sign in to comment.