Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: mixed-line-ending
- id: name-tests-test
args: ["--pytest-test-first"]
- id: no-commit-to-branch

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.8
hooks:
- id: ruff
args: ["--fix"]
types_or: [ python, pyi, jupyter ]
- id: ruff-format
types_or: [ python, pyi, jupyter ]
7 changes: 6 additions & 1 deletion src/finch/autoschedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
Subquery,
Table,
)
from .optimize import optimize, propagate_map_queries
from .optimize import (
optimize,
propagate_map_queries,
lift_subqueries,
)
from ..symbolic import PostOrderDFS, PostWalk, PreWalk

__all__ = [
Expand All @@ -34,6 +38,7 @@
"Table",
"optimize",
"propagate_map_queries",
"lift_subqueries",
"PostOrderDFS",
"PostWalk",
"PreWalk",
Expand Down
51 changes: 47 additions & 4 deletions src/finch/autoschedule/optimize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,57 @@
from .compiler import LogicCompiler
from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
from ..finch_logic import (
Aggregate,
Alias,
LogicNode,
MapJoin,
Plan,
Produces,
Query,
Subquery,
)
from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite


def optimize(prgm: LogicNode) -> LogicNode:
# ...
return propagate_map_queries(prgm)
prgm = lift_subqueries(prgm)
prgm = propagate_map_queries(prgm)
return prgm


def get_productions(root: LogicNode) -> LogicNode:
def _lift_subqueries_expr(node: LogicNode, bindings: dict) -> LogicNode:
match node:
case Subquery(lhs, arg):
if lhs not in bindings:
arg_2 = _lift_subqueries_expr(arg, bindings)
bindings[lhs] = arg_2
return lhs
case any if any.is_expr():
return any.make_term(
any.head(),
*map(lambda x: _lift_subqueries_expr(x, bindings), any.children()),
)
case _:
return node


def lift_subqueries(node: LogicNode) -> LogicNode:
match node:
case Plan(bodies):
return Plan(tuple(map(lift_subqueries, bodies)))
case Query(lhs, rhs):
bindings = {}
rhs_2 = _lift_subqueries_expr(rhs, bindings)
return Plan(
(*[Query(lhs, rhs) for lhs, rhs in bindings.items()], Query(lhs, rhs_2))
)
case Produces() as p:
return p
case _:
raise Exception(f"Invalid node: {node}")


def _get_productions(root: LogicNode) -> list[LogicNode]:
for node in PostOrderDFS(root):
if isinstance(node, Produces):
return [arg for arg in PostOrderDFS(node) if isinstance(arg, Alias)]
Expand All @@ -22,7 +65,7 @@ def rule_agg_to_mapjoin(ex):
return MapJoin(op, (init, arg))

root = Rewrite(PostWalk(rule_agg_to_mapjoin))(root)
rets = get_productions(root)
rets = _get_productions(root)
props = {}
for node in PostOrderDFS(root):
match node:
Expand Down
57 changes: 41 additions & 16 deletions src/finch/finch_logic/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,53 @@
from __future__ import annotations
from dataclasses import dataclass
from itertools import product
from typing import Iterable, Any
import numpy as np
from ..finch_logic import *
from ..symbolic import Term
from ..algebra import *
from ..finch_logic import (
Immediate,
Deferred,
Field,
Alias,
Table,
MapJoin,
Aggregate,
Query,
Plan,
Produces,
Subquery,
Relabel,
Reorder,
)
from ..algebra import return_type, fill_value, element_type, fixpoint_type


@dataclass(eq=True, frozen=True)
class TableValue():
class TableValue:
tns: Any
idxs: Iterable[Any]

def __post_init__(self):
if isinstance(self.tns, TableValue):
raise ValueError("The tensor (tns) cannot be a TableValue")

from typing import Any, Type

class FinchLogicInterpreter:
def __init__(self, *, make_tensor=np.full):
self.verbose = False
self.bindings = {}
self.make_tensor = make_tensor # Added make_tensor argument

def __call__(self, node):
# Example implementation for evaluating an expression
if self.verbose:
print(f"Evaluating: {expression}")
print(f"Evaluating: {node}")
# Placeholder for actual logic
head = node.head()
if head == Immediate:
return node.val
elif head == Deferred:
raise ValueError("The interpreter cannot evaluate a deferred node, a compiler might generate code for it")
raise ValueError(
"The interpreter cannot evaluate a deferred node, a compiler might generate code for it"
)
elif head == Field:
raise ValueError("Fields cannot be used in expressions")
elif head == Alias:
Expand Down Expand Up @@ -59,7 +76,9 @@ def __call__(self, node):
dims[idx] = dim
fill_val = op(*[fill_value(arg.tns) for arg in args])
dtype = return_type(op, *[element_type(arg.tns) for arg in args])
result = self.make_tensor(tuple(dims[idx] for idx in idxs), fill_val, dtype = dtype)
result = self.make_tensor(
tuple(dims[idx] for idx in idxs), fill_val, dtype=dtype
)
for crds in product(*[range(dims[idx]) for idx in idxs]):
idx_crds = {idx: crd for (idx, crd) in zip(idxs, crds)}
vals = [arg.tns[*[idx_crds[idx] for idx in arg.idxs]] for arg in args]
Expand All @@ -74,10 +93,16 @@ def __call__(self, node):
init = node.init.val
op = node.op.val
dtype = fixpoint_type(op, init, element_type(arg.tns))
new_shape = [dim for (dim, idx) in zip(arg.tns.shape, arg.idxs) if not idx in node.idxs]
new_shape = [
dim
for (dim, idx) in zip(arg.tns.shape, arg.idxs)
if idx not in node.idxs
]
result = self.make_tensor(new_shape, init, dtype=dtype)
for crds in product(*[range(dim) for dim in arg.tns.shape]):
out_crds = [crd for (crd, idx) in zip(crds, arg.idxs) if not idx in node.idxs]
out_crds = [
crd for (crd, idx) in zip(crds, arg.idxs) if idx not in node.idxs
]
result[*out_crds] = op(result[*out_crds], arg.tns[*crds])
return TableValue(result, [idx for idx in arg.idxs if idx not in node.idxs])
elif head == Relabel:
Expand All @@ -92,7 +117,7 @@ def __call__(self, node):
raise ValueError("Trying to drop a dimension that is not 1")
arg_dims = {idx: dim for idx, dim in zip(arg.idxs, arg.tns.shape)}
dims = [arg_dims.get(idx, 1) for idx in node.idxs]
result = self.make_tensor(dims, fill_value(arg.tns), dtype = arg.tns.dtype)
result = self.make_tensor(dims, fill_value(arg.tns), dtype=arg.tns.dtype)
for crds in product(*[range(dim) for dim in dims]):
node_crds = {idx: crd for (idx, crd) in zip(node.idxs, crds)}
in_crds = [node_crds.get(idx, 0) for idx in arg.idxs]
Expand All @@ -110,8 +135,8 @@ def __call__(self, node):
elif head == Produces:
return tuple(self(arg).tns for arg in node.args)
elif head == Subquery:
if not node.lhs in self.bindings:
self.bindings[node.lhs] = self(node.rhs)
if node.lhs not in self.bindings:
self.bindings[node.lhs] = self(node.arg)
return self.bindings[node.lhs]
else:
raise ValueError(f"Unknown expression type: {head}")
raise ValueError(f"Unknown expression type: {head}")
Loading