Skip to content

Commit 78da73c

Browse files
committed
Map INTERVAL types to Python types
1 parent 505989a commit 78da73c

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

tests/integration/test_types_integration.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from datetime import date, datetime, time, timedelta, timezone, tzinfo
55
from decimal import Decimal
66

7+
from dateutil.relativedelta import relativedelta
8+
79
try:
810
from zoneinfo import ZoneInfo
911
except ModuleNotFoundError:
@@ -737,13 +739,30 @@ def create_timezone(timezone_str: str) -> tzinfo:
737739
return ZoneInfo(timezone_str)
738740

739741

740-
def test_interval(trino_connection):
742+
def test_interval_year_to_month(trino_connection):
741743
SqlTest(trino_connection) \
742744
.add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) \
745+
.add_field(sql="INTERVAL '10' YEAR", python=relativedelta(years=10)) \
746+
.add_field(sql="INTERVAL '-5' YEAR", python=relativedelta(years=-5)) \
747+
.add_field(sql="INTERVAL '3' MONTH", python=relativedelta(months=3)) \
748+
.add_field(sql="INTERVAL '-18' MONTH", python=relativedelta(years=-1, months=-6)) \
749+
.add_field(sql="INTERVAL '30' MONTH", python=relativedelta(years=2, months=6)) \
750+
.add_field(sql="INTERVAL '124-30' YEAR TO MONTH", python=relativedelta(years=126, months=6)) \
751+
.execute()
752+
753+
754+
def test_interval_day_to_second(trino_connection):
755+
SqlTest(trino_connection) \
743756
.add_field(sql="CAST(null AS INTERVAL DAY TO SECOND)", python=None) \
744-
.add_field(sql="INTERVAL '3' MONTH", python='0-3') \
745-
.add_field(sql="INTERVAL '2' DAY", python='2 00:00:00.000') \
746-
.add_field(sql="INTERVAL '-2' DAY", python='-2 00:00:00.000') \
757+
.add_field(sql="INTERVAL '2' DAY", python=timedelta(days=2)) \
758+
.add_field(sql="INTERVAL '-2' DAY", python=timedelta(days=-2)) \
759+
.add_field(sql="INTERVAL '-2' SECOND", python=timedelta(seconds=-2)) \
760+
.add_field(sql="INTERVAL '1 11:11:11.116555' DAY TO SECOND",
761+
python=timedelta(days=1, seconds=40271, microseconds=116000)) \
762+
.add_field(sql="INTERVAL '-5 23:59:57.000' DAY TO SECOND", python=timedelta(days=-6, seconds=3)) \
763+
.add_field(sql="INTERVAL '12 10:45' DAY TO MINUTE", python=timedelta(days=12, seconds=38700)) \
764+
.add_field(sql="INTERVAL '45:32.123' MINUTE TO SECOND", python=timedelta(seconds=2732, microseconds=123000)) \
765+
.add_field(sql="INTERVAL '32.123' SECOND", python=timedelta(seconds=32, microseconds=123000)) \
747766
.execute()
748767

749768

trino/mapper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from decimal import Decimal
99
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
1010

11+
from dateutil.relativedelta import relativedelta
12+
1113
if sys.version_info >= (3, 9):
1214
from zoneinfo import ZoneInfo
1315
else:
@@ -172,6 +174,33 @@ def _fraction_to_decimal(fractional_str: str) -> Decimal:
172174
return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)]
173175

174176

177+
class IntervalYearToMonthMapper(ValueMapper[relativedelta]):
178+
def map(self, value: Any) -> Optional[relativedelta]:
179+
if value is None:
180+
return None
181+
is_negative = value[0] == "-"
182+
years, months = (value[1:] if is_negative else value).split('-')
183+
years, months = int(years), int(months)
184+
if is_negative:
185+
years, months = -years, -months
186+
return relativedelta(years=years, months=months)
187+
188+
189+
class IntervalDayToSecondMapper(ValueMapper[timedelta]):
190+
def map(self, value: Any) -> Optional[timedelta]:
191+
if value is None:
192+
return None
193+
is_negative = value[0] == "-"
194+
days, time = (value[1:] if is_negative else value).split(' ')
195+
hours, minutes, seconds_milliseconds = time.split(':')
196+
seconds, milliseconds = seconds_milliseconds.split('.')
197+
days, hours, minutes, seconds, milliseconds = (int(days), int(hours), int(minutes), int(seconds),
198+
int(milliseconds))
199+
if is_negative:
200+
days, hours, minutes, seconds, milliseconds = -days, -hours, -minutes, -seconds, -milliseconds
201+
return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds, milliseconds=milliseconds)
202+
203+
175204
class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
176205
def __init__(self, mapper: ValueMapper[Any]):
177206
self.mapper = mapper
@@ -276,6 +305,10 @@ def _create_value_mapper(self, column: Dict[str, Any]) -> ValueMapper[Any]:
276305
return TimestampValueMapper(self._get_precision(column))
277306
if col_type == 'timestamp with time zone':
278307
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
308+
if col_type == 'interval year to month':
309+
return IntervalYearToMonthMapper()
310+
if col_type == 'interval day to second':
311+
return IntervalDayToSecondMapper()
279312

280313
# structural types
281314
if col_type == 'array':

0 commit comments

Comments
 (0)