Skip to content

Commit

Permalink
Fix handling of timestamps when ingesting from CSV
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-thom committed Nov 12, 2024
1 parent ac69a9e commit 2260b95
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 128 deletions.
17 changes: 15 additions & 2 deletions src/chronify/csv_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@
from duckdb import DuckDBPyRelation

from chronify.models import CsvTableSchema, get_duckdb_type_from_sqlalchemy
from chronify.time_configs import DatetimeRange


def read_csv(path: Path | str, schema: CsvTableSchema, **kwargs: Any) -> DuckDBPyRelation:
"""Read a CSV file into a DuckDB relation."""
if schema.column_dtypes:
dtypes = {x.name: get_duckdb_type_from_sqlalchemy(x.dtype) for x in schema.column_dtypes}
dtypes = {
x.name: get_duckdb_type_from_sqlalchemy(x.dtype).id for x in schema.column_dtypes
}
rel = duckdb.read_csv(str(path), dtype=dtypes, **kwargs)
else:
rel = duckdb.read_csv(str(path), **kwargs)

expr = ",".join(rel.columns)
time_config = schema.time_config
exprs = []
for i, column in enumerate(rel.columns):
expr = column
if isinstance(time_config, DatetimeRange) and column == time_config.time_column:
time_type = rel.types[i]
if time_type == duckdb.typing.TIMESTAMP and time_config.start.tzinfo is not None: # type: ignore
expr = f"timezone('{time_config.start.tzinfo.key}', {column}) AS {column}" # type: ignore
exprs.append(expr)

expr = ",".join(exprs)
return duckdb.sql(f"SELECT {expr} FROM rel")
60 changes: 39 additions & 21 deletions src/chronify/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from typing import Any, Optional, Type
from typing import Any, Optional

import duckdb.typing
from duckdb.typing import DuckDBPyType
from pydantic import Field, field_validator, model_validator
from sqlalchemy import BigInteger, Boolean, DateTime, Double, Integer, String
from typing_extensions import Annotated
Expand Down Expand Up @@ -51,7 +52,8 @@ class TableSchema(TableSchemaBase):
"""Defines the schema for a time series table stored in the database."""

name: Annotated[
str, Field(description="Name of the table or view in the database.", frozen=True)
str,
Field(description="Name of the table or view in the database.", frozen=True),
]
value_column: Annotated[str, Field(description="Column in the table that contain values.")]

Expand Down Expand Up @@ -89,35 +91,51 @@ def list_columns(self) -> list[str]:
duckdb.typing.BOOLEAN.id: Boolean, # type: ignore
duckdb.typing.DOUBLE.id: Double, # type: ignore
duckdb.typing.INTEGER.id: Integer, # type: ignore
duckdb.typing.TIMESTAMP.id: DateTime, # type: ignore
duckdb.typing.VARCHAR.id: String, # type: ignore
}

_SQLALCHEMY_TYPES_TO_DUCKDB_TYPES: dict[Any, str] = {
v: k for k, v in _DUCKDB_TYPES_TO_SQLALCHEMY_TYPES.items()
# Note: timestamp requires special handling because of timezone in sqlalchemy.
}


def get_sqlalchemy_type_from_duckdb(duckdb_type: duckdb.typing.DuckDBPyType) -> Type: # type: ignore
def get_sqlalchemy_type_from_duckdb(duckdb_type: DuckDBPyType) -> Any:
"""Return the sqlalchemy type for a duckdb type."""
if duckdb_type == duckdb.typing.TIMESTAMP_TZ: # type: ignore
msg = "TIMESTAMP_TZ is not handled yet"
raise NotImplementedError(msg)
match duckdb_type:
case duckdb.typing.TIMESTAMP_TZ: # type: ignore
sqlalchemy_type = DateTime(timezone=True)
case duckdb.typing.TIMESTAMP: # type: ignore
sqlalchemy_type = DateTime(timezone=False)
case _:
cls = _DUCKDB_TYPES_TO_SQLALCHEMY_TYPES.get(duckdb_type.id)
if cls is None:
msg = f"There is no sqlalchemy mapping for {duckdb_type=}"
raise InvalidParameter(msg)
sqlalchemy_type = cls()

