From 2260b9592d23c18951c79a94426c8cfa97b03088 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Tue, 12 Nov 2024 07:55:30 -0700 Subject: [PATCH] Fix handling of timestamps when ingesting from CSV --- src/chronify/csv_io.py | 17 ++++- src/chronify/models.py | 60 +++++++++++------- src/chronify/store.py | 27 +++++--- src/chronify/time_configs.py | 8 +-- src/chronify/time_series_checker.py | 90 +++++++++++++------------- tests/test_models.py | 7 ++- tests/test_store.py | 97 +++++++++++++++++------------ tests/test_time_series_checker.py | 16 ++--- 8 files changed, 194 insertions(+), 128 deletions(-) diff --git a/src/chronify/csv_io.py b/src/chronify/csv_io.py index c0605d2..cc9245f 100644 --- a/src/chronify/csv_io.py +++ b/src/chronify/csv_io.py @@ -5,15 +5,28 @@ from duckdb import DuckDBPyRelation from chronify.models import CsvTableSchema, get_duckdb_type_from_sqlalchemy +from chronify.time_configs import DatetimeRange def read_csv(path: Path | str, schema: CsvTableSchema, **kwargs: Any) -> DuckDBPyRelation: """Read a CSV file into a DuckDB relation.""" if schema.column_dtypes: - dtypes = {x.name: get_duckdb_type_from_sqlalchemy(x.dtype) for x in schema.column_dtypes} + dtypes = { + x.name: get_duckdb_type_from_sqlalchemy(x.dtype).id for x in schema.column_dtypes + } rel = duckdb.read_csv(str(path), dtype=dtypes, **kwargs) else: rel = duckdb.read_csv(str(path), **kwargs) - expr = ",".join(rel.columns) + time_config = schema.time_config + exprs = [] + for i, column in enumerate(rel.columns): + expr = column + if isinstance(time_config, DatetimeRange) and column == time_config.time_column: + time_type = rel.types[i] + if time_type == duckdb.typing.TIMESTAMP and time_config.start.tzinfo is not None: # type: ignore + expr = f"timezone('{time_config.start.tzinfo.key}', {column}) AS {column}" # type: ignore + exprs.append(expr) + + expr = ",".join(exprs) return duckdb.sql(f"SELECT {expr} FROM rel") diff --git a/src/chronify/models.py b/src/chronify/models.py index 20404cc..4209343 100644 --- a/src/chronify/models.py +++ b/src/chronify/models.py @@ -1,7 +1,8 @@ import re -from typing import Any, Optional, Type +from typing import Any, Optional import duckdb.typing +from duckdb.typing import DuckDBPyType from pydantic import Field, field_validator, model_validator from sqlalchemy import BigInteger, Boolean, DateTime, Double, Integer, String from typing_extensions import Annotated @@ -51,7 +52,8 @@ class TableSchema(TableSchemaBase): """Defines the schema for a time series table stored in the database.""" name: Annotated[ - str, Field(description="Name of the table or view in the database.", frozen=True) + str, + Field(description="Name of the table or view in the database.", frozen=True), ] value_column: Annotated[str, Field(description="Column in the table that contain values.")] @@ -89,35 +91,51 @@ def list_columns(self) -> list[str]: duckdb.typing.BOOLEAN.id: Boolean, # type: ignore duckdb.typing.DOUBLE.id: Double, # type: ignore duckdb.typing.INTEGER.id: Integer, # type: ignore - duckdb.typing.TIMESTAMP.id: DateTime, # type: ignore duckdb.typing.VARCHAR.id: String, # type: ignore -} - -_SQLALCHEMY_TYPES_TO_DUCKDB_TYPES: dict[Any, str] = { - v: k for k, v in _DUCKDB_TYPES_TO_SQLALCHEMY_TYPES.items() + # Note: timestamp requires special handling because of timezone in sqlalchemy. } -def get_sqlalchemy_type_from_duckdb(duckdb_type: duckdb.typing.DuckDBPyType) -> Type: # type: ignore +def get_sqlalchemy_type_from_duckdb(duckdb_type: DuckDBPyType) -> Any: """Return the sqlalchemy type for a duckdb type.""" - if duckdb_type == duckdb.typing.TIMESTAMP_TZ: # type: ignore - msg = "TIMESTAMP_TZ is not handled yet" - raise NotImplementedError(msg) + match duckdb_type: + case duckdb.typing.TIMESTAMP_TZ: # type: ignore + sqlalchemy_type = DateTime(timezone=True) + case duckdb.typing.TIMESTAMP: # type: ignore + sqlalchemy_type = DateTime(timezone=False) + case _: + cls = _DUCKDB_TYPES_TO_SQLALCHEMY_TYPES.get(duckdb_type.id) + if cls is None: + msg = f"There is no sqlalchemy mapping for {duckdb_type=}" + raise InvalidParameter(msg) + sqlalchemy_type = cls() - sqlalchemy_type = _DUCKDB_TYPES_TO_SQLALCHEMY_TYPES.get(duckdb_type.id) - if sqlalchemy_type is None: - msg = f"There is no sqlalchemy mapping for {duckdb_type=}" - raise InvalidParameter(msg) return sqlalchemy_type -def get_duckdb_type_from_sqlalchemy(sqlalchemy_type: Any) -> str: +def get_duckdb_type_from_sqlalchemy(sqlalchemy_type: Any) -> DuckDBPyType: """Return the duckdb type for a sqlalchemy type.""" - duckdb_type = _SQLALCHEMY_TYPES_TO_DUCKDB_TYPES.get(sqlalchemy_type) - if duckdb_type is None: + if isinstance(sqlalchemy_type, DateTime): + duckdb_type = ( + duckdb.typing.TIMESTAMP_TZ # type: ignore + if sqlalchemy_type.timezone + else duckdb.typing.TIMESTAMP # type: ignore + ) + elif isinstance(sqlalchemy_type, BigInteger): + duckdb_type = duckdb.typing.BIGINT # type: ignore + elif isinstance(sqlalchemy_type, Boolean): + duckdb_type = duckdb.typing.BOOLEAN # type: ignore + elif isinstance(sqlalchemy_type, Double): + duckdb_type = duckdb.typing.DOUBLE # type: ignore + elif isinstance(sqlalchemy_type, Integer): + duckdb_type = duckdb.typing.INTEGER # type: ignore + elif isinstance(sqlalchemy_type, String): + duckdb_type = duckdb.typing.VARCHAR # type: ignore + else: msg = f"There is no duckdb mapping for {sqlalchemy_type=}" raise InvalidParameter(msg) - return duckdb_type.upper() + + return duckdb_type # type: ignore class ColumnDType(ChronifyBaseModel): @@ -130,7 +148,7 @@ class ColumnDType(ChronifyBaseModel): @classmethod def fix_data_type(cls, data: dict[str, Any]) -> dict[str, Any]: dtype = data.get("dtype") - if dtype is None or dtype in _DB_TYPES: + if dtype is None or any(map(lambda x: isinstance(dtype, x), _DB_TYPES)): return data if isinstance(dtype, str): @@ -139,7 +157,7 @@ def fix_data_type(cls, data: dict[str, Any]) -> dict[str, Any]: options = sorted(_COLUMN_TYPES.keys()) + list(_DB_TYPES) msg = f"{dtype=} must be one of {options}" raise ValueError(msg) - data["dtype"] = val + data["dtype"] = val() else: msg = f"dtype is an unsupported type: {type(dtype)}. It must be a str or type." raise ValueError(msg) diff --git a/src/chronify/store.py b/src/chronify/store.py index ba26136..f9e3b99 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -17,7 +17,7 @@ ) from chronify.sqlalchemy.functions import read_database, write_database from chronify.time_configs import DatetimeRange, IndexTimeRange -from chronify.time_series_checker import TimeSeriesChecker +from chronify.time_series_checker import check_timestamps from chronify.utils.sql import make_temp_view_name from chronify.utils.sqlalchemy_view import create_view @@ -132,16 +132,28 @@ def ingest_from_csv( msg = f"IndexTimeRange cannot be converted to {cls_name}" raise NotImplementedError(msg) - if self.has_table(dst_schema.name): + table_exists = self.has_table(dst_schema.name) + if table_exists: table = Table(dst_schema.name, self._metadata) else: dtypes = [get_sqlalchemy_type_from_duckdb(x) for x in rel.dtypes] - columns = [Column(x, y) for x, y in zip(rel.columns, dtypes)] - table = Table(dst_schema.name, self._metadata, *columns) + table = Table( + dst_schema.name, + self._metadata, + *[Column(x, y) for x, y in zip(rel.columns, dtypes)], + ) table.create(self._engine) with self._engine.begin() as conn: write_database(rel.to_df(), conn, dst_schema) + try: + check_timestamps(conn, table, dst_schema) + except Exception: + conn.rollback() + if not table_exists: + table.drop(self._engine) + self.update_table_schema() + raise conn.commit() self.update_table_schema() @@ -152,7 +164,7 @@ def read_query(self, query: Selectable | str, schema: TableSchema) -> pd.DataFra def read_table(self, schema: TableSchema) -> pd.DataFrame: """Return the table as a pandas DataFrame.""" - return self.read_query(f"select * from {schema.name}", schema) + return self.read_query(f"SELECT * FROM {schema.name}", schema) def write_query_to_parquet(self, stmt: Selectable, file_path: Path | str) -> None: """Write the result of a query to a Parquet file.""" @@ -212,8 +224,9 @@ def _check_table_schema(self, schema: TableSchema) -> None: check_columns(columns, schema.list_columns()) def _check_timestamps(self, schema: TableSchema) -> None: - checker = TimeSeriesChecker(self._engine, self._metadata) - checker.check_timestamps(schema) + with self._engine.connect() as conn: + table = Table(schema.name, self._metadata) + check_timestamps(conn, table, schema) def check_columns(table_columns: Iterable[str], schema_columns: Iterable[str]) -> None: diff --git a/src/chronify/time_configs.py b/src/chronify/time_configs.py index 0d1c4ab..3d16350 100644 --- a/src/chronify/time_configs.py +++ b/src/chronify/time_configs.py @@ -145,8 +145,8 @@ def needs_utc_conversion(self, engine_name: str) -> bool: return False @abc.abstractmethod - def list_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[Any]: - """Return a list of timestamps present in DataFrame. + def list_distinct_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[Any]: + """Return a list of distinct timestamps present in DataFrame. Type of the timestamps depends on the class. Returns @@ -183,8 +183,8 @@ def is_time_zone_naive(self) -> bool: """Return True if the timestamps in the range do not have time zones.""" return self.start.tzinfo is None - def list_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[datetime]: - return df[self.time_column].to_list() + def list_distinct_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[datetime]: + return sorted(df[self.time_column].unique()) def list_time_columns(self) -> list[str]: return [self.time_column] diff --git a/src/chronify/time_series_checker.py b/src/chronify/time_series_checker.py index df0fad3..5b00a02 100644 --- a/src/chronify/time_series_checker.py +++ b/src/chronify/time_series_checker.py @@ -1,4 +1,4 @@ -from sqlalchemy import Connection, Engine, MetaData, Table, select, text +from sqlalchemy import Connection, Table, select, text from chronify.exceptions import InvalidTable from chronify.models import TableSchema @@ -6,64 +6,64 @@ from chronify.utils.sql import make_temp_view_name +def check_timestamps(conn: Connection, table: Table, schema: TableSchema) -> None: + """Performs checks on time series arrays in a table.""" + TimeSeriesChecker(conn, table, schema).check_timestamps() + + class TimeSeriesChecker: """Performs checks on time series arrays in a table.""" - def __init__(self, engine: Engine, metadata: MetaData) -> None: - self._engine = engine - self._metadata = metadata + def __init__(self, conn: Connection, table: Table, schema: TableSchema) -> None: + self._conn = conn + self._schema = schema + self._table = table - def check_timestamps(self, schema: TableSchema) -> None: - self._check_expected_timestamps(schema) - self._check_expected_timestamps_by_time_array(schema) + def check_timestamps(self) -> None: + self._check_expected_timestamps() + self._check_expected_timestamps_by_time_array() - def _check_expected_timestamps(self, schema: TableSchema) -> None: - expected = schema.time_config.list_timestamps() - with self._engine.connect() as conn: - table = Table(schema.name, self._metadata) - time_columns = schema.time_config.list_time_columns() - stmt = select(*(table.c[x] for x in time_columns)).distinct() - for col in time_columns: - stmt = stmt.where(table.c[col].is_not(None)) - df = read_database(stmt, conn, schema) - actual = set(schema.time_config.list_timestamps_from_dataframe(df)) - match = sorted(actual) == expected - # TODO: This check doesn't work and I'm not sure why. - # diff = actual.symmetric_difference(expected) - # if diff: - # msg = f"Actual timestamps do not match expected timestamps: {diff}" - # # TODO: list diff on each side. - # raise InvalidTable(msg) - if not match: - msg = "Actual timestamps do not match expected timestamps" - # TODO: list diff on each side. - raise InvalidTable(msg) + def _check_expected_timestamps(self) -> None: + expected = self._schema.time_config.list_timestamps() + time_columns = self._schema.time_config.list_time_columns() + stmt = select(*(self._table.c[x] for x in time_columns)).distinct() + for col in time_columns: + stmt = stmt.where(self._table.c[col].is_not(None)) + df = read_database(stmt, self._conn, self._schema) + actual = self._schema.time_config.list_distinct_timestamps_from_dataframe(df) + match = actual == expected + # TODO: This check doesn't work and I'm not sure why. + # diff = actual.symmetric_difference(expected) + # if diff: + # msg = f"Actual timestamps do not match expected timestamps: {diff}" + # # TODO: list diff on each side. + # raise InvalidTable(msg) + if not match: + msg = "Actual timestamps do not match expected timestamps" + # TODO: list diff on each side. + raise InvalidTable(msg) - def _check_expected_timestamps_by_time_array(self, schema: TableSchema) -> None: - with self._engine.connect() as conn: - tmp_name = make_temp_view_name() - self._run_timestamp_checks_on_tmp_table(schema, conn, tmp_name) - conn.execute(text(f"DROP TABLE IF EXISTS {tmp_name}")) + def _check_expected_timestamps_by_time_array(self) -> None: + tmp_name = make_temp_view_name() + self._run_timestamp_checks_on_tmp_table(tmp_name) + self._conn.execute(text(f"DROP TABLE IF EXISTS {tmp_name}")) - @staticmethod - def _run_timestamp_checks_on_tmp_table( - schema: TableSchema, conn: Connection, table_name: str - ) -> None: - id_cols = ",".join(schema.time_array_id_columns) - filters = [f"{x} IS NOT NULL" for x in schema.time_config.list_time_columns()] + def _run_timestamp_checks_on_tmp_table(self, table_name: str) -> None: + id_cols = ",".join(self._schema.time_array_id_columns) + filters = [f"{x} IS NOT NULL" for x in self._schema.time_config.list_time_columns()] where_clause = "AND ".join(filters) query = f""" CREATE TEMP TABLE {table_name} AS SELECT {id_cols} ,COUNT(*) AS count_by_ta - FROM {schema.name} + FROM {self._schema.name} WHERE {where_clause} GROUP BY {id_cols} """ - conn.execute(text(query)) + self._conn.execute(text(query)) query2 = f"SELECT COUNT(DISTINCT count_by_ta) AS counts FROM {table_name}" - result2 = conn.execute(text(query2)).fetchone() + result2 = self._conn.execute(text(query2)).fetchone() assert result2 is not None if result2[0] != 1: @@ -71,9 +71,9 @@ def _run_timestamp_checks_on_tmp_table( raise InvalidTable(msg) query3 = f"SELECT DISTINCT count_by_ta AS counts FROM {table_name}" - result3 = conn.execute(text(query3)).fetchone() + result3 = self._conn.execute(text(query3)).fetchone() assert result3 is not None actual_count = result3[0] - if actual_count != schema.time_config.length: - msg = f"Time arrays must have length={schema.time_config.length}. Actual = {actual_count}" + if actual_count != self._schema.time_config.length: + msg = f"Time arrays must have length={self._schema.time_config.length}. Actual = {actual_count}" raise InvalidTable(msg) diff --git a/tests/test_models.py b/tests/test_models.py index 5a4675a..927c023 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,11 +1,14 @@ import pytest -from sqlalchemy import Integer +from sqlalchemy import BigInteger, Boolean, DateTime, Double, Integer, String from chronify.models import ColumnDType, _check_name def test_column_dtypes(): - ColumnDType(name="col1", dtype=Integer) + ColumnDType(name="col1", dtype=Integer()) + for dtype in (BigInteger, Boolean, DateTime, Double, String): + ColumnDType(name="col1", dtype=dtype()) + for string_type in ("int", "bigint", "bool", "datetime", "float", "str"): ColumnDType(name="col1", dtype=string_type) diff --git a/tests/test_store.py b/tests/test_store.py index 1a36046..433fbf0 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -1,12 +1,13 @@ import fileinput import shutil - from datetime import datetime, timedelta +from pathlib import Path from zoneinfo import ZoneInfo import duckdb import pandas as pd import pytest +import sqlalchemy from sqlalchemy import DateTime, Double, Engine, Table, create_engine, select from chronify.csv_io import read_csv from chronify.duckdb.functions import unpivot @@ -33,10 +34,10 @@ def generators_schema(): src_schema = CsvTableSchema( time_config=time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime), - ColumnDType(name="gen1", dtype=Double), - ColumnDType(name="gen2", dtype=Double), - ColumnDType(name="gen3", dtype=Double), + ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), + ColumnDType(name="gen1", dtype=Double()), + ColumnDType(name="gen2", dtype=Double()), + ColumnDType(name="gen3", dtype=Double()), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -48,19 +49,32 @@ def generators_schema(): time_array_id_columns=["generator"], value_column="value", ) - yield src_schema, dst_schema + yield Path(GENERATOR_TIME_SERIES_FILE), src_schema, dst_schema -def test_ingest_csv(iter_engines: Engine, tmp_path, generators_schema): +@pytest.mark.parametrize("use_time_zone", [True, False]) +def test_ingest_csv(iter_engines: Engine, tmp_path, generators_schema, use_time_zone): engine = iter_engines - src_schema, dst_schema = generators_schema + src_file, src_schema, dst_schema = generators_schema + src_schema.column_dtypes[0] = ColumnDType( + name="timestamp", dtype=DateTime(timezone=use_time_zone) + ) store = Store(engine=engine) - store.ingest_from_csv(GENERATOR_TIME_SERIES_FILE, src_schema, dst_schema) + if use_time_zone: + new_src_file = tmp_path / "gen_tz.csv" + duckdb.sql( + f""" + SELECT timezone('EST', timestamp) as timestamp, gen1, gen2, gen3 + FROM read_csv('{src_file}') + """ + ).to_df().to_csv(new_src_file, index=False) + src_file = new_src_file + store.ingest_from_csv(src_file, src_schema, dst_schema) df = store.read_table(dst_schema) assert len(df) == 8784 * 3 new_file = tmp_path / "gen2.csv" - shutil.copyfile(GENERATOR_TIME_SERIES_FILE, new_file) + shutil.copyfile(src_file, new_file) with fileinput.input([new_file], inplace=True) as f: for line in f: new_line = line.replace("gen1", "g1b").replace("gen2", "g2b").replace("gen3", "g3b") @@ -70,10 +84,10 @@ def test_ingest_csv(iter_engines: Engine, tmp_path, generators_schema): src_schema2 = CsvTableSchema( time_config=src_schema.time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime), - ColumnDType(name="g1b", dtype=Double), - ColumnDType(name="g2b", dtype=Double), - ColumnDType(name="g3b", dtype=Double), + ColumnDType(name="timestamp", dtype=DateTime(timezone=use_time_zone)), + ColumnDType(name="g1b", dtype=Double()), + ColumnDType(name="g2b", dtype=Double()), + ColumnDType(name="g3b", dtype=Double()), ], value_columns=["g1b", "g2b", "g3b"], pivoted_dimension_name="generator", @@ -84,28 +98,38 @@ def test_ingest_csv(iter_engines: Engine, tmp_path, generators_schema): assert len(df) == 8784 * 3 * 2 all(df.timestamp.unique() == dst_schema.time_config.list_timestamps()) + # Adding the same rows should fail. + with pytest.raises(InvalidTable): + store.ingest_from_csv(new_file, src_schema2, dst_schema) + df = store.read_table(dst_schema) + assert len(df) == 8784 * 3 * 2 + all(df.timestamp.unique() == dst_schema.time_config.list_timestamps()) -def test_ingest_csv_with_time_zones(iter_engines: Engine, tmp_path, generators_schema): - csv_file = tmp_path / "gen.csv" - df = duckdb.read_csv(GENERATOR_TIME_SERIES_FILE).to_df() - df["timestamp"] = df["timestamp"].dt.tz_localize("EST") - df.to_csv(csv_file, index=False) + +def test_ingest_invalid_csv(iter_engines: Engine, tmp_path, generators_schema): engine = iter_engines - src_schema, dst_schema = generators_schema + src_file, src_schema, dst_schema = generators_schema + lines = src_file.read_text().splitlines()[:-10] + new_file = tmp_path / "data.csv" + with open(new_file, "w", encoding="utf-8") as f: + for line in lines: + f.write(line) + f.write("\n") + store = Store(engine=engine) - store.ingest_from_csv(csv_file, src_schema, dst_schema) - df = store.read_table(dst_schema) - assert len(df) == 8784 * 3 - all(df.timestamp.unique() == dst_schema.time_config.list_timestamps()) + with pytest.raises(InvalidTable): + store.ingest_from_csv(new_file, src_schema, dst_schema) + with pytest.raises((sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError)): + store.read_table(dst_schema) def test_invalid_schema(iter_engines: Engine, generators_schema): engine = iter_engines - src_schema, dst_schema = generators_schema + src_file, src_schema, dst_schema = generators_schema src_schema.value_columns = ["g1", "g2", "g3"] store = Store(engine=engine) with pytest.raises(InvalidTable): - store.ingest_from_csv(GENERATOR_TIME_SERIES_FILE, src_schema, dst_schema) + store.ingest_from_csv(src_file, src_schema, dst_schema) def test_load_parquet(tmp_path): @@ -120,10 +144,10 @@ def test_load_parquet(tmp_path): src_schema = CsvTableSchema( time_config=time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime), - ColumnDType(name="gen1", dtype=Double), - ColumnDType(name="gen2", dtype=Double), - ColumnDType(name="gen3", dtype=Double), + ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), + ColumnDType(name="gen1", dtype=Double()), + ColumnDType(name="gen2", dtype=Double()), + ColumnDType(name="gen3", dtype=Double()), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -137,15 +161,8 @@ def test_load_parquet(tmp_path): ) rel = read_csv(GENERATOR_TIME_SERIES_FILE, src_schema) rel2 = unpivot(rel, ("gen1", "gen2", "gen3"), "generator", "value") # noqa: F841 - rel3 = duckdb.sql( - """ - SELECT timezone('EST', timestamp) AS timestamp - ,generator - ,value from rel2 - """ - ) out_file = tmp_path / "gen2.parquet" - rel3.to_parquet(str(out_file)) + rel2.to_parquet(str(out_file)) store = Store() store.load_table(out_file, dst_schema) df = store.read_table(dst_schema) @@ -154,9 +171,9 @@ def test_load_parquet(tmp_path): def test_to_parquet(tmp_path, generators_schema): - src_schema, dst_schema = generators_schema + src_file, src_schema, dst_schema = generators_schema store = Store() - store.ingest_from_csv(GENERATOR_TIME_SERIES_FILE, src_schema, dst_schema) + store.ingest_from_csv(src_file, src_schema, dst_schema) filename = tmp_path / "data.parquet" table = Table(dst_schema.name, store.metadata) stmt = select(table).where(table.c.generator == "gen2") diff --git a/tests/test_time_series_checker.py b/tests/test_time_series_checker.py index dd6e242..92993a3 100644 --- a/tests/test_time_series_checker.py +++ b/tests/test_time_series_checker.py @@ -7,13 +7,14 @@ from sqlalchemy import ( Engine, MetaData, + Table, ) from chronify.exceptions import InvalidTable from chronify.models import TableSchema from chronify.sqlalchemy.functions import write_database from chronify.time import TimeIntervalType from chronify.time_configs import DatetimeRange -from chronify.time_series_checker import TimeSeriesChecker +from chronify.time_series_checker import check_timestamps def test_valid_datetimes_with_tz(iter_engines: Engine): @@ -66,12 +67,13 @@ def _run_test( conn.commit() metadata.reflect(engine) - checker = TimeSeriesChecker(engine, metadata) - if message is None: - checker.check_timestamps(schema) - else: - with pytest.raises(InvalidTable, match=message): - checker.check_timestamps(schema) + with engine.connect() as conn: + table = Table(schema.name, metadata) + if message is None: + check_timestamps(conn, table, schema) + else: + with pytest.raises(InvalidTable, match=message): + check_timestamps(conn, table, schema) def _get_inputs_for_valid_datetimes_with_tz() -> tuple[pd.DataFrame, ZoneInfo, int, None]: