Skip to content

Commit 624d76c

Browse files
authored
Merge pull request #6 from mtsokol/test-propagate-map-queries
TEST: Test `propagate_map_queries` pass
2 parents 931597c + dd003bd commit 624d76c

File tree

9 files changed

+43
-11
lines changed

9 files changed

+43
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ readme = "README.md"
77
packages = [{include = "finch", from = "src"}]
88

99
[tool.poetry.dependencies]
10-
python = "^3.10"
10+
python = "^3.11"
1111
numpy = ">=1.19"
1212

1313
[tool.poetry.group.test.dependencies]

src/finch/autoschedule/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .finch_logic import (
1+
from ..finch_logic import (
22
Aggregate,
33
Alias,
44
Deferred,
@@ -15,7 +15,7 @@
1515
Table,
1616
)
1717
from .optimize import optimize, propagate_map_queries
18-
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk
18+
from ..symbolic import PostOrderDFS, PostWalk, PreWalk
1919

2020
__all__ = [
2121
"Aggregate",

src/finch/autoschedule/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from textwrap import dedent
33
from typing import Any
44

5-
from .finch_logic import (
5+
from ..finch_logic import (
66
Alias,
77
Deferred,
88
Field,

src/finch/autoschedule/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .compiler import LogicCompiler
2-
from .rewrite_tools import gensym
2+
from ..symbolic import gensym
33

44

55
class LogicExecutor:

src/finch/autoschedule/optimize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .compiler import LogicCompiler
2-
from .finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
3-
from .rewrite_tools import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite
2+
from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
3+
from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite
44

55

66
def optimize(prgm: LogicNode) -> LogicNode:
@@ -52,4 +52,4 @@ def __init__(self, ctx: LogicCompiler):
5252

5353
def __call__(self, prgm: LogicNode):
5454
prgm = optimize(prgm)
55-
return self.ctx(prgm)
55+
return self.ctx(prgm)

src/finch/finch_logic/nodes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ def children(self):
191191
"""Returns the children of the node."""
192192
return [self.op, *self.args]
193193

194+
@classmethod
195+
def make_term(cls, head, op, *args):
196+
return head(op, args)
197+
194198

195199
@dataclass(eq=True, frozen=True)
196200
class Aggregate(LogicNode):
@@ -412,3 +416,7 @@ def is_stateful():
412416
def children(self):
413417
"""Returns the children of the node."""
414418
return [*self.bodies]
419+
420+
@classmethod
421+
def make_term(cls, head, *val):
422+
return head(val)

src/finch/symbolic/rewriters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __call__(self, x: Term) -> Term | None:
8787
new_args = list(map(self, args))
8888
if all(arg is None for arg in new_args):
8989
return self.rw(x)
90-
y = x.make_term(*map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
90+
y = x.make_term(x.head(), *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
9191
return default_rewrite(self.rw(y), y)
9292
return self.rw(x)
9393

src/finch/symbolic/term.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def is_expr(self) -> bool:
3030

3131
@abstractmethod
3232
def make_term(self, head: Any, children: List[Term]) -> Term:
33-
"""Construct a new term in the same family of terms with the given head type and children."""
33+
"""
34+
Construct a new term in the same family of terms with the given head type and children.
35+
This function should satisfy `x == x.make_term(x.head(), *x.children())`
36+
"""
3437
pass
3538

3639
def __hash__(self) -> int:
@@ -52,4 +55,4 @@ def PreOrderDFS(node: Term) -> Iterator[Term]:
5255
yield node
5356
if node.is_expr():
5457
for arg in node.children():
55-
yield from PreOrderDFS(arg)
58+
yield from PreOrderDFS(arg)

tests/test_scheduler.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from finch.autoschedule import propagate_map_queries
2+
from finch.finch_logic import *
3+
4+
5+
def test_propagate_map_queries_simple():
6+
plan = Plan(
7+
(
8+
Query(Alias("A10"), Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ())),
9+
Query(Alias("A11"), Alias("A10")),
10+
Produces((Alias("11"),)),
11+
)
12+
)
13+
expected = Plan(
14+
(
15+
Query(Alias("A11"), MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]")))),
16+
Produces((Alias("11"),)),
17+
)
18+
)
19+
20+
result = propagate_map_queries(plan)
21+
assert result == expected

0 commit comments

Comments
 (0)