diff --git a/dataclass_io/reader.py b/dataclass_io/reader.py index 9b171c1..3768ed4 100644 --- a/dataclass_io/reader.py +++ b/dataclass_io/reader.py @@ -24,7 +24,7 @@ class DataclassReader: def __init__( self, - path: Path, + filename: str | Path, dataclass_type: type[DataclassInstance], delimiter: str = "\t", comment: str = "#", @@ -42,17 +42,18 @@ def __init__( TypeError: If the provided type is not a dataclass. """ + filepath: Path = filename if isinstance(filename, Path) else Path(filename) file_format = FileFormat( delimiter=delimiter, comment=comment, ) assert_dataclass_is_valid(dataclass_type) - assert_file_is_readable(path) - assert_file_header_matches_dataclass(path, dataclass_type, file_format) + assert_file_is_readable(filepath) + assert_file_header_matches_dataclass(filepath, dataclass_type, file_format) self._dataclass_type = dataclass_type - self._fin = path.open("r") + self._fin = filepath.open("r") self._header = get_header(reader=self._fin, file_format=file_format) self._reader = DictReader( f=self._fin, diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py index c439568..7cefb05 100644 --- a/dataclass_io/writer.py +++ b/dataclass_io/writer.py @@ -26,7 +26,7 @@ class DataclassWriter: def __init__( self, - path: Path, + filename: str | Path, dataclass_type: type[DataclassInstance], mode: str = "write", delimiter: str = "\t", @@ -62,6 +62,8 @@ def __init__( TypeError: If the provided type is not a dataclass. """ + filepath: Path = filename if isinstance(filename, Path) else Path(filename) + try: write_mode = WriteMode(mode) except ValueError: @@ -71,10 +73,10 @@ def __init__( assert_dataclass_is_valid(dataclass_type) if write_mode is WriteMode.WRITE: - assert_file_is_writable(path, overwrite=overwrite) + assert_file_is_writable(filepath, overwrite=overwrite) else: - assert_file_is_appendable(path, dataclass_type=dataclass_type) - assert_file_header_matches_dataclass(path, dataclass_type, file_format) + assert_file_is_appendable(filepath, dataclass_type=dataclass_type) + assert_file_header_matches_dataclass(filepath, dataclass_type, file_format) self._dataclass_type = dataclass_type self._fieldnames = _validate_output_fieldnames( @@ -82,7 +84,7 @@ def __init__( include_fields=include_fields, exclude_fields=exclude_fields, ) - self._fout = path.open(write_mode.abbreviation) + self._fout = filepath.open(write_mode.abbreviation) self._writer = DictWriter( f=self._fout, fieldnames=self._fieldnames, diff --git a/tests/test_reader.py b/tests/test_reader.py index e0c5602..8a9ea22 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -13,11 +13,25 @@ class FakeDataclass: def test_reader(tmp_path: Path) -> None: fpath = tmp_path / "test.txt" - with open(fpath, "w") as f: + with fpath.open("w") as f: f.write("foo\tbar\n") f.write("abc\t1\n") - with DataclassReader(path=fpath, dataclass_type=FakeDataclass) as reader: + with DataclassReader(filename=fpath, dataclass_type=FakeDataclass) as reader: + rows = [row for row in reader] + + assert rows[0] == FakeDataclass(foo="abc", bar=1) + + +def test_reader_from_str(tmp_path: Path) -> None: + """Test that we can create a reader when `filename` is a `str`.""" + fpath = tmp_path / "test.txt" + + with fpath.open("w") as f: + f.write("foo\tbar\n") + f.write("abc\t1\n") + + with DataclassReader(filename=str(fpath), dataclass_type=FakeDataclass) as reader: rows = [row for row in reader] assert rows[0] == FakeDataclass(foo="abc", bar=1) diff --git a/tests/test_writer.py b/tests/test_writer.py index 7f5d265..1e997d0 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -15,11 +15,11 @@ class FakeDataclass: def test_writer(tmp_path: Path) -> None: fpath = tmp_path / "test.txt" - with DataclassWriter(path=fpath, mode="write", dataclass_type=FakeDataclass) as writer: + with DataclassWriter(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) writer.write(FakeDataclass(foo="def", bar=2)) - with open(fpath, "r") as f: + with fpath.open("r") as f: assert next(f) == "foo\tbar\n" assert next(f) == "abc\t1\n" assert next(f) == "def\t2\n" @@ -27,6 +27,20 @@ def test_writer(tmp_path: Path) -> None: next(f) +def test_writer_from_str(tmp_path: Path) -> None: + """Test that we can create a writer when `filename` is a `str`.""" + fpath = tmp_path / "test.txt" + + with DataclassWriter(filename=str(fpath), mode="write", dataclass_type=FakeDataclass) as writer: + writer.write(FakeDataclass(foo="abc", bar=1)) + + with fpath.open("r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + with pytest.raises(StopIteration): + next(f) + + def test_writer_writeall(tmp_path: Path) -> None: fpath = tmp_path / "test.txt" @@ -34,10 +48,10 @@ def test_writer_writeall(tmp_path: Path) -> None: FakeDataclass(foo="abc", bar=1), FakeDataclass(foo="def", bar=2), ] - with DataclassWriter(path=fpath, mode="write", dataclass_type=FakeDataclass) as writer: + with DataclassWriter(filename=fpath, mode="write", dataclass_type=FakeDataclass) as writer: writer.writeall(data) - with open(fpath, "r") as f: + with fpath.open("r") as f: assert next(f) == "foo\tbar\n" assert next(f) == "abc\t1\n" assert next(f) == "def\t2\n" @@ -52,11 +66,11 @@ def test_writer_append(tmp_path: Path) -> None: with fpath.open("w") as fout: fout.write("foo\tbar\n") - with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) writer.write(FakeDataclass(foo="def", bar=2)) - with open(fpath, "r") as f: + with fpath.open("r") as f: assert next(f) == "foo\tbar\n" assert next(f) == "abc\t1\n" assert next(f) == "def\t2\n" @@ -70,7 +84,7 @@ def test_writer_append_raises_if_empty(tmp_path: Path) -> None: fpath.touch() with pytest.raises(ValueError, match="The specified output file is empty"): - with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) @@ -81,7 +95,7 @@ def test_writer_append_raises_if_no_header(tmp_path: Path) -> None: fout.write("abc\t1\n") with pytest.raises(ValueError, match="The provided file does not have the same field names"): - with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) @@ -96,7 +110,7 @@ def test_writer_append_raises_if_header_does_not_match(tmp_path: Path) -> None: fout.write("foo\tbar\tbaz\n") with pytest.raises(ValueError, match="The provided file does not have the same field names"): - with DataclassWriter(path=fpath, mode="append", dataclass_type=FakeDataclass) as writer: + with DataclassWriter(filename=fpath, mode="append", dataclass_type=FakeDataclass) as writer: writer.write(FakeDataclass(foo="abc", bar=1)) @@ -109,14 +123,14 @@ def test_writer_include_fields(tmp_path: Path) -> None: FakeDataclass(foo="def", bar=2), ] with DataclassWriter( - path=fpath, + filename=fpath, mode="write", dataclass_type=FakeDataclass, include_fields=["foo"], ) as writer: writer.writeall(data) - with open(fpath, "r") as f: + with fpath.open("r") as f: assert next(f) == "foo\n" assert next(f) == "abc\n" assert next(f) == "def\n" @@ -133,14 +147,14 @@ def test_writer_include_fields_reorders(tmp_path: Path) -> None: FakeDataclass(foo="def", bar=2), ] with DataclassWriter( - path=fpath, + filename=fpath, mode="write", dataclass_type=FakeDataclass, include_fields=["bar", "foo"], ) as writer: writer.writeall(data) - with open(fpath, "r") as f: + with fpath.open("r") as f: assert next(f) == "bar\tfoo\n" assert next(f) == "1\tabc\n" assert next(f) == "2\tdef\n" @@ -158,14 +172,14 @@ def test_writer_exclude_fields(tmp_path: Path) -> None: FakeDataclass(foo="def", bar=2), ] with DataclassWriter( - path=fpath, + filename=fpath, mode="write", dataclass_type=FakeDataclass, exclude_fields=["bar"], ) as writer: writer.writeall(data) - with open(fpath, "r") as f: + with fpath.open("r") as f: assert next(f) == "foo\n" assert next(f) == "abc\n" assert next(f) == "def\n"