Skip to content
20 changes: 16 additions & 4 deletions dataframely/columns/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def __init__(
max: dt.datetime | None = None,
max_exclusive: dt.datetime | None = None,
resolution: str | None = None,
time_zone: str | dt.tzinfo | None = None,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
Expand Down Expand Up @@ -326,6 +327,9 @@ def __init__(
the formatting language used by :mod:`polars` datetime ``round`` method.
For example, a value ``1h`` expects all datetimes to be full hours. Note
that this setting does *not* affect the storage resolution.
time_zone: The time zone that datetimes in the column must have. The time
zone must use a valid IANA time zone name identifier e.x. ``Etc/UTC`` or
``America/New_York``.
check: A custom rule or multiple rules to run for this column. This can be:
- A single callable that returns a non-aggregated boolean expression.
The name of the rule is derived from the callable name, or defaults to
Expand Down Expand Up @@ -368,10 +372,11 @@ def __init__(
metadata=metadata,
)
self.resolution = resolution
self.time_zone = time_zone

@property
def dtype(self) -> pl.DataType:
return pl.Datetime()
return pl.Datetime(time_zone=self.time_zone)

def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
result = super().validation_rules(expr)
Expand All @@ -380,16 +385,22 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
return result

def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
timezone_enabled = self.time_zone is not None
match dialect.name:
case "mssql":
# sa.DateTime wrongly maps to DATETIME
return sa_mssql.DATETIME2(6)
return sa_mssql.DATETIME2(6, timezone=timezone_enabled)
case _:
return sa.DateTime()
return sa.DateTime(timezone=timezone_enabled)

@property
def pyarrow_dtype(self) -> pa.DataType:
return pa.timestamp("us")
time_zone = (
self.time_zone.tzname(None)
if isinstance(self.time_zone, dt.tzinfo)
else self.time_zone
)
return pa.timestamp("us", time_zone)

def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
return generator.sample_datetime(
Expand All @@ -405,6 +416,7 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
allow_null_response=True,
),
resolution=self.resolution,
time_zone=self.time_zone,
null_probability=self._null_probability,
)

Expand Down
6 changes: 5 additions & 1 deletion dataframely/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def sample_datetime(
min: dt.datetime,
max: dt.datetime | None,
resolution: str | None = None,
time_zone: str | dt.tzinfo | None = None,
null_probability: float = 0.0,
) -> pl.Series:
"""Sample a list of datetimes in the provided range.
Expand All @@ -303,6 +304,9 @@ def sample_datetime(
max: The maximum datetime to sample (exclusive). '10000-01-01' when ``None``.
resolution: The resolution that datetimes in the column must have. This uses
the formatting language used by :mod:`polars` datetime ``round`` method.
time_zone: The time zone that datetimes in the column must have. The time
zone must use a valid IANA time zone name identifier e.x. ``Etc/UTC`` or
``America/New_York``.
null_probability: The probability of an element being ``null``.

Returns:
Expand All @@ -329,7 +333,7 @@ def sample_datetime(
)
# NOTE: polars tracks datetimes relative to epoch
- _datetime_to_microseconds(EPOCH_DATETIME)
).cast(pl.Datetime)
).cast(pl.Datetime(time_zone=time_zone))

if resolution is not None:
return result.dt.truncate(resolution)
Expand Down
37 changes: 35 additions & 2 deletions tests/column_types/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import datetime as dt
import re
from typing import Any

import polars as pl
Expand All @@ -10,6 +11,7 @@

import dataframely as dy
from dataframely.columns import Column
from dataframely.exc import DtypeValidationError
from dataframely.random import Generator
from dataframely.testing import evaluate_rules, rules_from_exprs
from dataframely.testing.factory import create_schema
Expand Down Expand Up @@ -392,11 +394,42 @@ def test_validate_resolution(
[
dy.Datetime(
min=dt.datetime(2020, 1, 1), max=dt.datetime(2021, 1, 1), resolution="1h"
)
),
dy.Datetime(time_zone="Etc/UTC"),
],
)
def test_sample_resolution(column: dy.Column) -> None:
def test_sample(column: dy.Column) -> None:
generator = Generator(seed=42)
samples = column.sample(generator, n=10_000)
schema = create_schema("test", {"a": column})
schema.validate(samples.to_frame("a"))


@pytest.mark.parametrize(
("dtype", "column", "error"),
[
(
pl.Datetime(time_zone="America/New_York"),
dy.Datetime(time_zone="Etc/UTC"),
r"1 columns have an invalid dtype.*\n.*got dtype 'Datetime\(time_unit='us', time_zone='America/New_York'\)' but expected 'Datetime\(time_unit='us', time_zone='Etc/UTC'\)'",
),
(
pl.Datetime(time_zone="Etc/UTC"),
dy.Datetime(time_zone="Etc/UTC"),
None,
),
],
)
def test_dtype_time_zone_validation(
dtype: pl.DataType,
column: dy.Column,
error: str | None,
) -> None:
df = pl.DataFrame(schema={"a": dtype})
schema = create_schema("test", {"a": column})
if error is None:
schema.validate(df)
else:
with pytest.raises(DtypeValidationError) as exc:
schema.validate(df)
assert re.match(error, str(exc.value))
2 changes: 2 additions & 0 deletions tests/columns/test_sql_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
(dy.Bool(), "BIT"),
(dy.Date(), "DATE"),
(dy.Datetime(), "DATETIME2(6)"),
(dy.Datetime(time_zone="Etc/UTC"), "DATETIME2(6)"),
(dy.Time(), "TIME(6)"),
(dy.Duration(), "DATETIME2(6)"),
(dy.Decimal(), "NUMERIC"),
Expand Down Expand Up @@ -62,6 +63,7 @@ def test_mssql_datatype(column: Column, datatype: str) -> None:
(dy.Bool(), "BOOLEAN"),
(dy.Date(), "DATE"),
(dy.Datetime(), "TIMESTAMP WITHOUT TIME ZONE"),
(dy.Datetime(time_zone="Etc/UTC"), "TIMESTAMP WITH TIME ZONE"),
(dy.Time(), "TIME WITHOUT TIME ZONE"),
(dy.Duration(), "INTERVAL"),
(dy.Decimal(), "NUMERIC"),
Expand Down
Loading