diff --git a/README.md b/README.md index 31955ddb..a9b7924b 100644 --- a/README.md +++ b/README.md @@ -67,9 +67,10 @@ board.pin_write(mtcars.head(), "mtcars", type="csv") Above, we saved the data as a CSV, but depending on what you’re saving and who else you want to read it, you might use the `type` argument to -instead save it as a `joblib`, `parquet`, or `json` file. +instead save it as a `joblib`, `parquet`, or `json` file. If you're using +a `polars.DataFrame`, you can save to `parquet`. -You can later retrieve the pinned data with `.pin_read()`: +You can later retrieve the pinned data as a `pandas.DataFrame` with `.pin_read()`: ``` python board.pin_read("mtcars") diff --git a/README.qmd b/README.qmd index 98af5826..a1982e99 100644 --- a/README.qmd +++ b/README.qmd @@ -66,8 +66,9 @@ board.pin_write(mtcars.head(), "mtcars", type="csv") Above, we saved the data as a CSV, but depending on what you’re saving and who else you want to read it, you might use the `type` argument to instead save it as a `joblib`, `parquet`, or `json` file. +If you're using a `polars.DataFrame`, you can save to `parquet`. -You can later retrieve the pinned data with `.pin_read()`: +You can later retrieve the pinned data as a `pandas.DataFrame` with `.pin_read()`: ```{python} board.pin_read("mtcars") diff --git a/pins/drivers.py b/pins/drivers.py index 5aa3e186..57cafdf1 100644 --- a/pins/drivers.py +++ b/pins/drivers.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Sequence +from typing import Literal, Sequence, TypeAlias from .config import PINS_ENV_INSECURE_READ, get_allow_pickle_read from .errors import PinsInsecureReadError @@ -11,15 +11,7 @@ UNSAFE_TYPES = frozenset(["joblib"]) REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib", "file"]) - - -def _assert_is_pandas_df(x, file_type: str) -> None: - import pandas as pd - - if not isinstance(x, pd.DataFrame): - raise NotImplementedError( - f"Currently only pandas.DataFrame can be saved as type {file_type!r}." - ) +_DFLib: TypeAlias = Literal["pandas", "polars"] def load_path(meta, path_to_version): @@ -152,28 +144,31 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen final_name = f"{fname}{suffix}" if type == "csv": - _assert_is_pandas_df(obj, file_type=type) - + _choose_df_lib(obj, supported_libs=["pandas"], file_type=type) obj.to_csv(final_name, index=False) elif type == "arrow": # NOTE: R pins accepts the type arrow, and saves it as feather. # we allow reading this type, but raise an error for writing. - _assert_is_pandas_df(obj, file_type=type) - + _choose_df_lib(obj, supported_libs=["pandas"], file_type=type) obj.to_feather(final_name) elif type == "feather": - _assert_is_pandas_df(obj, file_type=type) + _choose_df_lib(obj, supported_libs=["pandas"], file_type=type) raise NotImplementedError( 'Saving data as type "feather" no longer supported. Use type "arrow" instead.' ) elif type == "parquet": - _assert_is_pandas_df(obj, file_type=type) + df_lib = _choose_df_lib(obj, supported_libs=["pandas", "polars"], file_type=type) - obj.to_parquet(final_name) + if df_lib == "pandas": + obj.to_parquet(final_name) + elif df_lib == "polars": + obj.write_parquet(final_name) + else: + raise NotImplementedError elif type == "joblib": import joblib @@ -200,13 +195,94 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen def default_title(obj, name): + try: + df_lib = _choose_df_lib(obj) + except NotImplementedError: + obj_name = type(obj).__qualname__ + return f"{name}: a pinned {obj_name} object" + + _df_lib_to_objname: dict[_DFLib, str] = { + "polars": "DataFrame", + "pandas": "DataFrame", + } + + # TODO(compat): title says CSV rather than data.frame + # see https://github.com/machow/pins-python/issues/5 + shape_str = " x ".join(map(str, obj.shape)) + return f"{name}: a pinned {shape_str} {_df_lib_to_objname[df_lib]}" + + +def _choose_df_lib( + df, + *, + supported_libs: list[_DFLib] | None = None, + file_type: str | None = None, +) -> _DFLib: + """Return the library associated with a DataFrame, e.g. "pandas". + + The arguments `supported_libs` and `file_type` must be specified together, and are + meant to be used when saving an object, to choose the appropriate library. + + Args: + df: + The object to check - might not be a DataFrame necessarily. + supported_libs: + The DataFrame libraries to accept for this df. + file_type: + The file type we're trying to save to - used to give more specific error + messages. + + Raises: + NotImplementedError: If the DataFrame type is not recognized, or not supported. + """ + if (supported_libs is None) + (file_type is None) == 1: + raise ValueError("Must provide both or neither of supported_libs and file_type") + + df_libs: list[_DFLib] = [] + + # pandas import pandas as pd - if isinstance(obj, pd.DataFrame): - # TODO(compat): title says CSV rather than data.frame - # see https://github.com/machow/pins-python/issues/5 - shape_str = " x ".join(map(str, obj.shape)) - return f"{name}: a pinned {shape_str} DataFrame" + if isinstance(df, pd.DataFrame): + df_libs.append("pandas") + + # polars + try: + import polars as pl + except ModuleNotFoundError: + pass else: - obj_name = type(obj).__qualname__ - return f"{name}: a pinned {obj_name} object" + if isinstance(df, pl.DataFrame): + df_libs.append("polars") + + # Make sure there's only one library associated with the dataframe + if len(df_libs) == 1: + (df_lib,) = df_libs + elif len(df_libs) > 1: + msg = ( + f"Hybrid DataFrames are not supported: " + f"should only be one of {supported_libs!r}, " + f"but got an object from multiple libraries {df_libs!r}." + ) + raise NotImplementedError(msg) + else: + raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}") + + # Raise if the library is not supported + if supported_libs is not None and df_lib not in supported_libs: + ftype_clause = f"for type {file_type!r}" + + if len(supported_libs) == 1: + msg = ( + f"Currently only {supported_libs[0]} DataFrames can be saved " + f"{ftype_clause}. DataFrames from {df_lib} are not yet supported." + ) + else: + msg = ( + f"Currently only DataFrames from the following libraries can be saved " + f"{ftype_clause}: {supported_libs!r}." + ) + + raise NotImplementedError(msg) + + return df_lib diff --git a/pins/tests/test_drivers.py b/pins/tests/test_drivers.py index 230f0e80..5f8d92b3 100644 --- a/pins/tests/test_drivers.py +++ b/pins/tests/test_drivers.py @@ -2,10 +2,11 @@ import fsspec import pandas as pd +import polars as pl import pytest from pins.config import PINS_ENV_INSECURE_READ -from pins.drivers import default_title, load_data, save_data +from pins.drivers import _choose_df_lib, default_title, load_data, save_data from pins.errors import PinsInsecureReadError from pins.meta import MetaRaw from pins.tests.helpers import rm_env @@ -34,6 +35,7 @@ class D: [ (pd.DataFrame({"x": [1, 2]}), "somename: a pinned 2 x 1 DataFrame"), (pd.DataFrame({"x": [1], "y": [2]}), "somename: a pinned 1 x 2 DataFrame"), + (pl.DataFrame({"x": [1, 2]}), "somename: a pinned 2 x 1 DataFrame"), (ExC(), "somename: a pinned ExC object"), (ExC().D(), "somename: a pinned ExC.D object"), ([1, 2, 3], "somename: a pinned list object"), @@ -76,6 +78,36 @@ def test_driver_roundtrip(tmp_path: Path, type_): assert df.equals(obj) +@pytest.mark.parametrize( + "type_", + [ + "parquet", + ], +) +def test_driver_polars_roundtrip(tmp_path, type_): + import polars as pl + + df = pl.DataFrame({"x": [1, 2, 3]}) + + fname = "some_df" + full_file = f"{fname}.{type_}" + + p_obj = tmp_path / fname + res_fname = save_data(df, p_obj, type_) + + assert Path(res_fname).name == full_file + + meta = MetaRaw(full_file, type_, "my_pin") + pandas_df = load_data( + meta, fsspec.filesystem("file"), tmp_path, allow_pickle_read=True + ) + + # Convert from pandas to polars + obj = pl.DataFrame(pandas_df) + + assert df.equals(obj) + + @pytest.mark.parametrize( "type_", [ @@ -159,3 +191,57 @@ def test_driver_apply_suffix_false(tmp_path: Path): res_fname = save_data(df, p_obj, type_, apply_suffix=False) assert Path(res_fname).name == "some_df" + + +class TestChooseDFLib: + def test_pandas(self): + assert _choose_df_lib(pd.DataFrame({"x": [1]})) == "pandas" + + def test_polars(self): + assert _choose_df_lib(pl.DataFrame({"x": [1]})) == "polars" + + def test_list_raises(self): + with pytest.raises( + NotImplementedError, match="Unrecognized DataFrame type: " + ): + _choose_df_lib([]) + + def test_pandas_subclass(self): + class MyDataFrame(pd.DataFrame): + pass + + assert _choose_df_lib(MyDataFrame({"x": [1]})) == "pandas" + + def test_ftype_compatible(self): + assert ( + _choose_df_lib( + pd.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv" + ) + == "pandas" + ) + + def test_ftype_incompatible(self): + with pytest.raises( + NotImplementedError, + match=( + "Currently only pandas DataFrames can be saved for type 'csv'. " + "DataFrames from polars are not yet supported." + ), + ): + _choose_df_lib( + pl.DataFrame({"x": [1]}), supported_libs=["pandas"], file_type="csv" + ) + + def test_supported_alone_raises(self): + with pytest.raises( + ValueError, + match="Must provide both or neither of supported_libs and file_type", + ): + _choose_df_lib(..., supported_libs=["pandas"]) + + def test_file_type_alone_raises(self): + with pytest.raises( + ValueError, + match="Must provide both or neither of supported_libs and file_type", + ): + _choose_df_lib(..., file_type="csv") diff --git a/pyproject.toml b/pyproject.toml index 3497c51a..215bf07f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ test = [ "pytest-dotenv", "pytest-parallel", "s3fs", + "polars>=1.0.0", ] [build-system] diff --git a/requirements/dev.txt b/requirements/dev.txt index 2b2a43a0..401b48eb 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -266,6 +266,8 @@ pluggy==1.5.0 # via pytest plum-dispatch==2.5.1.post1 # via quartodoc +polars==1.2.1 + # via pins (setup.cfg) portalocker==2.10.1 # via msal-extensions pre-commit==3.7.1