From fecdd12f8d669fa3eab29bfa8d96b809a351c99b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Mon, 5 May 2025 15:51:28 +0000 Subject: [PATCH] Add `lift_subqueries` scheduler pass --- .pre-commit-config.yaml | 19 +++++++ src/finch/autoschedule/__init__.py | 7 ++- src/finch/autoschedule/optimize.py | 51 +++++++++++++++-- src/finch/finch_logic/interpreter.py | 57 +++++++++++++------ src/finch/finch_logic/nodes.py | 61 ++++++++++++++------ tests/test_scheduler.py | 84 ++++++++++++++++++++++++++-- 6 files changed, 236 insertions(+), 43 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..23a3a9c9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: mixed-line-ending + - id: name-tests-test + args: ["--pytest-test-first"] + - id: no-commit-to-branch + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.8 + hooks: + - id: ruff + args: ["--fix"] + types_or: [ python, pyi, jupyter ] + - id: ruff-format + types_or: [ python, pyi, jupyter ] diff --git a/src/finch/autoschedule/__init__.py b/src/finch/autoschedule/__init__.py index 5e18633c..6730353d 100644 --- a/src/finch/autoschedule/__init__.py +++ b/src/finch/autoschedule/__init__.py @@ -14,7 +14,11 @@ Subquery, Table, ) -from .optimize import optimize, propagate_map_queries +from .optimize import ( + optimize, + propagate_map_queries, + lift_subqueries, +) from ..symbolic import PostOrderDFS, PostWalk, PreWalk __all__ = [ @@ -34,6 +38,7 @@ "Table", "optimize", "propagate_map_queries", + "lift_subqueries", "PostOrderDFS", "PostWalk", "PreWalk", diff --git a/src/finch/autoschedule/optimize.py b/src/finch/autoschedule/optimize.py index 0b1e91f3..4e2a36f7 100644 --- a/src/finch/autoschedule/optimize.py +++ b/src/finch/autoschedule/optimize.py @@ -1,14 +1,57 @@ from .compiler import LogicCompiler -from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query +from ..finch_logic import ( + Aggregate, + Alias, + LogicNode, + MapJoin, + Plan, + Produces, + Query, + Subquery, +) from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite def optimize(prgm: LogicNode) -> LogicNode: # ... - return propagate_map_queries(prgm) + prgm = lift_subqueries(prgm) + prgm = propagate_map_queries(prgm) + return prgm -def get_productions(root: LogicNode) -> LogicNode: +def _lift_subqueries_expr(node: LogicNode, bindings: dict) -> LogicNode: + match node: + case Subquery(lhs, arg): + if lhs not in bindings: + arg_2 = _lift_subqueries_expr(arg, bindings) + bindings[lhs] = arg_2 + return lhs + case any if any.is_expr(): + return any.make_term( + any.head(), + *map(lambda x: _lift_subqueries_expr(x, bindings), any.children()), + ) + case _: + return node + + +def lift_subqueries(node: LogicNode) -> LogicNode: + match node: + case Plan(bodies): + return Plan(tuple(map(lift_subqueries, bodies))) + case Query(lhs, rhs): + bindings = {} + rhs_2 = _lift_subqueries_expr(rhs, bindings) + return Plan( + (*[Query(lhs, rhs) for lhs, rhs in bindings.items()], Query(lhs, rhs_2)) + ) + case Produces() as p: + return p + case _: + raise Exception(f"Invalid node: {node}") + + +def _get_productions(root: LogicNode) -> list[LogicNode]: for node in PostOrderDFS(root): if isinstance(node, Produces): return [arg for arg in PostOrderDFS(node) if isinstance(arg, Alias)] @@ -22,7 +65,7 @@ def rule_agg_to_mapjoin(ex): return MapJoin(op, (init, arg)) root = Rewrite(PostWalk(rule_agg_to_mapjoin))(root) - rets = get_productions(root) + rets = _get_productions(root) props = {} for node in PostOrderDFS(root): match node: diff --git a/src/finch/finch_logic/interpreter.py b/src/finch/finch_logic/interpreter.py index b747dc7f..13d6a3b0 100644 --- a/src/finch/finch_logic/interpreter.py +++ b/src/finch/finch_logic/interpreter.py @@ -1,36 +1,53 @@ -from __future__ import annotations +from dataclasses import dataclass from itertools import product +from typing import Iterable, Any import numpy as np -from ..finch_logic import * -from ..symbolic import Term -from ..algebra import * +from ..finch_logic import ( + Immediate, + Deferred, + Field, + Alias, + Table, + MapJoin, + Aggregate, + Query, + Plan, + Produces, + Subquery, + Relabel, + Reorder, +) +from ..algebra import return_type, fill_value, element_type, fixpoint_type + @dataclass(eq=True, frozen=True) -class TableValue(): +class TableValue: tns: Any idxs: Iterable[Any] + def __post_init__(self): if isinstance(self.tns, TableValue): raise ValueError("The tensor (tns) cannot be a TableValue") -from typing import Any, Type class FinchLogicInterpreter: def __init__(self, *, make_tensor=np.full): self.verbose = False self.bindings = {} self.make_tensor = make_tensor # Added make_tensor argument - + def __call__(self, node): # Example implementation for evaluating an expression if self.verbose: - print(f"Evaluating: {expression}") + print(f"Evaluating: {node}") # Placeholder for actual logic head = node.head() if head == Immediate: return node.val elif head == Deferred: - raise ValueError("The interpreter cannot evaluate a deferred node, a compiler might generate code for it") + raise ValueError( + "The interpreter cannot evaluate a deferred node, a compiler might generate code for it" + ) elif head == Field: raise ValueError("Fields cannot be used in expressions") elif head == Alias: @@ -59,7 +76,9 @@ def __call__(self, node): dims[idx] = dim fill_val = op(*[fill_value(arg.tns) for arg in args]) dtype = return_type(op, *[element_type(arg.tns) for arg in args]) - result = self.make_tensor(tuple(dims[idx] for idx in idxs), fill_val, dtype = dtype) + result = self.make_tensor( + tuple(dims[idx] for idx in idxs), fill_val, dtype=dtype + ) for crds in product(*[range(dims[idx]) for idx in idxs]): idx_crds = {idx: crd for (idx, crd) in zip(idxs, crds)} vals = [arg.tns[*[idx_crds[idx] for idx in arg.idxs]] for arg in args] @@ -74,10 +93,16 @@ def __call__(self, node): init = node.init.val op = node.op.val dtype = fixpoint_type(op, init, element_type(arg.tns)) - new_shape = [dim for (dim, idx) in zip(arg.tns.shape, arg.idxs) if not idx in node.idxs] + new_shape = [ + dim + for (dim, idx) in zip(arg.tns.shape, arg.idxs) + if idx not in node.idxs + ] result = self.make_tensor(new_shape, init, dtype=dtype) for crds in product(*[range(dim) for dim in arg.tns.shape]): - out_crds = [crd for (crd, idx) in zip(crds, arg.idxs) if not idx in node.idxs] + out_crds = [ + crd for (crd, idx) in zip(crds, arg.idxs) if idx not in node.idxs + ] result[*out_crds] = op(result[*out_crds], arg.tns[*crds]) return TableValue(result, [idx for idx in arg.idxs if idx not in node.idxs]) elif head == Relabel: @@ -92,7 +117,7 @@ def __call__(self, node): raise ValueError("Trying to drop a dimension that is not 1") arg_dims = {idx: dim for idx, dim in zip(arg.idxs, arg.tns.shape)} dims = [arg_dims.get(idx, 1) for idx in node.idxs] - result = self.make_tensor(dims, fill_value(arg.tns), dtype = arg.tns.dtype) + result = self.make_tensor(dims, fill_value(arg.tns), dtype=arg.tns.dtype) for crds in product(*[range(dim) for dim in dims]): node_crds = {idx: crd for (idx, crd) in zip(node.idxs, crds)} in_crds = [node_crds.get(idx, 0) for idx in arg.idxs] @@ -110,8 +135,8 @@ def __call__(self, node): elif head == Produces: return tuple(self(arg).tns for arg in node.args) elif head == Subquery: - if not node.lhs in self.bindings: - self.bindings[node.lhs] = self(node.rhs) + if node.lhs not in self.bindings: + self.bindings[node.lhs] = self(node.arg) return self.bindings[node.lhs] else: - raise ValueError(f"Unknown expression type: {head}") \ No newline at end of file + raise ValueError(f"Unknown expression type: {head}") diff --git a/src/finch/finch_logic/nodes.py b/src/finch/finch_logic/nodes.py index 29776794..60ebdfb9 100644 --- a/src/finch/finch_logic/nodes.py +++ b/src/finch/finch_logic/nodes.py @@ -1,20 +1,27 @@ from abc import abstractmethod -from collections.abc import Iterable from dataclasses import dataclass from typing import Any from ..symbolic import Term + @dataclass(eq=True, frozen=True) class LogicNode(Term): """ LogicNode - Represents a Finch Logic IR node. Finch uses a variant of Concrete Field Notation + Represents a Finch Logic IR node. Finch uses a variant of Concrete Field Notation as an intermediate representation. - The LogicNode struct represents many different Finch IR nodes. The nodes are + The LogicNode struct represents many different Finch IR nodes. The nodes are differentiated by a `FinchLogic.LogicNodeKind` enum. """ + + @staticmethod + @abstractmethod + def is_expr(): + """Determines if the node is expresion.""" + ... + @staticmethod @abstractmethod def is_stateful(): @@ -26,6 +33,10 @@ def head(cls): """Returns the head of the node.""" return cls + def children(self): + """Returns the children of the node.""" + raise Exception(f"`children` isn't supported for {self.__class__}.") + @classmethod def make_term(cls, head, *args): """Creates a term with the given head and arguments.""" @@ -40,6 +51,7 @@ class Immediate(LogicNode): Attributes: val: The literal value. """ + val: Any @staticmethod @@ -56,13 +68,14 @@ def is_stateful(): @dataclass(eq=True, frozen=True) class Deferred(LogicNode): """ - Represents a logical AST expression for an expression `ex` of type `type`, + Represents a logical AST expression for an expression `ex` of type `type`, yet to be evaluated. Attributes: ex: The expression to be evaluated. type_: The type of the expression. """ + ex: Any type_: Any @@ -91,6 +104,7 @@ class Field(LogicNode): Attributes: name: The name of the field. """ + name: str @staticmethod @@ -117,6 +131,7 @@ class Alias(LogicNode): Attributes: name: The name of the alias. """ + name: str @staticmethod @@ -144,8 +159,9 @@ class Table(LogicNode): tns: The tensor object. idxs: The fields indexing the tensor. """ + tns: LogicNode - idxs: Iterable[LogicNode] + idxs: tuple[LogicNode] @staticmethod def is_expr(): @@ -174,8 +190,9 @@ class MapJoin(LogicNode): op: The function to map. args: The arguments to map the function across. """ + op: LogicNode - args: Iterable[LogicNode] + args: tuple[LogicNode] @staticmethod def is_expr(): @@ -208,10 +225,11 @@ class Aggregate(LogicNode): arg: The argument to reduce. idxs: The dimensions to reduce. """ + op: LogicNode init: LogicNode arg: LogicNode - idxs: Iterable[LogicNode] + idxs: tuple[LogicNode] @staticmethod def is_expr(): @@ -239,8 +257,9 @@ class Reorder(LogicNode): arg: The argument to reorder. idxs: The new order of dimensions. """ + arg: LogicNode - idxs: Iterable[LogicNode] + idxs: tuple[LogicNode] @staticmethod def is_expr(): @@ -266,8 +285,9 @@ class Relabel(LogicNode): arg: The argument to relabel. idxs: The new labels for dimensions. """ + arg: LogicNode - idxs: Iterable[LogicNode] + idxs: tuple[LogicNode] @staticmethod def is_expr(): @@ -293,6 +313,7 @@ class Reformat(LogicNode): tns: The target tensor. arg: The argument to reformat. """ + tns: LogicNode arg: LogicNode @@ -314,15 +335,16 @@ def children(self): @dataclass(eq=True, frozen=True) class Subquery(LogicNode): """ - Represents a logical AST statement that evaluates `rhs`, binding the result to `lhs`, + Represents a logical AST statement that evaluates `rhs`, binding the result to `lhs`, and returns `rhs`. Attributes: lhs: The left-hand side of the binding. rhs: The argument to evaluate. """ + lhs: LogicNode - rhs: LogicNode + arg: LogicNode @staticmethod def is_expr(): @@ -336,7 +358,7 @@ def is_stateful(): def children(self): """Returns the children of the node.""" - return [self.lhs, self.rhs] + return [self.lhs, self.arg] @dataclass(eq=True, frozen=True) @@ -348,6 +370,7 @@ class Query(LogicNode): lhs: The left-hand side of the binding. rhs: The right-hand side to evaluate. """ + lhs: LogicNode rhs: LogicNode @@ -369,13 +392,14 @@ def children(self): @dataclass(eq=True, frozen=True) class Produces(LogicNode): """ - Represents a logical AST statement that returns `args...` from the current plan. + Represents a logical AST statement that returns `args...` from the current plan. Halts execution of the program. Attributes: args: The arguments to return. """ - args: Iterable[LogicNode] + + args: tuple[LogicNode] @staticmethod def is_expr(): @@ -391,17 +415,22 @@ def children(self): """Returns the children of the node.""" return [*self.args] + @classmethod + def make_term(cls, head, *args): + return head(args) + @dataclass(eq=True, frozen=True) class Plan(LogicNode): """ - Represents a logical AST statement that executes a sequence of statements `bodies...`. + Represents a logical AST statement that executes a sequence of statements `bodies...`. Returns the last statement. Attributes: bodies: The sequence of statements to execute. """ - bodies: Iterable[LogicNode] = () + + bodies: tuple[LogicNode] = () @staticmethod def is_expr(): diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index b5bc285c..39e3ccf5 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,21 +1,93 @@ -from finch.autoschedule import propagate_map_queries -from finch.finch_logic import * +from finch.autoschedule import propagate_map_queries, lift_subqueries +from finch.finch_logic import ( + Plan, + Query, + Alias, + Aggregate, + Immediate, + MapJoin, + Produces, + Subquery, +) def test_propagate_map_queries_simple(): plan = Plan( ( - Query(Alias("A10"), Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ())), + Query( + Alias("A10"), + Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ()), + ), Query(Alias("A11"), Alias("A10")), - Produces((Alias("11"),)), + Produces((Alias("A11"),)), ) ) expected = Plan( ( - Query(Alias("A11"), MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]")))), - Produces((Alias("11"),)), + Query( + Alias("A11"), + MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]"))), + ), + Produces((Alias("A11"),)), ) ) result = propagate_map_queries(plan) assert result == expected + + +def test_lift_subqueries(): + plan = Plan( + ( + Query( + Alias("A10"), + Plan( + ( + Subquery(Alias("C10"), Immediate(0)), + Subquery( + Alias("B10"), + MapJoin( + Immediate("+"), + ( + Subquery(Alias("C10"), Immediate(0)), + Immediate("[1,2,3]"), + ), + ), + ), + Subquery(Alias("B10"), Immediate(0)), + Produces((Alias("B10"),)), + ) + ), + ), + Produces((Alias("A10"),)), + ) + ) + + expected = Plan( + ( + Plan( + ( + Query(Alias("C10"), Immediate(0)), + Query( + Alias("B10"), + MapJoin(Immediate("+"), (Alias("C10"), Immediate("[1,2,3]"))), + ), + Query( + Alias("A10"), + Plan( + ( + Alias("C10"), + Alias("B10"), + Alias("B10"), + Produces((Alias("B10"),)), + ) + ), + ), + ), + ), + Produces((Alias("A10"),)), + ) + ) + + result = lift_subqueries(plan) + assert result == expected