Skip to content

Commit 57fd115

Browse files
committed
Fix: treat all instances of macro variables as case-insensitive
1 parent 16a032f commit 57fd115

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

sqlmesh/core/macros.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,23 @@ def _macro_str_replace(text: str) -> str:
128128
return f"self.template({text}, locals())"
129129

130130

131+
class CaseInsensitiveMapping(dict):
132+
def __init__(self, data: t.Dict[str, t.Any]) -> None:
133+
super().__init__(data)
134+
135+
self._lower = {k.lower(): v for k, v in data.items()}
136+
137+
def __getitem__(self, key: str) -> t.Any:
138+
if key in self:
139+
return super().__getitem__(key)
140+
return self._lower[key.lower()]
141+
142+
def get(self, key: str, default: t.Any = None) -> t.Any:
143+
if key in self:
144+
return super().get(key, default)
145+
return self._lower.get(key.lower(), default)
146+
147+
131148
class MacroDialect(Python):
132149
class Generator(Python.Generator):
133150
TRANSFORMS = {
@@ -313,11 +330,11 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
313330
"""
314331
# We try to convert all variables into sqlglot expressions because they're going to be converted
315332
# into strings; in sql we don't convert strings because that would result in adding quotes
316-
mapping = {
333+
base_mapping = {
317334
k: convert_sql(v, self.dialect)
318335
for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items())
319336
}
320-
return MacroStrTemplate(str(text)).safe_substitute(mapping)
337+
return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping))
321338

322339
def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
323340
if isinstance(node, MacroDef):

tests/core/test_macros.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,7 @@ def test_macro_with_spaces():
11131113
for sql, expected in (
11141114
("@x", '"a b"'),
11151115
("@{x}", '"a b"'),
1116+
("@{X}", '"a b"'),
11161117
("a_@x", '"a_a b"'),
11171118
("a.@x", 'a."a b"'),
11181119
("@y", "'a b'"),
@@ -1121,6 +1122,7 @@ def test_macro_with_spaces():
11211122
("a.@{y}", 'a."a b"'),
11221123
("@z", 'a."b c"'),
11231124
("d.@z", 'd.a."b c"'),
1125+
("@'test_@{X}_suffix'", "'test_a b_suffix'"),
11241126
):
11251127
assert evaluator.transform(parse_one(sql)).sql() == expected
11261128

0 commit comments

Comments
 (0)