Skip to content

Commit f585166

Browse files
committed
Merge remote-tracking branch 'upstream/main' into lint-fixes
2 parents 6a225f5 + fb79079 commit f585166

File tree

7 files changed

+162
-25
lines changed

7 files changed

+162
-25
lines changed

src/finch/autoschedule/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .optimize import (
1919
lift_subqueries,
2020
optimize,
21+
propagate_fields,
2122
propagate_map_queries,
2223
)
2324

@@ -37,6 +38,7 @@
3738
"Subquery",
3839
"Table",
3940
"optimize",
41+
"propagate_fields",
4042
"propagate_map_queries",
4143
"lift_subqueries",
4244
"PostOrderDFS",

src/finch/autoschedule/compiler.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from textwrap import dedent
2-
from typing import Any, TypeVar
2+
from typing import TypeVar
33

44
from ..finch_logic import (
55
Alias,
@@ -14,15 +14,17 @@
1414
Reorder,
1515
Subquery,
1616
Table,
17+
WithFields,
1718
)
1819
from ..symbolic import Term
1920

2021
T = TypeVar("T", bound="LogicNode")
2122

2223

23-
def get_or_insert(dictionary: dict[str, LogicNode], key: str, default: Any) -> Any:
24-
if key in dictionary:
25-
return dictionary[key]
24+
def get_or_insert(dictionary: dict[str, T], key: str, default: T) -> T:
25+
val = dictionary.get(key)
26+
if val is not None:
27+
return val
2628
dictionary[key] = default
2729
return default
2830

@@ -36,11 +38,14 @@ def get_structure(
3638
case Alias(name):
3739
return get_or_insert(aliases, name, Immediate(len(fields) + len(aliases)))
3840
case Subquery(Alias(name) as lhs, arg):
39-
if name in aliases:
40-
return aliases[name]
41-
return Subquery(
42-
get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases)
43-
)
41+
alias = aliases.get(name)
42+
if alias is not None:
43+
return alias
44+
in_arg = get_structure(arg, fields, aliases)
45+
in_lhs = get_structure(lhs, fields, aliases)
46+
assert isinstance(in_arg, WithFields)
47+
assert isinstance(in_lhs, WithFields)
48+
return Subquery(in_lhs, in_arg)
4449
case Table(tns, idxs):
4550
assert all(isinstance(idx, Field) for idx in idxs)
4651
return Table(

src/finch/autoschedule/optimize.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
from collections.abc import Iterable
2+
13
from ..finch_logic import (
24
Aggregate,
35
Alias,
6+
Field,
47
LogicNode,
58
MapJoin,
69
Plan,
710
Produces,
811
Query,
12+
Relabel,
913
Subquery,
14+
WithFields,
1015
)
1116
from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite, Term
1217
from .compiler import LogicCompiler
@@ -89,6 +94,31 @@ def rule_2(ex):
8994
return Rewrite(PostWalk(rule_2))(root)
9095

9196

97+
def _propagate_fields(
98+
root: LogicNode, fields: dict[LogicNode, Iterable[Field]]
99+
) -> LogicNode:
100+
match root:
101+
case Plan(bodies):
102+
return Plan(tuple(_propagate_fields(b, fields) for b in bodies))
103+
case Query(lhs, rhs):
104+
rhs = _propagate_fields(rhs, fields)
105+
assert isinstance(rhs, WithFields)
106+
fields[lhs] = rhs.get_fields()
107+
return Query(lhs, rhs)
108+
case Alias() as a:
109+
return Relabel(a, tuple(fields[a]))
110+
case node if node.is_expr():
111+
return node.make_term(
112+
node.head(), *[_propagate_fields(c, fields) for c in node.children()]
113+
)
114+
case node:
115+
return node
116+
117+
118+
def propagate_fields(root: LogicNode) -> LogicNode:
119+
return _propagate_fields(root, fields={})
120+
121+
92122
class DefaultLogicOptimizer:
93123
def __init__(self, ctx: LogicCompiler):
94124
self.ctx = ctx

src/finch/finch_logic/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Reorder,
1515
Subquery,
1616
Table,
17+
WithFields,
1718
)
1819

1920
__all__ = [
@@ -34,4 +35,5 @@
3435
"Reformat",
3536
"Subquery",
3637
"Table",
38+
"WithFields",
3739
]

src/finch/finch_logic/nodes.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,14 @@ def make_term(cls, *args: Term) -> Self:
6868

6969

