Skip to content

Commit fecdd12

Browse files
committed
Add lift_subqueries scheduler pass
1 parent 624d76c commit fecdd12

File tree

6 files changed

+236
-43
lines changed

6 files changed

+236
-43
lines changed

.pre-commit-config.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v5.0.0
4+
hooks:
5+
- id: end-of-file-fixer
6+
- id: trailing-whitespace
7+
- id: mixed-line-ending
8+
- id: name-tests-test
9+
args: ["--pytest-test-first"]
10+
- id: no-commit-to-branch
11+
12+
- repo: https://github.com/astral-sh/ruff-pre-commit
13+
rev: v0.11.8
14+
hooks:
15+
- id: ruff
16+
args: ["--fix"]
17+
types_or: [ python, pyi, jupyter ]
18+
- id: ruff-format
19+
types_or: [ python, pyi, jupyter ]

src/finch/autoschedule/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
Subquery,
1515
Table,
1616
)
17-
from .optimize import optimize, propagate_map_queries
17+
from .optimize import (
18+
optimize,
19+
propagate_map_queries,
20+
lift_subqueries,
21+
)
1822
from ..symbolic import PostOrderDFS, PostWalk, PreWalk
1923

2024
__all__ = [
@@ -34,6 +38,7 @@
3438
"Table",
3539
"optimize",
3640
"propagate_map_queries",
41+
"lift_subqueries",
3742
"PostOrderDFS",
3843
"PostWalk",
3944
"PreWalk",

src/finch/autoschedule/optimize.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,57 @@
11
from .compiler import LogicCompiler
2-
from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
2+
from ..finch_logic import (
3+
Aggregate,
4+
Alias,
5+
LogicNode,
6+
MapJoin,
7+
Plan,
8+
Produces,
9+
Query,
10+
Subquery,
11+
)
312
from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite
413

514

615
def optimize(prgm: LogicNode) -> LogicNode:
716
# ...
8-
return propagate_map_queries(prgm)
17+
prgm = lift_subqueries(prgm)
18+
prgm = propagate_map_queries(prgm)
19+
return prgm
920

1021

11-
def get_productions(root: LogicNode) -> LogicNode:
22+
def _lift_subqueries_expr(node: LogicNode, bindings: dict) -> LogicNode:
23+
match node:
24+
case Subquery(lhs, arg):
25+
if lhs not in bindings:
26+
arg_2 = _lift_subqueries_expr(arg, bindings)
27+
bindings[lhs] = arg_2
28+
return lhs
29+
case any if any.is_expr():
30+
return any.make_term(
31+
any.head(),
32+
*map(lambda x: _lift_subqueries_expr(x, bindings), any.children()),
33+
)
34+
case _:
35+
return node
36+
37+
38+
def lift_subqueries(node: LogicNode) -> LogicNode:
39+
match node:
40+
case Plan(bodies):
41+
return Plan(tuple(map(lift_subqueries, bodies)))
42+
case Query(lhs, rhs):
43+
bindings = {}
44+
rhs_2 = _lift_subqueries_expr(rhs, bindings)
45+
return Plan(
46+
(*[Query(lhs, rhs) for lhs, rhs in bindings.items()], Query(lhs, rhs_2))
47+
)
48+
case Produces() as p:
49+
return p
50+
case _:
51+
raise Exception(f"Invalid node: {node}")
52+
53+
54+
def _get_productions(root: LogicNode) -> list[LogicNode]:
1255
for node in PostOrderDFS(root):
1356
if isinstance(node, Produces):
1457
return [arg for arg in PostOrderDFS(node) if isinstance(arg, Alias)]
@@ -22,7 +65,7 @@ def rule_agg_to_mapjoin(ex):
2265
return MapJoin(op, (init, arg))
2366

2467
root = Rewrite(PostWalk(rule_agg_to_mapjoin))(root)
25-
rets = get_productions(root)
68+
rets = _get_productions(root)
2669
props = {}
2770
for node in PostOrderDFS(root):
2871
match node:

src/finch/finch_logic/interpreter.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,53 @@
1-
from __future__ import annotations
1+
from dataclasses import dataclass
22
from itertools import product
3+
from typing import Iterable, Any
34
import numpy as np
4-
from ..finch_logic import *
5-
from ..symbolic import Term
6-
from ..algebra import *
5+
from ..finch_logic import (
6+
Immediate,
7+
Deferred,
8+
Field,
9+
Alias,
10+
Table,
11+
MapJoin,
12+
Aggregate,
13+
Query,
14+
Plan,
15+
Produces,
16+
Subquery,
17+
Relabel,
18+
Reorder,
19+
)
20+
from ..algebra import return_type, fill_value, element_type, fixpoint_type
21+
722

823
@dataclass(eq=True, frozen=True)
9-
class TableValue():
24+
class TableValue:
1025
tns: Any
1126
idxs: Iterable[Any]
27+
1228
def __post_init__(self):
1329
if isinstance(self.tns, TableValue):
1430
raise ValueError("The tensor (tns) cannot be a TableValue")
1531

16-
from typing import Any, Type
1732

1833
class FinchLogicInterpreter:
1934
def __init__(self, *, make_tensor=np.full):
2035
self.verbose = False
2136
self.bindings = {}
2237
self.make_tensor = make_tensor # Added make_tensor argument
23-
38+
2439
def __call__(self, node):
2540
# Example implementation for evaluating an expression
2641
if self.verbose:
27-
print(f"Evaluating: {expression}")
42+
print(f"Evaluating: {node}")
2843
# Placeholder for actual logic
2944
head = node.head()
3045
if head == Immediate:
3146
return node.val
3247
elif head == Deferred:
33-
raise ValueError("The interpreter cannot evaluate a deferred node, a compiler might generate code for it")
48+
raise ValueError(
49+
"The interpreter cannot evaluate a deferred node, a compiler might generate code for it"
50+
)
3451
elif head == Field:
3552
raise ValueError("Fields cannot be used in expressions")
3653
elif head == Alias:
@@ -59,7 +76,9 @@ def __call__(self, node):
5976
dims[idx] = dim
6077
fill_val = op(*[fill_value(arg.tns) for arg in args])
6178
dtype = return_type(op, *[element_type(arg.tns) for arg in args])
62-
result = self.make_tensor(tuple(dims[idx] for idx in idxs), fill_val, dtype = dtype)
79+
result = self.make_tensor(
80+
tuple(dims[idx] for idx in idxs), fill_val, dtype=dtype
81+
)
6382
for crds in product(*[range(dims[idx]) for idx in idxs]):
6483
idx_crds = {idx: crd for (idx, crd) in zip(idxs, crds)}
6584
vals = [arg.tns[*[idx_crds[idx] for idx in arg.idxs]] for arg in args]
@@ -74,10 +93,16 @@ def __call__(self, node):
7493
init = node.init.val
7594
op = node.op.val
7695
dtype = fixpoint_type(op, init, element_type(arg.tns))
77-
new_shape = [dim for (dim, idx) in zip(arg.tns.shape, arg.idxs) if not idx in node.idxs]
96+
new_shape = [
97+
dim
98+
for (dim, idx) in zip(arg.tns.shape, arg.idxs)
99+
if idx not in node.idxs
100+
]
78101
result = self.make_tensor(new_shape, init, dtype=dtype)
79102
for crds in product(*[range(dim) for dim in arg.tns.shape]):
80-
out_crds = [crd for (crd, idx) in zip(crds, arg.idxs) if not idx in node.idxs]
103+
out_crds = [
104+
crd for (crd, idx) in zip(crds, arg.idxs) if idx not in node.idxs
105+
]
81106
result[*out_crds] = op(result[*out_crds], arg.tns[*crds])
82107
return TableValue(result, [idx for idx in arg.idxs if idx not in node.idxs])
83108
elif head == Relabel:
@@ -92,7 +117,7 @@ def __call__(self, node):
92117
raise ValueError("Trying to drop a dimension that is not 1")
93118
arg_dims = {idx: dim for idx, dim in zip(arg.idxs, arg.tns.shape)}
94119
dims = [arg_dims.get(idx, 1) for idx in node.idxs]
95-
result = self.make_tensor(dims, fill_value(arg.tns), dtype = arg.tns.dtype)
120+
result = self.make_tensor(dims, fill_value(arg.tns), dtype=arg.tns.dtype)
96121
for crds in product(*[range(dim) for dim in dims]):
97122
node_crds = {idx: crd for (idx, crd) in zip(node.idxs, crds)}
98123
in_crds = [node_crds.get(idx, 0) for idx in arg.idxs]
@@ -110,8 +135,8 @@ def __call__(self, node):
110135
elif head == Produces:
111136
return tuple(self(arg).tns for arg in node.args)
112137
elif head == Subquery:
113-
if not node.lhs in self.bindings:
114-
self.bindings[node.lhs] = self(node.rhs)
138+
if node.lhs not in self.bindings:
139+
self.bindings[node.lhs] = self(node.arg)
115140
return self.bindings[node.lhs]
116141
else:
117-
raise ValueError(f"Unknown expression type: {head}")
142+
raise ValueError(f"Unknown expression type: {head}")

0 commit comments

Comments
 (0)