Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions great_tables/_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ def _get_cell(data: DataFrameLike, row: int, column: str) -> Any:

@_get_cell.register(PlDataFrame)
def _(data: Any, row: int, column: str) -> Any:
import polars as pl

# if container dtype, convert pl.Series to list
if isinstance(data[column].dtype, (pl.List, pl.Array)):
return data[column][row].to_list()

return data[column][row]


Expand Down
37 changes: 34 additions & 3 deletions tests/test_tbl_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math

import pandas as pd
import polars as pl
import pyarrow as pa
import polars.testing
import pyarrow as pa
import pytest

from great_tables import GT
from great_tables._utils_render_html import create_body_component_h
from great_tables._tbl_data import (
DataFrameLike,
SeriesLike,
Expand All @@ -14,6 +15,7 @@
_set_cell,
_validate_selector_list,
cast_frame_to_string,
copy_frame,
create_empty_frame,
eval_aggregate,
eval_select,
Expand All @@ -24,8 +26,8 @@
to_frame,
to_list,
validate_frame,
copy_frame,
)
from great_tables._utils_render_html import create_body_component_h

params_frames = [
pytest.param(pd.DataFrame, id="pandas"),
Expand All @@ -38,13 +40,37 @@
pytest.param(pa.array, id="arrow"),
pytest.param(lambda a: pa.chunked_array([a]), id="arrow-chunked"),
]
params_pl_container_dtypes = [
pytest.param(pl.List, id="list"),
pytest.param(pl.Array, id="array"),
]


@pytest.fixture(params=params_frames, scope="function")
def df(request) -> pd.DataFrame:
return request.param({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})


@pytest.fixture(params=params_pl_container_dtypes, scope="function")
def df_container_dtypes(request):
dtype_constructor = request.param

if dtype_constructor == pl.List:
return pl.DataFrame(
{"col1": [1, 2, 3], "col2": [[1, 2, 3], [4, 5], [6, 7, 8]], "col3": ["a", "b", "c"]}
)
# return a pl df with pl.Array columns
else:
col2_as_array = pl.col("col2").cast(pl.Array(pl.Int32, shape=(3,)))
return pl.DataFrame(
{
"col1": [1, 2, 3],
"col2": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
"col3": ["a", "b", "c"],
}
).with_columns(col2_as_array)


@pytest.fixture(params=params_series, scope="function")
def ser(request) -> SeriesLike:
return request.param([1.0, 2.0, None])
Expand Down Expand Up @@ -75,6 +101,11 @@ def test_get_cell(df: DataFrameLike):
assert _get_cell(df, 1, "col2") == "b"


def test_get_cell_container_dtypes(df_container_dtypes: pl.DataFrame):
"Checks that container dtype entries in polars dfs are returned as lists"
assert isinstance(_get_cell(df_container_dtypes, 1, "col2"), list)


def test_set_cell(df: DataFrameLike):
expected_data = {"col1": [1, 2, 3], "col2": ["a", "x", "c"], "col3": [4.0, 5.0, 6.0]}
if isinstance(df, pa.Table):
Expand Down