Skip to content

Commit

Permalink
Sanitize timestamps in arrow tables (#3076)
Browse files Browse the repository at this point in the history
* Convert timestamps in arrow tables to ISO-8601 strings before JSON serialization

* Move sanitize_arrow_table to core

* Add null time entry to test

* Raise for arrow duration columns

* Skip test_arrow_timestamp_conversion on windows when timezone database is not available
  • Loading branch information
jonmmease authored Jun 7, 2023
1 parent 676d72a commit 8bee779
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 5 deletions.
25 changes: 25 additions & 0 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,31 @@ def to_list_if_array(val):
return df


def sanitize_arrow_table(pa_table):
"""Sanitize arrow table for JSON serialization"""
import pyarrow as pa
import pyarrow.compute as pc

arrays = []
schema = pa_table.schema
for name in schema.names:
array = pa_table[name]
dtype = schema.field(name).type
if str(dtype).startswith("timestamp"):
arrays.append(pc.strftime(array))
elif str(dtype).startswith("duration"):
raise ValueError(
'Field "{col_name}" has type "{dtype}" which is '
"not supported by Altair. Please convert to "
"either a timestamp or a numerical value."
"".format(col_name=name, dtype=dtype)
)
else:
arrays.append(array)

return pa.Table.from_arrays(arrays, names=schema.names)


def parse_shorthand(
shorthand,
data=None,
Expand Down
6 changes: 2 additions & 4 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from toolz import curried
from typing import Callable

from .core import sanitize_dataframe
from .core import sanitize_dataframe, sanitize_arrow_table
from .core import sanitize_geo_interface
from .deprecation import AltairDeprecationWarning
from .plugin_registry import PluginRegistry
Expand Down Expand Up @@ -166,7 +166,7 @@ def to_values(data):
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
pa_table = sanitize_arrow_table(pi.from_dataframe(data))
return {"values": pa_table.to_pylist()}


Expand All @@ -185,8 +185,6 @@ def check_data_type(data):
# ==============================================================================
# Private utilities
# ==============================================================================


def _compute_data_hash(data_str):
return hashlib.md5(data_str.encode()).hexdigest()

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ dev = [
"mypy",
"pandas-stubs",
"types-jsonschema",
"types-setuptools"
"types-setuptools",
"pyarrow>=11"
]
doc = [
"sphinx",
Expand Down
57 changes: 57 additions & 0 deletions tests/utils/test_dataframe_interchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from datetime import datetime
import pyarrow as pa
import pandas as pd
import pytest
import sys
import os

from altair.utils.data import to_values


def windows_has_tzdata():
"""
From PyArrow: python/pyarrow/tests/util.py
This is the default location where tz.cpp will look for (until we make
this configurable at run-time)
"""
tzdata_path = os.path.expandvars(r"%USERPROFILE%\Downloads\tzdata")
return os.path.exists(tzdata_path)


# Skip test on Windows when the tz database is not configured.
# See https://github.com/altair-viz/altair/issues/3050.
@pytest.mark.skipif(
sys.platform == "win32" and not windows_has_tzdata(),
reason="Timezone database is not installed on Windows",
)
def test_arrow_timestamp_conversion():
"""Test that arrow timestamp values are converted to ISO-8601 strings"""
data = {
"date": [datetime(2004, 8, 1), datetime(2004, 9, 1), None],
"value": [102, 129, 139],
}
pa_table = pa.table(data)

values = to_values(pa_table)
expected_values = {
"values": [
{"date": "2004-08-01T00:00:00.000000", "value": 102},
{"date": "2004-09-01T00:00:00.000000", "value": 129},
{"date": None, "value": 139},
]
}
assert values == expected_values


def test_duration_raises():
td = pd.timedelta_range(0, periods=3, freq="h")
df = pd.DataFrame(td).reset_index()
df.columns = ["id", "timedelta"]
pa_table = pa.table(df)
with pytest.raises(ValueError) as e:
to_values(pa_table)

# Check that exception mentions the duration[ns] type,
# which is what the pandas timedelta is converted into
assert "duration[ns]" in e.value.args[0]

0 comments on commit 8bee779

Please sign in to comment.