sqlalchemy_type = _DUCKDB_TYPES_TO_SQLALCHEMY_TYPES.get(duckdb_type.id)
if sqlalchemy_type is None:
msg = f"There is no sqlalchemy mapping for {duckdb_type=}"
raise InvalidParameter(msg)
return sqlalchemy_type


def get_duckdb_type_from_sqlalchemy(sqlalchemy_type: Any) -> str:
def get_duckdb_type_from_sqlalchemy(sqlalchemy_type: Any) -> DuckDBPyType:
"""Return the duckdb type for a sqlalchemy type."""
duckdb_type = _SQLALCHEMY_TYPES_TO_DUCKDB_TYPES.get(sqlalchemy_type)
if duckdb_type is None:
if isinstance(sqlalchemy_type, DateTime):
duckdb_type = (
duckdb.typing.TIMESTAMP_TZ # type: ignore
if sqlalchemy_type.timezone
else duckdb.typing.TIMESTAMP # type: ignore
)
elif isinstance(sqlalchemy_type, BigInteger):
duckdb_type = duckdb.typing.BIGINT # type: ignore
elif isinstance(sqlalchemy_type, Boolean):
duckdb_type = duckdb.typing.BOOLEAN # type: ignore
elif isinstance(sqlalchemy_type, Double):
duckdb_type = duckdb.typing.DOUBLE # type: ignore
elif isinstance(sqlalchemy_type, Integer):
duckdb_type = duckdb.typing.INTEGER # type: ignore
elif isinstance(sqlalchemy_type, String):
duckdb_type = duckdb.typing.VARCHAR # type: ignore
else:
msg = f"There is no duckdb mapping for {sqlalchemy_type=}"
raise InvalidParameter(msg)
return duckdb_type.upper()

return duckdb_type # type: ignore


class ColumnDType(ChronifyBaseModel):
Expand All @@ -130,7 +148,7 @@ class ColumnDType(ChronifyBaseModel):
@classmethod
def fix_data_type(cls, data: dict[str, Any]) -> dict[str, Any]:
dtype = data.get("dtype")
if dtype is None or dtype in _DB_TYPES:
if dtype is None or any(map(lambda x: isinstance(dtype, x), _DB_TYPES)):
return data

if isinstance(dtype, str):
Expand All @@ -139,7 +157,7 @@ def fix_data_type(cls, data: dict[str, Any]) -> dict[str, Any]:
options = sorted(_COLUMN_TYPES.keys()) + list(_DB_TYPES)
msg = f"{dtype=} must be one of {options}"
raise ValueError(msg)
data["dtype"] = val
data["dtype"] = val()
else:
msg = f"dtype is an unsupported type: {type(dtype)}. It must be a str or type."
raise ValueError(msg)
Expand Down
27 changes: 20 additions & 7 deletions src/chronify/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from chronify.sqlalchemy.functions import read_database, write_database
from chronify.time_configs import DatetimeRange, IndexTimeRange
from chronify.time_series_checker import TimeSeriesChecker
from chronify.time_series_checker import check_timestamps
from chronify.utils.sql import make_temp_view_name
from chronify.utils.sqlalchemy_view import create_view

