diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 0d3418209..d04be5852 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -220,6 +220,12 @@ def _get_cell(data: DataFrameLike, row: int, column: str) -> Any: @_get_cell.register(PlDataFrame) def _(data: Any, row: int, column: str) -> Any: + import polars as pl + + # if container dtype, convert pl.Series to list + if isinstance(data[column].dtype, (pl.List, pl.Array)): + return data[column][row].to_list() + return data[column][row] diff --git a/tests/test_tbl_data.py b/tests/test_tbl_data.py index d58307240..5d0314aa4 100644 --- a/tests/test_tbl_data.py +++ b/tests/test_tbl_data.py @@ -1,11 +1,12 @@ import math + import pandas as pd import polars as pl -import pyarrow as pa import polars.testing +import pyarrow as pa import pytest + from great_tables import GT -from great_tables._utils_render_html import create_body_component_h from great_tables._tbl_data import ( DataFrameLike, SeriesLike, @@ -14,6 +15,7 @@ _set_cell, _validate_selector_list, cast_frame_to_string, + copy_frame, create_empty_frame, eval_aggregate, eval_select, @@ -24,8 +26,8 @@ to_frame, to_list, validate_frame, - copy_frame, ) +from great_tables._utils_render_html import create_body_component_h params_frames = [ pytest.param(pd.DataFrame, id="pandas"), @@ -38,6 +40,10 @@ pytest.param(pa.array, id="arrow"), pytest.param(lambda a: pa.chunked_array([a]), id="arrow-chunked"), ] +params_pl_container_dtypes = [ + pytest.param(pl.List, id="list"), + pytest.param(pl.Array, id="array"), +] @pytest.fixture(params=params_frames, scope="function") @@ -45,6 +51,26 @@ def df(request) -> pd.DataFrame: return request.param({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]}) +@pytest.fixture(params=params_pl_container_dtypes, scope="function") +def df_container_dtypes(request): + dtype_constructor = request.param + + if dtype_constructor == pl.List: + return pl.DataFrame( + {"col1": [1, 2, 3], "col2": [[1, 2, 3], [4, 5], [6, 7, 8]], "col3": ["a", "b", "c"]} + ) + # return a pl df with pl.Array columns + else: + col2_as_array = pl.col("col2").cast(pl.Array(pl.Int32, shape=(3,))) + return pl.DataFrame( + { + "col1": [1, 2, 3], + "col2": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "col3": ["a", "b", "c"], + } + ).with_columns(col2_as_array) + + @pytest.fixture(params=params_series, scope="function") def ser(request) -> SeriesLike: return request.param([1.0, 2.0, None]) @@ -75,6 +101,11 @@ def test_get_cell(df: DataFrameLike): assert _get_cell(df, 1, "col2") == "b" +def test_get_cell_container_dtypes(df_container_dtypes: pl.DataFrame): + "Checks that container dtype entries in polars dfs are returned as lists" + assert isinstance(_get_cell(df_container_dtypes, 1, "col2"), list) + + def test_set_cell(df: DataFrameLike): expected_data = {"col1": [1, 2, 3], "col2": ["a", "x", "c"], "col3": [4.0, 5.0, 6.0]} if isinstance(df, pa.Table):