Skip to content

Commit ea591aa

Browse files
dfranklandborchero
andauthored
feat: Add time_zone to Datetime column (#33)
Co-authored-by: Oliver Borchert <[email protected]> Co-authored-by: Oliver Borchert <[email protected]>
1 parent 08972c6 commit ea591aa

File tree

4 files changed

+58
-7
lines changed

4 files changed

+58
-7
lines changed

dataframely/columns/datetime.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(
299299
max: dt.datetime | None = None,
300300
max_exclusive: dt.datetime | None = None,
301301
resolution: str | None = None,
302+
time_zone: str | dt.tzinfo | None = None,
302303
check: (
303304
Callable[[pl.Expr], pl.Expr]
304305
| list[Callable[[pl.Expr], pl.Expr]]
@@ -326,6 +327,9 @@ def __init__(
326327
the formatting language used by :mod:`polars` datetime ``round`` method.
327328
For example, a value ``1h`` expects all datetimes to be full hours. Note
328329
that this setting does *not* affect the storage resolution.
330+
time_zone: The time zone that datetimes in the column must have. The time
331+
zone must use a valid IANA time zone name identifier e.x. ``Etc/UTC`` or
332+
``America/New_York``.
329333
check: A custom rule or multiple rules to run for this column. This can be:
330334
- A single callable that returns a non-aggregated boolean expression.
331335
The name of the rule is derived from the callable name, or defaults to
@@ -368,10 +372,11 @@ def __init__(
368372
metadata=metadata,
369373
)
370374
self.resolution = resolution
375+
self.time_zone = time_zone
371376

372377
@property
373378
def dtype(self) -> pl.DataType:
374-
return pl.Datetime()
379+
return pl.Datetime(time_zone=self.time_zone)
375380

376381
def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
377382
result = super().validation_rules(expr)
@@ -380,16 +385,22 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
380385
return result
381386

382387
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
388+
timezone_enabled = self.time_zone is not None
383389
match dialect.name:
384390
case "mssql":
385391
# sa.DateTime wrongly maps to DATETIME
386-
return sa_mssql.DATETIME2(6)
392+
return sa_mssql.DATETIME2(6, timezone=timezone_enabled)
387393
case _:
388-
return sa.DateTime()
394+
return sa.DateTime(timezone=timezone_enabled)
389395

390396
@property
391397
def pyarrow_dtype(self) -> pa.DataType:
392-
return pa.timestamp("us")
398+
time_zone = (
399+
self.time_zone.tzname(None)
400+
if isinstance(self.time_zone, dt.tzinfo)
401+
else self.time_zone
402+
)
403+
return pa.timestamp("us", time_zone)
393404

394405
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
395406
return generator.sample_datetime(
@@ -405,6 +416,7 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
405416
allow_null_response=True,
406417
),
407418
resolution=self.resolution,
419+
time_zone=self.time_zone,
408420
null_probability=self._null_probability,
409421
)
410422

dataframely/random.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def sample_datetime(
293293
min: dt.datetime,
294294
max: dt.datetime | None,
295295
resolution: str | None = None,
296+
time_zone: str | dt.tzinfo | None = None,
296297
null_probability: float = 0.0,
297298
) -> pl.Series:
298299
"""Sample a list of datetimes in the provided range.
@@ -303,6 +304,9 @@ def sample_datetime(
303304
max: The maximum datetime to sample (exclusive). '10000-01-01' when ``None``.
304305
resolution: The resolution that datetimes in the column must have. This uses
305306
the formatting language used by :mod:`polars` datetime ``round`` method.
307+
time_zone: The time zone that datetimes in the column must have. The time
308+
zone must use a valid IANA time zone name identifier e.x. ``Etc/UTC`` or
309+
``America/New_York``.
306310
null_probability: The probability of an element being ``null``.
307311
308312
Returns:
@@ -329,7 +333,7 @@ def sample_datetime(
329333
)
330334
# NOTE: polars tracks datetimes relative to epoch
331335
- _datetime_to_microseconds(EPOCH_DATETIME)
332-
).cast(pl.Datetime)
336+
).cast(pl.Datetime(time_zone=time_zone))
333337

334338
if resolution is not None:
335339
return result.dt.truncate(resolution)

tests/column_types/test_datetime.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import datetime as dt
5+
import re
56
from typing import Any
67

78
import polars as pl
@@ -10,6 +11,7 @@
1011

1112
import dataframely as dy
1213
from dataframely.columns import Column
14+
from dataframely.exc import DtypeValidationError
1315
from dataframely.random import Generator
1416
from dataframely.testing import evaluate_rules, rules_from_exprs
1517
from dataframely.testing.factory import create_schema
@@ -392,11 +394,42 @@ def test_validate_resolution(
392394
[
393395
dy.Datetime(
394396
min=dt.datetime(2020, 1, 1), max=dt.datetime(2021, 1, 1), resolution="1h"
395-
)
397+
),
398+
dy.Datetime(time_zone="Etc/UTC"),
396399
],
397400
)
398-
def test_sample_resolution(column: dy.Column) -> None:
401+
def test_sample(column: dy.Column) -> None:
399402
generator = Generator(seed=42)
400403
samples = column.sample(generator, n=10_000)
401404
schema = create_schema("test", {"a": column})
402405
schema.validate(samples.to_frame("a"))
406+
407+
408+
@pytest.mark.parametrize(
409+
("dtype", "column", "error"),
410+
[
411+
(
412+
pl.Datetime(time_zone="America/New_York"),
413+
dy.Datetime(time_zone="Etc/UTC"),
414+
r"1 columns have an invalid dtype.*\n.*got dtype 'Datetime\(time_unit='us', time_zone='America/New_York'\)' but expected 'Datetime\(time_unit='us', time_zone='Etc/UTC'\)'",
415+
),
416+
(
417+
pl.Datetime(time_zone="Etc/UTC"),
418+
dy.Datetime(time_zone="Etc/UTC"),
419+
None,
420+
),
421+
],
422+
)
423+
def test_dtype_time_zone_validation(
424+
dtype: pl.DataType,
425+
column: dy.Column,
426+
error: str | None,
427+
) -> None:
428+
df = pl.DataFrame(schema={"a": dtype})
429+
schema = create_schema("test", {"a": column})
430+
if error is None:
431+
schema.validate(df)
432+
else:
433+
with pytest.raises(DtypeValidationError) as exc:
434+
schema.validate(df)
435+
assert re.match(error, str(exc.value))

tests/columns/test_sql_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
(dy.Bool(), "BIT"),
1919
(dy.Date(), "DATE"),
2020
(dy.Datetime(), "DATETIME2(6)"),
21+
(dy.Datetime(time_zone="Etc/UTC"), "DATETIME2(6)"),
2122
(dy.Time(), "TIME(6)"),
2223
(dy.Duration(), "DATETIME2(6)"),
2324
(dy.Decimal(), "NUMERIC"),
@@ -62,6 +63,7 @@ def test_mssql_datatype(column: Column, datatype: str) -> None:
6263
(dy.Bool(), "BOOLEAN"),
6364
(dy.Date(), "DATE"),
6465
(dy.Datetime(), "TIMESTAMP WITHOUT TIME ZONE"),
66+
(dy.Datetime(time_zone="Etc/UTC"), "TIMESTAMP WITH TIME ZONE"),
6567
(dy.Time(), "TIME WITHOUT TIME ZONE"),
6668
(dy.Duration(), "INTERVAL"),
6769
(dy.Decimal(), "NUMERIC"),

0 commit comments

Comments
 (0)