Expand Down Expand Up @@ -132,16 +132,28 @@ def ingest_from_csv(
msg = f"IndexTimeRange cannot be converted to {cls_name}"
raise NotImplementedError(msg)

if self.has_table(dst_schema.name):
table_exists = self.has_table(dst_schema.name)
if table_exists:
table = Table(dst_schema.name, self._metadata)
else:
dtypes = [get_sqlalchemy_type_from_duckdb(x) for x in rel.dtypes]
columns = [Column(x, y) for x, y in zip(rel.columns, dtypes)]
table = Table(dst_schema.name, self._metadata, *columns)
table = Table(
dst_schema.name,
self._metadata,
*[Column(x, y) for x, y in zip(rel.columns, dtypes)],
)
table.create(self._engine)

with self._engine.begin() as conn:
write_database(rel.to_df(), conn, dst_schema)
try:
check_timestamps(conn, table, dst_schema)
except Exception:
conn.rollback()
if not table_exists:
table.drop(self._engine)
self.update_table_schema()
raise
conn.commit()
self.update_table_schema()

Expand All @@ -152,7 +164,7 @@ def read_query(self, query: Selectable | str, schema: TableSchema) -> pd.DataFra

def read_table(self, schema: TableSchema) -> pd.DataFrame:
"""Return the table as a pandas DataFrame."""
return self.read_query(f"select * from {schema.name}", schema)
return self.read_query(f"SELECT * FROM {schema.name}", schema)

def write_query_to_parquet(self, stmt: Selectable, file_path: Path | str) -> None:
"""Write the result of a query to a Parquet file."""
Expand Down Expand Up @@ -212,8 +224,9 @@ def _check_table_schema(self, schema: TableSchema) -> None:
check_columns(columns, schema.list_columns())

def _check_timestamps(self, schema: TableSchema) -> None:
checker = TimeSeriesChecker(self._engine, self._metadata)
checker.check_timestamps(schema)
with self._engine.connect() as conn:
table = Table(schema.name, self._metadata)
check_timestamps(conn, table, schema)


def check_columns(table_columns: Iterable[str], schema_columns: Iterable[str]) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/chronify/time_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def needs_utc_conversion(self, engine_name: str) -> bool:
return False

@abc.abstractmethod
def list_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[Any]:
"""Return a list of timestamps present in DataFrame.
def list_distinct_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[Any]:
"""Return a list of distinct timestamps present in DataFrame.
Type of the timestamps depends on the class.
Returns
Expand Down Expand Up @@ -183,8 +183,8 @@ def is_time_zone_naive(self) -> bool:
"""Return True if the timestamps in the range do not have time zones."""
return self.start.tzinfo is None

def list_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[datetime]:
return df[self.time_column].to_list()
def list_distinct_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[datetime]:
return sorted(df[self.time_column].unique())

def list_time_columns(self) -> list[str]:
return [self.time_column]
Expand Down
90 changes: 45 additions & 45 deletions src/chronify/time_series_checker.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,79 @@
from sqlalchemy import Connection, Engine, MetaData, Table, select, text
from sqlalchemy import Connection, Table, select, text

from chronify.exceptions import InvalidTable
from chronify.models import TableSchema
from chronify.sqlalchemy.functions import read_database
from chronify.utils.sql import make_temp_view_name


def check_timestamps(conn: Connection, table: Table, schema: TableSchema) -> None:
"""Performs checks on time series arrays in a table."""
TimeSeriesChecker(conn, table, schema).check_timestamps()


class TimeSeriesChecker:
"""Performs checks on time series arrays in a table."""

def __init__(self, engine: Engine, metadata: MetaData) -> None:
self._engine = engine
self._metadata = metadata
def __init__(self, conn: Connection, table: Table, schema: TableSchema) -> None:
self._conn = conn
self._schema = schema
self._table = table

def check_timestamps(self, schema: TableSchema) -> None:
self._check_expected_timestamps(schema)
self._check_expected_timestamps_by_time_array(schema)
def check_timestamps(self) -> None:
self._check_expected_timestamps()
self._check_expected_timestamps_by_time_array()

def _check_expected_timestamps(self, schema: TableSchema) -> None:
expected = schema.time_config.list_timestamps()
with self._engine.connect() as conn:
table = Table(schema.name, self._metadata)
time_columns = schema.time_config.list_time_columns()
stmt = select(*(table.c[x] for x in time_columns)).distinct()
for col in time_columns:
stmt = stmt.where(table.c[col].is_not(None))
df = read_database(stmt, conn, schema)
actual = set(schema.time_config.list_timestamps_from_dataframe(df))
match = sorted(actual) == expected
# TODO: This check doesn't work and I'm not sure why.
# diff = actual.symmetric_difference(expected)
# if diff:
# msg = f"Actual timestamps do not match expected timestamps: {diff}"
# # TODO: list diff on each side.
# raise InvalidTable(msg)
if not match:
msg = "Actual timestamps do not match expected timestamps"
# TODO: list diff on each side.
raise InvalidTable(msg)
def _check_expected_timestamps(self) -> None:
expected = self._schema.time_config.list_timestamps()
time_columns = self._schema.time_config.list_time_columns()
stmt = select(*(self._table.c[x] for x in time_columns)).distinct()
for col in time_columns:
stmt = stmt.where(self._table.c[col].is_not(None))
df = read_database(stmt, self._conn, self._schema)
actual = self._schema.time_config.list_distinct_timestamps_from_dataframe(df)
match = actual == expected
# TODO: This check doesn't work and I'm not sure why.
# diff = actual.symmetric_difference(expected)
# if diff:
# msg = f"Actual timestamps do not match expected timestamps: {diff}"
# # TODO: list diff on each side.
# raise InvalidTable(msg)
if not match:
msg = "Actual timestamps do not match expected timestamps"
# TODO: list diff on each side.
raise InvalidTable(msg)

def _check_expected_timestamps_by_time_array(self, schema: TableSchema) -> None:
with self._engine.connect() as conn:
tmp_name = make_temp_view_name()
self._run_timestamp_checks_on_tmp_table(schema, conn, tmp_name)
conn.execute(text(f"DROP TABLE IF EXISTS {tmp_name}"))
def _check_expected_timestamps_by_time_array(self) -> None:
tmp_name = make_temp_view_name()
self._run_timestamp_checks_on_tmp_table(tmp_name)
self._conn.execute(text(f"DROP TABLE IF EXISTS {tmp_name}"))

@staticmethod
def _run_timestamp_checks_on_tmp_table(
schema: TableSchema, conn: Connection, table_name: str
) -> None:
id_cols = ",".join(schema.time_array_id_columns)
filters = [f"{x} IS NOT NULL" for x in schema.time_config.list_time_columns()]
def _run_timestamp_checks_on_tmp_table(self, table_name: str) -> None:
id_cols = ",".join(self._schema.time_array_id_columns)
filters = [f"{x} IS NOT NULL" for x in self._schema.time_config.list_time_columns()]
where_clause = "AND ".join(filters)
query = f"""
CREATE TEMP TABLE {table_name} AS
SELECT
{id_cols}
,COUNT(*) AS count_by_ta
FROM {schema.name}
FROM {self._schema.name}
WHERE {where_clause}
GROUP BY {id_cols}
"""
conn.execute(text(query))
self._conn.execute(text(query))
query2 = f"SELECT COUNT(DISTINCT count_by_ta) AS counts FROM {table_name}"
result2 = conn.execute(text(query2)).fetchone()
result2 = self._conn.execute(text(query2)).fetchone()
assert result2 is not None

if result2[0] != 1:
msg = f"All time arrays must have the same length. There are {result2[0]} different lengths"
raise InvalidTable(msg)

query3 = f"SELECT DISTINCT count_by_ta AS counts FROM {table_name}"
result3 = conn.execute(text(query3)).fetchone()
result3 = self._conn.execute(text(query3)).fetchone()
assert result3 is not None
actual_count = result3[0]
if actual_count != schema.time_config.length:
msg = f"Time arrays must have length={schema.time_config.length}. Actual = {actual_count}"
if actual_count != self._schema.time_config.length:
msg = f"Time arrays must have length={self._schema.time_config.length}. Actual = {actual_count}"
raise InvalidTable(msg)
7 changes: 5 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pytest
from sqlalchemy import Integer
from sqlalchemy import BigInteger, Boolean, DateTime, Double, Integer, String

from chronify.models import ColumnDType, _check_name


def test_column_dtypes():
ColumnDType(name="col1", dtype=Integer)
ColumnDType(name="col1", dtype=Integer())
for dtype in (BigInteger, Boolean, DateTime, Double, String):
ColumnDType(name="col1", dtype=dtype())

for string_type in ("int", "bigint", "bool", "datetime", "float", "str"):
ColumnDType(name="col1", dtype=string_type)

Expand Down
Loading

0 comments on commit 2260b95

Please sign in to comment.