7070
@dataclass(eq=True, frozen=True)
71-
class Immediate(LogicNode):
71+
class WithFields(LogicNode):
72+
@abstractmethod
73+
def get_fields(self) -> tuple[Field, ...]:
74+
"""Get this node's fields."""
75+
76+
77+
@dataclass(eq=True, frozen=True)
78+
class Immediate(WithFields):
7279
"""
7380
Represents a logical AST expression for the literal value `val`.
7481
@@ -97,6 +104,10 @@ def fill_value(self):
97104
def children(self):
98105
raise TypeError(f"`{type(self).__name__}` doesn't support `.children()`.")
99106

107+
def get_fields(self) -> tuple[Field, ...]:
108+
"""Returns fields of the node."""
109+
return ()
110+
100111

101112
@dataclass(eq=True, frozen=True)
102113
class Deferred(LogicNode):
@@ -183,7 +194,7 @@ def children(self):
183194

184195

185196
@dataclass(eq=True, frozen=True)
186-
class Table(LogicNode):
197+
class Table(WithFields):
187198
"""
188199
Represents a logical AST expression for a tensor object `tns`, indexed by fields
189200
`idxs...`. A table is a tensor with named dimensions.
@@ -210,9 +221,17 @@ def children(self):
210221
"""Returns the children of the node."""
211222
return [self.tns, *self.idxs]
212223

224+
def get_fields(self) -> tuple[Field, ...]:
225+
"""Returns fields of the node."""
226+
return self.idxs
227+
228+
@classmethod
229+
def make_term(cls, head, tns, *idxs):
230+
return head(tns, idxs)
231+
213232

214233
@dataclass(eq=True, frozen=True)
215-
class MapJoin(LogicNode):
234+
class MapJoin(WithFields):
216235
"""
217236
Represents a logical AST expression for mapping the function `op` across `args...`.
218237
Dimensions which are not present are broadcasted. Dimensions which are
@@ -225,7 +244,7 @@ class MapJoin(LogicNode):
225244
"""
226245

227246
op: Immediate
228-
args: tuple[LogicNode, ...]
247+
args: tuple[WithFields, ...]
229248

230249
@staticmethod
231250
def is_expr():
@@ -241,8 +260,19 @@ def children(self):
241260
"""Returns the children of the node."""
242261
return [self.op, *self.args]
243262

263+
def get_fields(self) -> tuple[Field, ...]:
264+
"""Returns fields of the node."""
265+
# (mtsokol) I'm not sure if this comment still applies - the order is preserved.
266+
# TODO: this is wrong here: the overall order should at least be concordant with
267+
# the args if the args are concordant
268+
fs: list[Field] = []
269+
for arg in self.args:
270+
fs.extend(arg.get_fields())
271+
272+
return tuple(fs)
273+
244274
@classmethod
245-
def make_term(cls, op: Immediate, *args: LogicNode) -> Self: # type: ignore[override]
275+
def make_term(cls, op: Immediate, *args: WithFields) -> Self: # type: ignore[override]
246276
return cls(op, tuple(args))
247277

248278

@@ -261,7 +291,7 @@ class Aggregate(LogicNode):
261291

262292
op: Immediate
263293
init: Immediate
264-
arg: LogicNode
294+
arg: WithFields
265295
idxs: tuple[Field, ...]
266296

267297
@staticmethod
@@ -278,9 +308,17 @@ def children(self):
278308
"""Returns the children of the node."""
279309
return [self.op, self.init, self.arg, *self.idxs]
280310

311+
def get_fields(self) -> tuple[Field, ...]:
312+
"""Returns fields of the node."""
313+
return tuple(field for field in self.arg.get_fields() if field not in self.idxs)
314+
315+
@classmethod
316+
def make_term(cls, head, op, init, arg, *idxs):
317+
return head(op, init, arg, idxs)
318+
281319

282320
@dataclass(eq=True, frozen=True)
283-
class Reorder(LogicNode):
321+
class Reorder(WithFields):
284322
"""
285323
Represents a logical AST statement that reorders the dimensions of `arg` to be
286324
`idxs...`. Dimensions known to be length 1 may be dropped. Dimensions that do not
@@ -308,9 +346,17 @@ def children(self):
308346
"""Returns the children of the node."""
309347
return [self.arg, *self.idxs]
310348

349+
def get_fields(self) -> tuple[Field, ...]:
350+
"""Returns fields of the node."""
351+
return self.idxs
352+
353+
@classmethod
354+
def make_term(cls, head, arg, *idxs):
355+
return head(arg, idxs)
356+
311357

312358
@dataclass(eq=True, frozen=True)
313-
class Relabel(LogicNode):
359+
class Relabel(WithFields):
314360
"""
315361
Represents a logical AST statement that relabels the dimensions of `arg` to be
316362
`idxs...`.
@@ -337,9 +383,13 @@ def children(self):
337383
"""Returns the children of the node."""
338384
return [self.arg, *self.idxs]
339385

