Skip to content

Commit f82bbff

Browse files
committed
fix: handle ParseError gracefully in MERGE assignment validation and improve type annotations in parsing utilities
1 parent bf869d8 commit f82bbff

File tree

3 files changed

+38
-30
lines changed

3 files changed

+38
-30
lines changed

sqlspec/builder/_merge.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
parameter binding and validation.
55
"""
66

7+
import contextlib
78
from collections.abc import Mapping, Sequence
89
from datetime import datetime
910
from decimal import Decimal
@@ -12,6 +13,7 @@
1213

1314
from mypy_extensions import trait
1415
from sqlglot import exp
16+
from sqlglot.errors import ParseError
1517
from typing_extensions import Self
1618

1719
from sqlspec.builder._base import QueryBuilder
@@ -45,32 +47,34 @@ def _is_column_reference(self, value: str) -> bool:
4547
if not isinstance(value, str):
4648
return False
4749

48-
parsed: exp.Expression | None = exp.maybe_parse(value.strip())
49-
if parsed is None:
50-
return False
51-
52-
if isinstance(parsed, exp.Column):
53-
return parsed.table is not None and bool(parsed.table)
54-
55-
return isinstance(
56-
parsed,
57-
(
58-
exp.Dot,
59-
exp.Add,
60-
exp.Sub,
61-
exp.Mul,
62-
exp.Div,
63-
exp.Mod,
64-
exp.Func,
65-
exp.Anonymous,
66-
exp.Null,
67-
exp.CurrentTimestamp,
68-
exp.CurrentDate,
69-
exp.CurrentTime,
70-
exp.Paren,
71-
exp.Case,
72-
),
73-
)
50+
with contextlib.suppress(ParseError):
51+
parsed: exp.Expression | None = exp.maybe_parse(value.strip())
52+
if parsed is None:
53+
return False
54+
55+
if isinstance(parsed, exp.Column):
56+
return parsed.table is not None and bool(parsed.table)
57+
58+
return isinstance(
59+
parsed,
60+
(
61+
exp.Dot,
62+
exp.Add,
63+
exp.Sub,
64+
exp.Mul,
65+
exp.Div,
66+
exp.Mod,
67+
exp.Func,
68+
exp.Anonymous,
69+
exp.Null,
70+
exp.CurrentTimestamp,
71+
exp.CurrentDate,
72+
exp.CurrentTime,
73+
exp.Paren,
74+
exp.Case,
75+
),
76+
)
77+
return False
7478

7579
def _process_assignment(self, target_column: str, value: Any) -> exp.Expression:
7680
column_identifier = exp.column(target_column) if isinstance(target_column, str) else target_column
@@ -571,7 +575,9 @@ def when_not_matched_then_insert(
571575
if values is None:
572576
using_alias = None
573577
using_expr = current_expr.args.get("using")
574-
if using_expr is not None and (isinstance(using_expr, (exp.Subquery, exp.Table)) or hasattr(using_expr, "alias")):
578+
if using_expr is not None and (
579+
isinstance(using_expr, (exp.Subquery, exp.Table)) or hasattr(using_expr, "alias")
580+
):
575581
using_alias = using_expr.alias
576582
column_values = [f"{using_alias}.{col}" for col in column_names] if using_alias else column_names
577583
else:

sqlspec/builder/_parsing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def extract_column_name(column: str | exp.Column) -> str:
2828
Column name as string for use as parameter name
2929
"""
3030
if isinstance(column, str):
31-
col_expr = exp.maybe_parse(column)
31+
col_expr: exp.Expression | None = exp.maybe_parse(column)
3232
if isinstance(col_expr, exp.Column):
3333
return col_expr.name
3434
return column.split(".")[-1] if "." in column else column
@@ -93,7 +93,7 @@ def parse_column_expression(column_input: str | exp.Expression | Any, builder: A
9393
def parse_table_expression(table_input: str, explicit_alias: str | None = None) -> exp.Expression:
9494
"""Parses a table string that can be a name, a name with an alias, or a subquery string."""
9595
with contextlib.suppress(Exception):
96-
parsed = exp.maybe_parse(f"SELECT * FROM {table_input}")
96+
parsed: exp.Expression | None = exp.maybe_parse(f"SELECT * FROM {table_input}")
9797
if isinstance(parsed, exp.Select) and parsed.args.get("from"):
9898
from_clause = cast("exp.From", parsed.args.get("from"))
9999
table_expr = from_clause.this

tests/unit/test_sql_factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1520,7 +1520,9 @@ def test_merge_complete_example() -> None:
15201520
assert "WHEN NOT MATCHED THEN INSERT" in stmt.sql
15211521
assert "WHEN NOT MATCHED BY SOURCE THEN UPDATE" in stmt.sql
15221522
assert "NOW()" in stmt.sql
1523-
assert len(stmt.parameters) >= 6
1523+
assert "status" in stmt.parameters
1524+
assert stmt.parameters["status"] == "archived"
1525+
assert len(stmt.parameters) == 1
15241526

15251527

15261528
def test_querybuilder_parameter_style_handling_regression() -> None:

0 commit comments

Comments
 (0)