Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"
packages = [{include = "finch", from = "src"}]

[tool.poetry.dependencies]
python = "^3.10"
python = "^3.11"
numpy = ">=1.19"

[tool.poetry.group.test.dependencies]
Expand Down
4 changes: 2 additions & 2 deletions src/finch/autoschedule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .finch_logic import (
from ..finch_logic import (
Aggregate,
Alias,
Deferred,
Expand All @@ -15,7 +15,7 @@
Table,
)
from .optimize import optimize, propagate_map_queries
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk
from ..symbolic import PostOrderDFS, PostWalk, PreWalk

__all__ = [
"Aggregate",
Expand Down
2 changes: 1 addition & 1 deletion src/finch/autoschedule/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from textwrap import dedent
from typing import Any

from .finch_logic import (
from ..finch_logic import (
Alias,
Deferred,
Field,
Expand Down
2 changes: 1 addition & 1 deletion src/finch/autoschedule/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .compiler import LogicCompiler
from .rewrite_tools import gensym
from ..symbolic import gensym


class LogicExecutor:
Expand Down
6 changes: 3 additions & 3 deletions src/finch/autoschedule/optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .compiler import LogicCompiler
from .finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
from .rewrite_tools import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite
from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite


def optimize(prgm: LogicNode) -> LogicNode:
Expand Down Expand Up @@ -52,4 +52,4 @@ def __init__(self, ctx: LogicCompiler):

def __call__(self, prgm: LogicNode):
prgm = optimize(prgm)
return self.ctx(prgm)
return self.ctx(prgm)
8 changes: 8 additions & 0 deletions src/finch/finch_logic/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def children(self):
"""Returns the children of the node."""
return [self.op, *self.args]

@classmethod
def make_term(cls, head, op, *args):
return head(op, args)


@dataclass(eq=True, frozen=True)
class Aggregate(LogicNode):
Expand Down Expand Up @@ -412,3 +416,7 @@ def is_stateful():
def children(self):
"""Returns the children of the node."""
return [*self.bodies]

@classmethod
def make_term(cls, head, *val):
return head(val)
2 changes: 1 addition & 1 deletion src/finch/symbolic/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, x: Term) -> Term | None:
new_args = list(map(self, args))
if all(arg is None for arg in new_args):
return self.rw(x)
y = x.make_term(*map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
y = x.make_term(x.head(), *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
return default_rewrite(self.rw(y), y)
return self.rw(x)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from finch.autoschedule import propagate_map_queries
from finch.finch_logic import *


def test_propagate_map_queries_simple():
plan = Plan(
(
Query(Alias("A10"), Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ())),
Query(Alias("A11"), Alias("A10")),
Produces((Alias("11"),)),
)
)
expected = Plan(
(
Query(Alias("A11"), MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]")))),
Produces((Alias("11"),)),
)
)

result = propagate_map_queries(plan)
assert result == expected