Skip to content

Commit 74a13f2

Browse files
feat(optimizer)!: Annotate type for snowflake DIV0 and DIVNULL functions (#6008)
* feat(optimizer): Annotate type for snowflake DIV0 and DIVNULL functions * fix(optimizer): Fixed logic for div0null, modified tests
1 parent ba7ad34 commit 74a13f2

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

sqlglot/dialects/snowflake.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,20 @@ def _build_if_from_div0(args: t.List) -> exp.If:
148148
return exp.If(this=cond, true=true, false=false)
149149

150150

151+
# https://docs.snowflake.com/en/sql-reference/functions/div0null
152+
def _build_if_from_div0null(args: t.List) -> exp.If:
153+
lhs = exp._wrap(seq_get(args, 0), exp.Binary)
154+
rhs = exp._wrap(seq_get(args, 1), exp.Binary)
155+
156+
# Returns 0 when divisor is 0 OR NULL
157+
cond = exp.EQ(this=rhs, expression=exp.Literal.number(0)).or_(
158+
exp.Is(this=rhs, expression=exp.null())
159+
)
160+
true = exp.Literal.number(0)
161+
false = exp.Div(this=lhs, expression=rhs)
162+
return exp.If(this=cond, true=true, false=false)
163+
164+
151165
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
152166
def _build_if_from_zeroifnull(args: t.List) -> exp.If:
153167
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
@@ -746,6 +760,7 @@ class Parser(parser.Parser):
746760
"DATEDIFF": _build_datediff,
747761
"DAYOFWEEKISO": exp.DayOfWeekIso.from_arg_list,
748762
"DIV0": _build_if_from_div0,
763+
"DIV0NULL": _build_if_from_div0null,
749764
"EDITDISTANCE": lambda args: exp.Levenshtein(
750765
this=seq_get(args, 0), expression=seq_get(args, 1), max_dist=seq_get(args, 2)
751766
),

tests/dialects/test_snowflake.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,28 @@ def test_snowflake(self):
841841
"duckdb": "CASE WHEN (c - d) = 0 AND NOT (a - b) IS NULL THEN 0 ELSE (a - b) / (c - d) END",
842842
},
843843
)
844+
self.validate_all(
845+
"DIV0NULL(foo, bar)",
846+
write={
847+
"snowflake": "IFF(bar = 0 OR bar IS NULL, 0, foo / bar)",
848+
"sqlite": "IIF(bar = 0 OR bar IS NULL, 0, CAST(foo AS REAL) / bar)",
849+
"presto": "IF(bar = 0 OR bar IS NULL, 0, CAST(foo AS DOUBLE) / bar)",
850+
"spark": "IF(bar = 0 OR bar IS NULL, 0, foo / bar)",
851+
"hive": "IF(bar = 0 OR bar IS NULL, 0, foo / bar)",
852+
"duckdb": "CASE WHEN bar = 0 OR bar IS NULL THEN 0 ELSE foo / bar END",
853+
},
854+
)
855+
self.validate_all(
856+
"DIV0NULL(a - b, c - d)",
857+
write={
858+
"snowflake": "IFF((c - d) = 0 OR (c - d) IS NULL, 0, (a - b) / (c - d))",
859+
"sqlite": "IIF((c - d) = 0 OR (c - d) IS NULL, 0, CAST((a - b) AS REAL) / (c - d))",
860+
"presto": "IF((c - d) = 0 OR (c - d) IS NULL, 0, CAST((a - b) AS DOUBLE) / (c - d))",
861+
"spark": "IF((c - d) = 0 OR (c - d) IS NULL, 0, (a - b) / (c - d))",
862+
"hive": "IF((c - d) = 0 OR (c - d) IS NULL, 0, (a - b) / (c - d))",
863+
"duckdb": "CASE WHEN (c - d) = 0 OR (c - d) IS NULL THEN 0 ELSE (a - b) / (c - d) END",
864+
},
865+
)
844866
self.validate_all(
845867
"ZEROIFNULL(foo)",
846868
write={

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,6 +1639,22 @@ BINARY;
16391639
DECOMPRESS_STRING('compressed_data', 'ZSTD');
16401640
VARCHAR;
16411641

1642+
# dialect: snowflake
1643+
DIV0(10, 0);
1644+
DOUBLE;
1645+
1646+
# dialect: snowflake
1647+
DIV0(tbl.double_col, tbl.double_col);
1648+
DOUBLE;
1649+
1650+
# dialect: snowflake
1651+
DIV0NULL(10, 0);
1652+
DOUBLE;
1653+
1654+
# dialect: snowflake
1655+
DIV0NULL(tbl.double_col, tbl.double_col);
1656+
DOUBLE;
1657+
16421658
# dialect: snowflake
16431659
LPAD('Hello', 10, '*');
16441660
VARCHAR;

0 commit comments

Comments
 (0)