386+
def get_fields(self) -> tuple[Field, ...]:
387+
"""Returns fields of the node."""
388+
return self.idxs
389+
340390

341391
@dataclass(eq=True, frozen=True)
342-
class Reformat(LogicNode):
392+
class Reformat(WithFields):
343393
"""
344394
Represents a logical AST statement that reformats `arg` into the tensor `tns`.
345395
@@ -349,7 +399,7 @@ class Reformat(LogicNode):
349399
"""
350400

351401
tns: Immediate
352-
arg: LogicNode
402+
arg: WithFields
353403

354404
@staticmethod
355405
def is_expr():
@@ -365,20 +415,24 @@ def children(self):
365415
"""Returns the children of the node."""
366416
return [self.tns, self.arg]
367417

418+
def get_fields(self) -> tuple[Field, ...]:
419+
"""Returns fields of the node."""
420+
return self.arg.get_fields()
421+
368422

369423
@dataclass(eq=True, frozen=True)
370-
class Subquery(LogicNode):
424+
class Subquery(WithFields):
371425
"""
372426
Represents a logical AST statement that evaluates `rhs`, binding the result to
373427
`lhs`, and returns `rhs`.
374428
375429
Attributes:
376430
lhs: The left-hand side of the binding.
377-
rhs: The argument to evaluate.
431+
arg: The argument to evaluate.
378432
"""
379433

380434
lhs: LogicNode
381-
arg: LogicNode
435+
arg: WithFields
382436

383437
@staticmethod
384438
def is_expr():
@@ -394,6 +448,10 @@ def children(self):
394448
"""Returns the children of the node."""
395449
return [self.lhs, self.arg]
396450

451+
def get_fields(self) -> tuple[Field, ...]:
452+
"""Returns fields of the node."""
453+
return self.arg.get_fields()
454+
397455

398456
@dataclass(eq=True, frozen=True)
399457
class Query(LogicNode):

src/finch/symbolic/term.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def is_expr(self) -> bool:
4040
`children()` if `True`."""
4141

4242
@abstractmethod
43-
def make_term(self, head: Any, *children: Term) -> Self:
43+
def make_term(self, *children: Term) -> Self:
4444
"""
45-
Construct a new term in the same family of terms with the given head type and
45+
Construct a new term in the same family of terms with the given
4646
children. This function should satisfy
4747
`x == x.make_term(*x.children())`
4848
"""

tests/test_scheduler.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from finch.autoschedule import lift_subqueries, propagate_map_queries
1+
from finch.autoschedule import lift_subqueries, propagate_fields, propagate_map_queries
22
from finch.finch_logic import (
33
Aggregate,
44
Alias,
5+
Field,
56
Immediate,
67
MapJoin,
78
Plan,
89
Produces,
910
Query,
11+
Relabel,
1012
Subquery,
13+
Table,
1114
)
1215

1316

@@ -91,3 +94,40 @@ def test_lift_subqueries():
9194

9295
result = lift_subqueries(plan)
9396
assert result == expected
97+
98+
99+
def test_propagate_fields():
100+
plan = Plan(
101+
(
102+
Query(
103+
Alias("A10"),
104+
MapJoin(
105+
Immediate("op"),
106+
(
107+
Table(Immediate("tbl1"), (Field("A1"), Field("A2"))),
108+
Table(Immediate("tbl2"), (Field("A2"), Field("A3"))),
109+
),
110+
),
111+
),
112+
Alias("A10"),
113+
)
114+
)
115+
116+
expected = Plan(
117+
(
118+
Query(
119+
Alias("A10"),
120+
MapJoin(
121+
Immediate("op"),
122+
(
123+
Table(Immediate("tbl1"), (Field("A1"), Field("A2"))),
124+
Table(Immediate("tbl2"), (Field("A2"), Field("A3"))),
125+
),
126+
),
127+
),
128+
Relabel(Alias("A10"), (Field("A1"), Field("A2"), Field("A3"))),
129+
)
130+
)
131+
132+
result = propagate_fields(plan)
133+
assert result == expected

0 commit comments

Comments
 (0)