Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize timestamps in arrow tables #3076

Merged
merged 6 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,31 @@ def to_list_if_array(val):
return df


def sanitize_arrow_table(pa_table):
"""Sanitize arrow table for JSON serialization"""
import pyarrow as pa
import pyarrow.compute as pc

arrays = []
schema = pa_table.schema
for name in schema.names:
array = pa_table[name]
dtype = schema.field(name).type
if str(dtype).startswith("timestamp"):
arrays.append(pc.strftime(array))
elif str(dtype).startswith("duration"):
raise ValueError(
'Field "{col_name}" has type "{dtype}" which is '
"not supported by Altair. Please convert to "
"either a timestamp or a numerical value."
"".format(col_name=name, dtype=dtype)
)
else:
arrays.append(array)

return pa.Table.from_arrays(arrays, names=schema.names)


def parse_shorthand(
shorthand,
data=None,
Expand Down
6 changes: 2 additions & 4 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from toolz import curried
from typing import Callable

from .core import sanitize_dataframe
from .core import sanitize_dataframe, sanitize_arrow_table
from .core import sanitize_geo_interface
from .deprecation import AltairDeprecationWarning
from .plugin_registry import PluginRegistry
Expand Down Expand Up @@ -166,7 +166,7 @@ def to_values(data):
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
pa_table = sanitize_arrow_table(pi.from_dataframe(data))
return {"values": pa_table.to_pylist()}


Expand All @@ -185,8 +185,6 @@ def check_data_type(data):
# ==============================================================================
# Private utilities
# ==============================================================================


def _compute_data_hash(data_str):
return hashlib.md5(data_str.encode()).hexdigest()

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ dev = [
"mypy",
"pandas-stubs",
"types-jsonschema",
"types-setuptools"
"types-setuptools",
"pyarrow>=11"
]
doc = [
"sphinx",
Expand Down
57 changes: 57 additions & 0 deletions tests/utils/test_dataframe_interchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from datetime import datetime
import pyarrow as pa
import pandas as pd
import pytest
import sys
import os

from altair.utils.data import to_values


def windows_has_tzdata():
"""
From PyArrow: python/pyarrow/tests/util.py

This is the default location where tz.cpp will look for (until we make
this configurable at run-time)
"""
tzdata_path = os.path.expandvars(r"%USERPROFILE%\Downloads\tzdata")
return os.path.exists(tzdata_path)


# Skip test on Windows when the tz database is not configured.
# See https://github.com/altair-viz/altair/issues/3050.
@pytest.mark.skipif(
sys.platform == "win32" and not windows_has_tzdata(),
reason="Timezone database is not installed on Windows",
)
def test_arrow_timestamp_conversion():
"""Test that arrow timestamp values are converted to ISO-8601 strings"""
data = {
"date": [datetime(2004, 8, 1), datetime(2004, 9, 1), None],
"value": [102, 129, 139],
}
pa_table = pa.table(data)

values = to_values(pa_table)
expected_values = {
"values": [
{"date": "2004-08-01T00:00:00.000000", "value": 102},
{"date": "2004-09-01T00:00:00.000000", "value": 129},
{"date": None, "value": 139},
]
}
assert values == expected_values


def test_duration_raises():
td = pd.timedelta_range(0, periods=3, freq="h")
df = pd.DataFrame(td).reset_index()
df.columns = ["id", "timedelta"]
pa_table = pa.table(df)
with pytest.raises(ValueError) as e:
to_values(pa_table)

# Check that exception mentions the duration[ns] type,
# which is what the pandas timedelta is converted into
assert "duration[ns]" in e.value.args[0]