diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index cb59c9fe7..c48cda974 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -5,7 +5,14 @@ from typing_extensions import TypeAlias from great_tables._locations import RowSelectExpr, resolve_cols_c, resolve_rows_i -from great_tables._tbl_data import DataFrameLike, SelectExpr, get_column_names, is_na +from great_tables._tbl_data import ( + DataFrameLike, + SelectExpr, + get_column_names, + get_rows, + is_na, + to_list, +) from great_tables.loc import body from great_tables.style import fill, text @@ -227,7 +234,7 @@ def data_color( # For each column targeted, get the data values as a new list object for col in columns_resolved: # This line handles both pandas and polars dataframes - column_vals = data_table[col][row_pos].to_list() + column_vals = to_list(get_rows(data_table[col], indexes=row_pos)) # Filter out NA values from `column_vals` filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)] diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 0d3418209..81f564fcd 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -29,10 +29,10 @@ PlSelectExpr = _selector_proxy_ PlExpr = pl.Expr - PdSeries = pd.Series + PdSeries = pd.Series[Any] PlSeries = pl.Series - PyArrowArray = pa.Array - PyArrowChunkedArray = pa.ChunkedArray + PyArrowArray = pa.Array[Any] + PyArrowChunkedArray = pa.ChunkedArray[Any] PdNA = pd.NA PlNull = pl.Null @@ -763,7 +763,7 @@ def _(df: PyArrowTable, x: Any) -> bool: import pyarrow as pa arr = pa.array([x]) - return arr.is_null().to_pylist()[0] or arr.is_nan().to_pylist()[0] + return arr.is_null(nan_is_null=True).to_pylist()[0] @singledispatch @@ -936,3 +936,25 @@ def _(df: PyArrowTable, expr: Callable[[PyArrowTable], PyArrowTable]) -> dict[st ) return {col: res.column(col)[0].as_py() for col in res.column_names} + + +@singledispatch +def get_rows(ser: SeriesLike, indexes: list[int]) -> SeriesLike: + """Returns values of the series at `indexes` position.`""" + raise NotImplementedError(f"Unsupported type: {type(ser)}") + + +@get_rows.register +def _(ser: PdSeries, indexes: list[int]) -> PdSeries: + return ser.iloc[indexes] + + +@get_rows.register +def _(ser: PlSeries, indexes: list[int]) -> PlSeries: + return ser[indexes] + + +@get_rows.register(PyArrowArray) +@get_rows.register(PyArrowChunkedArray) +def _(ser: Any, indexes: list[int]) -> PyArrowArray | PyArrowChunkedArray: + return ser.take(indexes) diff --git a/tests/data_color/__snapshots__/test_data_color.ambr b/tests/data_color/__snapshots__/test_data_color.ambr index 022e471c5..f8abd32ab 100644 --- a/tests/data_color/__snapshots__/test_data_color.ambr +++ b/tests/data_color/__snapshots__/test_data_color.ambr @@ -123,6 +123,32 @@ ''' # --- +# name: test_data_color_autocolor_text_false[pyarrow] + ''' +
+