diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6ae13601..d2e863c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,9 +1,10 @@ +defaults: + run: + shell: bash -el {0} + name: CI jobs: test: - defaults: - run: - shell: bash -el {0} strategy: matrix: os: [ubuntu-latest] @@ -17,7 +18,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Checkout Repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} @@ -29,7 +30,25 @@ jobs: - name: Run tests run: | poetry run pytest --junit-xml=test-${{ matrix.os }}-Python-${{ matrix.python }}.xml - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v5 + mypy: + runs-on: ubuntu-latest + steps: + - name: Checkout Repo + uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.12 + - name: Install Poetry + uses: snok/install-poetry@v1 + - name: Install package + run: | + poetry install --with test + - name: Run tests + run: | + poetry run mypy . + + on: # Trigger the workflow on push or pull request, diff --git a/pixi.toml b/pixi.toml index fdea7989..69667f54 100644 --- a/pixi.toml +++ b/pixi.toml @@ -24,6 +24,7 @@ sparse = "*" numba = ">=0.60" scipy = "*" numpy = "==2.*" +mypy = ">=1.15.0,<2" [feature.test.tasks] test = { cmd = "pytest", depends-on = ["compile"] } diff --git a/pyproject.toml b/pyproject.toml index 37abaa6e..a9401a56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,28 @@ pytest-cov = "^4.1.0" sparse = "^0.16.0" scipy = "^1.7" numba = "^0.61.0" +mypy = "^1.15.0" [build-system] requires = ["poetry-core>=1.0.8"] build-backend = "poetry.core.masonry.api" + +[tool.ruff.lint] +select = ["F", "E", "W", "I", "B", "UP", "YTT", "BLE", "C4", "T10", "ISC", "ICN", "PIE", "PYI", "RSE", "RET", "SIM", "PGH", "FLY", "NPY", "PERF"] + +[tool.ruff.lint.isort.sections] +numpy = ["numpy", "numpy.*", "scipy", "scipy.*"] + +[tool.ruff.format] +quote-style = "double" +docstring-code-format = true + +[tool.ruff.lint.isort] +section-order = [ + "future", + "standard-library", + "numpy", + "third-party", + "first-party", + "local-folder", +] diff --git a/src/finch/__init__.py b/src/finch/__init__.py index 4b628146..8e8276a2 100644 --- a/src/finch/__init__.py +++ b/src/finch/__init__.py @@ -1,2 +1,31 @@ from . import finch_logic -from .interface import * \ No newline at end of file +from .interface import ( + compute, + elementwise, + expand_dims, + fuse, + fused, + identify, + lazy, + multiply, + permute_dims, + prod, + reduce, + squeeze, +) + +__all__ = [ + "lazy", + "compute", + "finch_logic", + "fuse", + "fused", + "permute_dims", + "expand_dims", + "squeeze", + "identify", + "reduce", + "elementwise", + "prod", + "multiply", +] diff --git a/src/finch/algebra/__init__.py b/src/finch/algebra/__init__.py index 1070fd19..e7a09201 100644 --- a/src/finch/algebra/__init__.py +++ b/src/finch/algebra/__init__.py @@ -1,4 +1,13 @@ -from .algebra import * +from .algebra import ( + element_type, + fill_value, + fixpoint_type, + init_value, + is_associative, + query_property, + register_property, + return_type, +) __all__ = [ "fill_value", @@ -9,4 +18,4 @@ "is_associative", "query_property", "register_property", -] \ No newline at end of file +] diff --git a/src/finch/algebra/algebra.py b/src/finch/algebra/algebra.py index 0b719a66..4b30e458 100644 --- a/src/finch/algebra/algebra.py +++ b/src/finch/algebra/algebra.py @@ -1,6 +1,3 @@ -from typing import Any, Type -from collections.abc import Hashable - """ Finch performs extensive rewriting and defining of functions. The Finch compiler is designed to inspect objects and functions defined by other @@ -22,7 +19,8 @@ ```python from finch import register_property -register_property(complex, '__add__', 'is_associative', lambda obj: True) + +register_property(complex, "__add__", "is_associative", lambda obj: True) ``` Finch includes a convenience functions to query each property as well, @@ -30,25 +28,31 @@ ```python from finch import query_property from operator import add -query_property(complex, '__add__', 'is_associative') + +query_property(complex, "__add__", "is_associative") # True is_associative(add, complex, complex) # True ``` -Properties can be inherited in the same way as methods. First we check whether properties have been defined for the object itself (in the case of functions), then we check For example, if you -register a property for a class, all subclasses of that class will inherit -that property. This allows you to define properties for a class and have -them automatically apply to all subclasses, without having to register the -property for each subclass individually. +Properties can be inherited in the same way as methods. First we check whether +properties have been defined for the object itself (in the case of functions), then we +check For example, if you register a property for a class, all subclasses of that class +will inherit that property. This allows you to define properties for a class and have +them automatically apply to all subclasses, without having to register the property for +each subclass individually. """ + import operator -from typing import Union +from collections.abc import Callable, Hashable +from typing import Any + import numpy as np -_properties = {} +_properties: dict[tuple[type | Hashable, str, str], Any] = {} + -def query_property(obj, attr, prop, *args): +def query_property(obj: type | Hashable, attr: str, prop: str, *args: Any) -> Any: """Queries a property of an attribute of an object or class. Properties can be overridden by calling register_property on the object or it's class. @@ -64,22 +68,31 @@ def query_property(obj, attr, prop, *args): Raises: NotImplementedError: If the property is not implemented for the given type. """ - if isinstance(obj, type): - T = obj - else: - if isinstance(obj, Hashable): - if (obj, attr, prop) in _properties: - return _properties[(obj, attr, prop)](obj, *args) + T = obj + if not isinstance(obj, Hashable): T = type(obj) - while True: - if (T, attr, prop) in _properties: - return _properties[(T, attr, prop)](obj, *args) - if T is object: - break - T = T.__base__ + to_query = {T} + to_query_new: set[type | Hashable] = set() + queried: set[type | Hashable] = set() + while len(to_query) != 0: + for o in to_query: + if o in queried: + continue + method = _properties.get((o, attr, prop), None) + if method is not None: + return method(obj, *args) + queried.add(o) + if not isinstance(o, type): + to_query_new.add(type(o)) + continue + to_query_new.update(o.__mro__) + to_query.clear() + to_query, to_query_new = to_query_new, to_query + raise NotImplementedError(f"Property {prop} not implemented for {type(obj)}") -def register_property(cls, attr, prop, f): + +def register_property(cls: type | Hashable, attr: str, prop: str, f: Callable) -> None: """Registers a property for a class or object. Args: @@ -90,6 +103,7 @@ def register_property(cls, attr, prop, f): """ _properties[(cls, attr, prop)] = f + def fill_value(arg: Any) -> Any: """The fill value for the given argument. The fill value is the default value for a tensor when it is created with a given shape and dtype, @@ -104,11 +118,15 @@ def fill_value(arg: Any) -> Any: Raises: NotImplementedError: If the fill value is not implemented for the given type. """ - return query_property(arg, '__self__', 'fill_value') + return query_property(arg, "__self__", "fill_value") -register_property(np.ndarray, '__self__', 'fill_value', lambda x: np.zeros((), dtype=x.dtype)[()]) -def element_type(arg: Any) -> Type: +register_property( + np.ndarray, "__self__", "fill_value", lambda x: np.zeros((), dtype=x.dtype)[()] +) + + +def element_type(arg: Any) -> type: """The element type of the given argument. The element type is the scalar type of the elements in a tensor, which may be different from the data type of the tensor. @@ -122,9 +140,16 @@ def element_type(arg: Any) -> Type: Raises: NotImplementedError: If the element type is not implemented for the given type. """ - return query_property(arg, '__self__', 'element_type') + return query_property(arg, "__self__", "element_type") + + +register_property( + np.ndarray, + "__self__", + "element_type", + lambda x: type(np.zeros((), dtype=x.dtype)[()]), +) -register_property(np.ndarray, '__self__', 'element_type', lambda x: type(np.zeros((), dtype=x.dtype)[()])) def return_type(op: Any, *args: Any) -> Any: """The return type of the given function on the given argument types. @@ -136,7 +161,8 @@ def return_type(op: Any, *args: Any) -> Any: Returns: The return type of op(*args: arg_types) """ - return query_property(op, '__call__', 'return_type', *args) + return query_property(op, "__call__", "return_type", *args) + StableNumber = (np.number, bool, int, float, complex) @@ -158,17 +184,31 @@ def return_type(op: Any, *args: Any) -> Any: } for op, (meth, rmeth) in _reflexive_operators.items(): - register_property(op, '__call__', 'return_type', lambda op, a, b: query_property(a, meth, 'return_type', b) if hasattr(a, meth) else query_property(b, rmeth, 'return_type', a)), + register_property( + op, + "__call__", + "return_type", + lambda op, a, b, meth=meth, rmeth=rmeth: query_property( + a, meth, "return_type", b + ) + if hasattr(a, meth) + else query_property(b, rmeth, "return_type", a), + ) + def _return_type(meth): def _return_type_closure(a, b): if issubclass(b, StableNumber): return type(getattr(a(True), meth)(b(True))) - else: - raise TypeError(f"Unsupported operand type for {type(a)}.{meth}: {type(b)}") + raise TypeError( + f"Unsupported operand type for {type(a)}.{meth}: {type(b)}" + ) + return _return_type_closure + for T in StableNumber: - register_property(T, meth, 'return_type', _return_type(meth)) - register_property(T, rmeth, 'return_type', _return_type(rmeth)) + register_property(T, meth, "return_type", _return_type(meth)) + register_property(T, rmeth, "return_type", _return_type(rmeth)) + def is_associative(op: Any) -> bool: """Returns whether the given function is associative, that is, whether the @@ -176,16 +216,18 @@ def is_associative(op: Any) -> bool: Args: op: The function to check. - + Returns: True if the function can be proven to be associative, False otherwise. """ - return query_property(op, '__call__', 'is_associative') + return query_property(op, "__call__", "is_associative") + for op in [operator.add, operator.mul, operator.and_, operator.xor, operator.or_]: - register_property(op, '__call__', 'is_associative', lambda op: True) + register_property(op, "__call__", "is_associative", lambda op: True) + -def fixpoint_type(op: Any, z: Any, T: Type) -> Type: +def fixpoint_type(op: Any, z: Any, T: type) -> type: """Determines the fixpoint type after repeated calling the given operation. Args: @@ -200,9 +242,12 @@ def fixpoint_type(op: Any, z: Any, T: Type) -> Type: R = type(z) while R not in S: S.add(R) - R = return_type(op, type(z), T) # Assuming `op` is a callable that takes `z` and `T` as arguments + R = return_type( + op, type(z), T + ) # Assuming `op` is a callable that takes `z` and `T` as arguments return R + def init_value(op, arg) -> Any: """Returns the initial value for a reduction operation on the given type. @@ -214,17 +259,24 @@ def init_value(op, arg) -> Any: The initial value for the given operation and type. Raises: - NotImplementedError: If the initial value is not implemented for the given type and operation. + NotImplementedError: If the initial value is not implemented for the given type + and operation. """ - return query_property(op, '__call__', 'init_value', arg) + return query_property(op, "__call__", "init_value", arg) + for op in [operator.add, operator.mul, operator.and_, operator.xor, operator.or_]: (meth, rmeth) = _reflexive_operators[op] - register_property(op, '__call__', 'init_value', lambda op, arg: query_property(arg, meth, 'init_value', arg)) + register_property( + op, + "__call__", + "init_value", + lambda op, arg, meth=meth: query_property(arg, meth, "init_value", arg), + ) for T in StableNumber: - register_property(T, '__add__', 'init_value', lambda a, b: a(False)) - register_property(T, '__mul__', 'init_value', lambda a, b: a(True)) - register_property(T, '__and__', 'init_value', lambda a, b: a(True)) - register_property(T, '__xor__', 'init_value', lambda a, b: a(False)) - register_property(T, '__or__', 'init_value', lambda a, b: a(False)) \ No newline at end of file + register_property(T, "__add__", "init_value", lambda a, b: a(False)) + register_property(T, "__mul__", "init_value", lambda a, b: a(True)) + register_property(T, "__and__", "init_value", lambda a, b: a(True)) + register_property(T, "__xor__", "init_value", lambda a, b: a(False)) + register_property(T, "__or__", "init_value", lambda a, b: a(False)) diff --git a/src/finch/autoschedule/__init__.py b/src/finch/autoschedule/__init__.py index 502bfde7..23e15d57 100644 --- a/src/finch/autoschedule/__init__.py +++ b/src/finch/autoschedule/__init__.py @@ -14,13 +14,13 @@ Subquery, Table, ) +from ..symbolic import PostOrderDFS, PostWalk, PreWalk from .optimize import ( + lift_subqueries, optimize, propagate_fields, propagate_map_queries, - lift_subqueries, ) -from ..symbolic import PostOrderDFS, PostWalk, PreWalk __all__ = [ "Aggregate", diff --git a/src/finch/autoschedule/compiler.py b/src/finch/autoschedule/compiler.py index a97f9c4e..8c1e28f3 100644 --- a/src/finch/autoschedule/compiler.py +++ b/src/finch/autoschedule/compiler.py @@ -1,12 +1,12 @@ -from collections.abc import Hashable from textwrap import dedent -from typing import Any +from typing import TypeVar from ..finch_logic import ( Alias, Deferred, Field, Immediate, + LogicExpression, LogicNode, MapJoin, Query, @@ -16,29 +16,43 @@ Subquery, Table, ) +from ..symbolic import Term +T = TypeVar("T", bound="LogicNode") -def get_or_insert(dictionary: dict[Hashable, Any], key: Hashable, default: Any) -> Any: - if key in dictionary: - return dictionary[key] - dictionary[key] = default - return default +def get_or_insert(dictionary: dict[str, T], key: str, default: T) -> T: + return dictionary.setdefault(key, default) -def get_structure(node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode]) -> LogicNode: + +def get_structure( + node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode] +) -> LogicNode: match node: case Field(name): return get_or_insert(fields, name, Immediate(len(fields) + len(aliases))) case Alias(name): return get_or_insert(aliases, name, Immediate(len(fields) + len(aliases))) case Subquery(Alias(name) as lhs, arg): - if name in aliases: - return aliases[name] - return Subquery(get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases)) + alias = aliases.get(name) + if alias is not None: + return alias + in_lhs = get_structure(lhs, fields, aliases) + assert isinstance(in_lhs, LogicExpression) + in_arg = get_structure(arg, fields, aliases) + assert isinstance(in_arg, LogicExpression) + return Subquery(in_lhs, in_arg) case Table(tns, idxs): - return Table(Immediate(type(tns.val)), tuple(get_structure(idx, fields, aliases) for idx in idxs)) - case any if any.is_tree(): - return any.from_arguments(*[get_structure(arg, fields, aliases) for arg in any.get_arguments()]) + assert all(isinstance(idx, Field) for idx in idxs) + return Table( + Immediate(type(tns.val)), + tuple(get_structure(idx, fields, aliases) for idx in idxs), # type: ignore[misc] + ) + case LogicExpression() as expr: + return expr.make_term( + expr.head(), + *[get_structure(arg, fields, aliases) for arg in expr.children()], + ) case _: return node @@ -53,7 +67,11 @@ def __call__(self, ex): return f":({val}({','.join([self(arg) for arg in args])}))" case Reorder(Relabel(Alias(name), idxs_1), idxs_2): self.bound_idxs.append(idxs_1) - return f":({name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])" + return ( + f":({name}" + + ",".join([idx.name if idx in idxs_2 else 1 for idx in idxs_1]) + + ")" + ) case Reorder(Immediate(val), _): return val case Immediate(val): @@ -68,7 +86,7 @@ def compile_pointwise_logic(ex: LogicNode) -> tuple: return (code, ctx.bound_idxs) -def compile_logic_constant(ex: LogicNode) -> str: +def compile_logic_constant(ex: Term) -> str: match ex: case Immediate(val): return val @@ -94,15 +112,23 @@ class LogicLowerer: def __init__(self, mode: str = "fast"): self.mode = mode - def __call__(self, ex: LogicNode): + def __call__(self, ex: Term) -> str: match ex: case Query(Alias(name), Table(tns, _)): return f":({name} = {compile_logic_constant(tns)})" - case Query(Alias(_) as lhs, Reformat(tns, Reorder(Relabel(Alias(_) as arg, idxs_1), idxs_2))): - loop_idxs = [idx.name for idx in with_subsequence(intersect(idxs_1, idxs_2), idxs_2)] + case Query( + Alias(_) as lhs, + Reformat(tns, Reorder(Relabel(Alias(_) as arg, idxs_1), idxs_2)), + ): + loop_idxs = [ + idx.name + for idx in with_subsequence(intersect(idxs_1, idxs_2), idxs_2) + ] lhs_idxs = [idx.name for idx in idxs_2] - (rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2)) + (rhs, rhs_idxs) = compile_pointwise_logic( + Reorder(Relabel(arg, idxs_1), idxs_2) + ) body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})" for idx in loop_idxs: if Field(idx) in rhs_idxs: @@ -132,6 +158,6 @@ class LogicCompiler: def __init__(self): self.ll = LogicLowerer() - def __call__(self, prgm): + def __call__(self, prgm: Term) -> str: # prgm = format_queries(prgm, True) # noqa: F821 return self.ll(prgm) diff --git a/src/finch/autoschedule/executor.py b/src/finch/autoschedule/executor.py index 148aed28..8bc15eca 100644 --- a/src/finch/autoschedule/executor.py +++ b/src/finch/autoschedule/executor.py @@ -1,5 +1,5 @@ -from .compiler import LogicCompiler from ..symbolic import gensym +from .compiler import LogicCompiler class LogicExecutor: diff --git a/src/finch/autoschedule/optimize.py b/src/finch/autoschedule/optimize.py index 06f7f22b..5ee40089 100644 --- a/src/finch/autoschedule/optimize.py +++ b/src/finch/autoschedule/optimize.py @@ -1,9 +1,10 @@ -from typing import Any, Iterable +from collections.abc import Iterable -from .compiler import LogicCompiler from ..finch_logic import ( Aggregate, Alias, + Field, + LogicExpression, LogicNode, MapJoin, Plan, @@ -13,26 +14,28 @@ Subquery, ) from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite +from .compiler import LogicCompiler def optimize(prgm: LogicNode) -> LogicNode: # ... prgm = lift_subqueries(prgm) - prgm = propagate_map_queries(prgm) - return prgm + return propagate_map_queries(prgm) -def _lift_subqueries_expr(node: LogicNode, bindings: dict) -> LogicNode: +def _lift_subqueries_expr( + node: LogicNode, bindings: dict[LogicNode, LogicNode] +) -> LogicNode: match node: case Subquery(lhs, arg): if lhs not in bindings: arg_2 = _lift_subqueries_expr(arg, bindings) bindings[lhs] = arg_2 return lhs - case any if any.is_expr(): - return any.make_term( - any.head(), - *map(lambda x: _lift_subqueries_expr(x, bindings), any.children()), + case LogicExpression() as expr: + return expr.make_term( + expr.head(), + *tuple(_lift_subqueries_expr(x, bindings) for x in expr.children()), ) case _: return node @@ -43,7 +46,7 @@ def lift_subqueries(node: LogicNode) -> LogicNode: case Plan(bodies): return Plan(tuple(map(lift_subqueries, bodies))) case Query(lhs, rhs): - bindings = {} + bindings: dict[LogicNode, LogicNode] = {} rhs_2 = _lift_subqueries_expr(rhs, bindings) return Plan( (*[Query(lhs, rhs) for lhs, rhs in bindings.items()], Query(lhs, rhs_2)) @@ -93,20 +96,21 @@ def rule_2(ex): def _propagate_fields( - root: LogicNode, fields: dict[LogicNode, Iterable[LogicNode]] + root: LogicNode, fields: dict[LogicNode, Iterable[Field]] ) -> LogicNode: match root: case Plan(bodies): return Plan(tuple(_propagate_fields(b, fields) for b in bodies)) case Query(lhs, rhs): rhs = _propagate_fields(rhs, fields) + assert isinstance(rhs, LogicExpression) fields[lhs] = rhs.get_fields() return Query(lhs, rhs) case Alias() as a: return Relabel(a, tuple(fields[a])) - case node if node.is_expr(): - return node.make_term( - node.head(), *[_propagate_fields(c, fields) for c in node.children()] + case LogicExpression() as expr: + return expr.make_term( + expr.head(), *[_propagate_fields(c, fields) for c in expr.children()] ) case node: return node @@ -120,6 +124,6 @@ class DefaultLogicOptimizer: def __init__(self, ctx: LogicCompiler): self.ctx = ctx - def __call__(self, prgm: LogicNode): + def __call__(self, prgm: LogicNode) -> str: prgm = optimize(prgm) return self.ctx(prgm) diff --git a/src/finch/finch_logic/__init__.py b/src/finch/finch_logic/__init__.py index e85b1350..824bc9a3 100644 --- a/src/finch/finch_logic/__init__.py +++ b/src/finch/finch_logic/__init__.py @@ -1,8 +1,23 @@ -from .nodes import * -from .interpreter import FinchLogicInterpreter +from .nodes import ( + Aggregate, + Alias, + Deferred, + Field, + Immediate, + LogicExpression, + LogicNode, + MapJoin, + Plan, + Produces, + Query, + Reformat, + Relabel, + Reorder, + Subquery, + Table, +) __all__ = [ - "FinchLogicInterpreter", "LogicNode", "Deferred", "Aggregate", @@ -13,8 +28,12 @@ "Plan", "Produces", "Query", + "Subquery", + "Reformat", "Relabel", "Reorder", + "Reformat", "Subquery", "Table", -] \ No newline at end of file + "LogicExpression", +] diff --git a/src/finch/finch_logic/interpreter.py b/src/finch/finch_logic/interpreter.py index 13d6a3b0..57f57523 100644 --- a/src/finch/finch_logic/interpreter.py +++ b/src/finch/finch_logic/interpreter.py @@ -1,23 +1,26 @@ +from collections.abc import Iterable from dataclasses import dataclass from itertools import product -from typing import Iterable, Any +from typing import Any + import numpy as np + +from ..algebra import element_type, fill_value, fixpoint_type, return_type from ..finch_logic import ( - Immediate, + Aggregate, + Alias, Deferred, Field, - Alias, - Table, + Immediate, MapJoin, - Aggregate, - Query, Plan, Produces, - Subquery, + Query, Relabel, Reorder, + Subquery, + Table, ) -from ..algebra import return_type, fill_value, element_type, fixpoint_type @dataclass(eq=True, frozen=True) @@ -41,102 +44,102 @@ def __call__(self, node): if self.verbose: print(f"Evaluating: {node}") # Placeholder for actual logic - head = node.head() - if head == Immediate: - return node.val - elif head == Deferred: - raise ValueError( - "The interpreter cannot evaluate a deferred node, a compiler might generate code for it" - ) - elif head == Field: - raise ValueError("Fields cannot be used in expressions") - elif head == Alias: - if node in self.bindings: - return self.bindings[node] - else: - raise ValueError(f"undefined tensor alias {node.val}") - elif head == Table: - if node.tns.head() != Immediate: - raise ValueError("The table data must be Immediate") - return TableValue(node.tns.val, node.idxs) - elif head == MapJoin: - if node.op.head() != Immediate: - raise ValueError("The mapjoin operator must be Immediate") - op = node.op.val - args = list(map(self, node.args)) - dims = {} - idxs = [] - for arg in args: - for idx, dim in zip(arg.idxs, arg.tns.shape): - if idx in dims: - if dims[idx] != dim: - raise ValueError("Dimensions mismatched in map") - else: - idxs.append(idx) - dims[idx] = dim - fill_val = op(*[fill_value(arg.tns) for arg in args]) - dtype = return_type(op, *[element_type(arg.tns) for arg in args]) - result = self.make_tensor( - tuple(dims[idx] for idx in idxs), fill_val, dtype=dtype - ) - for crds in product(*[range(dims[idx]) for idx in idxs]): - idx_crds = {idx: crd for (idx, crd) in zip(idxs, crds)} - vals = [arg.tns[*[idx_crds[idx] for idx in arg.idxs]] for arg in args] - result[*crds] = op(*vals) - return TableValue(result, idxs) - elif head == Aggregate: - if node.op.head() != Immediate: - raise ValueError("The aggregate operator must be Immediate") - if node.init.head() != Immediate: - raise ValueError("The aggregate initial value must be Immediate") - arg = self(node.arg) - init = node.init.val - op = node.op.val - dtype = fixpoint_type(op, init, element_type(arg.tns)) - new_shape = [ - dim - for (dim, idx) in zip(arg.tns.shape, arg.idxs) - if idx not in node.idxs - ] - result = self.make_tensor(new_shape, init, dtype=dtype) - for crds in product(*[range(dim) for dim in arg.tns.shape]): - out_crds = [ - crd for (crd, idx) in zip(crds, arg.idxs) if idx not in node.idxs - ] - result[*out_crds] = op(result[*out_crds], arg.tns[*crds]) - return TableValue(result, [idx for idx in arg.idxs if idx not in node.idxs]) - elif head == Relabel: - arg = self(node.arg) - if len(arg.idxs) != len(node.idxs): - raise ValueError("The number of indices in the relabel must match") - return TableValue(arg.tns, node.idxs) - elif head == Reorder: - arg = self(node.arg) - for idx, dim in zip(arg.idxs, arg.tns.shape): - if idx not in node.idxs and dim != 1: - raise ValueError("Trying to drop a dimension that is not 1") - arg_dims = {idx: dim for idx, dim in zip(arg.idxs, arg.tns.shape)} - dims = [arg_dims.get(idx, 1) for idx in node.idxs] - result = self.make_tensor(dims, fill_value(arg.tns), dtype=arg.tns.dtype) - for crds in product(*[range(dim) for dim in dims]): - node_crds = {idx: crd for (idx, crd) in zip(node.idxs, crds)} - in_crds = [node_crds.get(idx, 0) for idx in arg.idxs] - result[*crds] = arg.tns[*in_crds] - return TableValue(result, node.idxs) - elif head == Query: - rhs = self(node.rhs) - self.bindings[node.lhs] = rhs - return (rhs,) - elif head == Plan: - res = () - for body in node.bodies: - res = self(body) - return res - elif head == Produces: - return tuple(self(arg).tns for arg in node.args) - elif head == Subquery: - if node.lhs not in self.bindings: - self.bindings[node.lhs] = self(node.arg) - return self.bindings[node.lhs] - else: - raise ValueError(f"Unknown expression type: {head}") + match node: + case Immediate(val): + return val + case Deferred(_): + raise ValueError( + "The interpreter cannot evaluate a deferred node, a compiler might " + "generate code for it" + ) + case Field(_): + raise ValueError("Fields cannot be used in expressions") + case Alias(_): + alias = self.bindings.get(node, None) + if alias is None: + raise ValueError(f"undefined tensor alias {node}") + return alias + case Table(Immediate(val), idxs): + return TableValue(val, idxs) + case MapJoin(Immediate(op), args): + args = tuple(self(a) for a in args) + dims = {} + idxs = [] + for arg in args: + for idx, dim in zip(arg.idxs, arg.tns.shape, strict=True): + if idx in dims: + if dims[idx] != dim: + raise ValueError("Dimensions mismatched in map") + else: + idxs.append(idx) + dims[idx] = dim + fill_val = op(*[fill_value(arg.tns) for arg in args]) + dtype = return_type(op, *[element_type(arg.tns) for arg in args]) + result = self.make_tensor( + tuple(dims[idx] for idx in idxs), fill_val, dtype=dtype + ) + for crds in product(*[range(dims[idx]) for idx in idxs]): + idx_crds = dict(zip(idxs, crds, strict=True)) + vals = [ + arg.tns[*[idx_crds[idx] for idx in arg.idxs]] for arg in args + ] + result[*crds] = op(*vals) + return TableValue(result, idxs) + case Aggregate(op, init, arg, idxs): + arg = self(arg) + init = init.val + op = op.val + dtype = fixpoint_type(op, init, element_type(arg.tns)) + new_shape = tuple( + dim + for (dim, idx) in zip(arg.tns.shape, arg.idxs, strict=True) + if idx not in node.idxs + ) + result = self.make_tensor(new_shape, init, dtype=dtype) + for crds in product(*[range(dim) for dim in arg.tns.shape]): + out_crds = [ + crd + for (crd, idx) in zip(crds, arg.idxs, strict=True) + if idx not in node.idxs + ] + result[*out_crds] = op(result[*out_crds], arg.tns[*crds]) + return TableValue( + result, [idx for idx in arg.idxs if idx not in node.idxs] + ) + case Relabel(arg, idxs): + arg = self(arg) + if len(arg.idxs) != len(idxs): + raise ValueError("The number of indices in the relabel must match") + return TableValue(arg.tns, idxs) + case Reorder(arg, idxs): + arg = self(arg) + for idx, dim in zip(arg.idxs, arg.tns.shape, strict=True): + if idx not in idxs and dim != 1: + raise ValueError("Trying to drop a dimension that is not 1") + arg_dims = dict(zip(arg.idxs, arg.tns.shape, strict=True)) + dims = [arg_dims.get(idx, 1) for idx in idxs] + result = self.make_tensor( + dims, fill_value(arg.tns), dtype=arg.tns.dtype + ) + for crds in product(*[range(dim) for dim in dims]): + node_crds = dict(zip(idxs, crds, strict=True)) + in_crds = [node_crds.get(idx, 0) for idx in arg.idxs] + result[*crds] = arg.tns[*in_crds] + return TableValue(result, idxs) + case Query(lhs, rhs): + rhs = self(rhs) + self.bindings[lhs] = rhs + return (rhs,) + case Plan(bodies): + res = () + for body in bodies: + res = self(body) + return res + case Produces(args): + return tuple(self(arg).tns for arg in args) + case Subquery(lhs, arg): + if lhs not in self.bindings: + self.bindings[lhs] = self(arg) + return self.bindings[lhs] + case _: + raise ValueError(f"Unknown expression type: {type(node)}") diff --git a/src/finch/finch_logic/nodes.py b/src/finch/finch_logic/nodes.py index 0f7dce49..38e8cfdd 100644 --- a/src/finch/finch_logic/nodes.py +++ b/src/finch/finch_logic/nodes.py @@ -1,8 +1,34 @@ +from __future__ import annotations + from abc import abstractmethod +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Iterable +from typing import Any, Never, Self, TypeVar + from ..symbolic import Term +__all__ = [ + "LogicNode", + "LogicExpression", + "Immediate", + "Deferred", + "Field", + "Alias", + "Table", + "MapJoin", + "Aggregate", + "Reorder", + "Relabel", + "Reformat", + "Subquery", + "Query", + "Produces", + "Plan", +] + + +T = TypeVar("T", bound="LogicNode") + @dataclass(eq=True, frozen=True) class LogicNode(Term): @@ -16,39 +42,40 @@ class LogicNode(Term): differentiated by a `FinchLogic.LogicNodeKind` enum. """ - @staticmethod - @abstractmethod - def is_expr(): - """Determines if the node is expresion.""" - ... - @staticmethod @abstractmethod def is_stateful(): """Determines if the node is stateful.""" - ... @classmethod - def head(cls): + def head(cls) -> type[Self]: """Returns the head of the node.""" return cls - def children(self): - """Returns the children of the node.""" - raise Exception(f"`children` isn't supported for {self.__class__}.") - - def get_fields(self) -> Iterable["LogicNode"]: - """Returns fields of the node.""" - raise Exception(f"`fields` isn't supported for {self.__class__}.") - @classmethod - def make_term(cls, head: type, *args: Any) -> "LogicNode": + def make_term(cls, head: Callable[..., Self], *args: Term) -> Self: """Creates a term with the given head and arguments.""" return head(*args) @dataclass(eq=True, frozen=True) -class Immediate(LogicNode): +class LogicExpression(LogicNode): + @abstractmethod + def get_fields(self) -> list[Field]: + """Get this node's fields.""" + + @abstractmethod + def children(self) -> list[LogicNode]: + """Get this node's children.""" + + +@dataclass(eq=True, frozen=True) +class LogicLeaf(LogicNode): + pass + + +@dataclass(eq=True, frozen=True) +class Immediate(LogicLeaf): """ Represents a logical AST expression for the literal value `val`. @@ -58,23 +85,27 @@ class Immediate(LogicNode): val: Any - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return False - @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def get_fields(self): + @property + def fill_value(self): + from ..algebra import fill_value + + return fill_value(self) + + def children(self) -> Never: + raise TypeError(f"`{type(self).__name__}` doesn't support `.children()`.") + + def get_fields(self) -> tuple[Field, ...]: """Returns fields of the node.""" - return [] + return () @dataclass(eq=True, frozen=True) -class Deferred(LogicNode): +class Deferred(LogicLeaf): """ Represents a logical AST expression for an expression `ex` of type `type`, yet to be evaluated. @@ -87,23 +118,18 @@ class Deferred(LogicNode): ex: Any type_: Any - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return False - @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[Any]: """Returns the children of the node.""" - return [self.val, self.type_] + return [self.ex, self.type_] @dataclass(eq=True, frozen=True) -class Field(LogicNode): +class Field(LogicLeaf): """ Represents a logical AST expression for a field named `name`. Fields are used to name the dimensions of a tensor. The named @@ -115,23 +141,18 @@ class Field(LogicNode): name: str - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return False - @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[str]: """Returns the children of the node.""" return [self.name] @dataclass(eq=True, frozen=True) -class Alias(LogicNode): +class Alias(LogicLeaf): """ Represents a logical AST expression for an alias named `name`. Aliases are used to refer to tables in the program. @@ -142,60 +163,55 @@ class Alias(LogicNode): name: str - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return False - @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[str]: """Returns the children of the node.""" return [self.name] @dataclass(eq=True, frozen=True) -class Table(LogicNode): +class Table(LogicExpression): """ - Represents a logical AST expression for a tensor object `tns`, indexed by fields `idxs...`. - A table is a tensor with named dimensions. + Represents a logical AST expression for a tensor object `tns`, indexed by fields + `idxs...`. A table is a tensor with named dimensions. Attributes: tns: The tensor object. idxs: The fields indexing the tensor. """ - tns: LogicNode - idxs: tuple[LogicNode] - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + tns: Immediate + idxs: tuple[Field, ...] @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[LogicLeaf]: # type: ignore[override] """Returns the children of the node.""" return [self.tns, *self.idxs] - def get_fields(self): + def get_fields(self) -> list[Field]: """Returns fields of the node.""" - return self.idxs + return [*self.idxs] @classmethod - def make_term(cls, head, tns, *idxs): + def make_term( # type: ignore[override] + cls, + head: Callable[[Immediate, tuple[Field, ...]], Self], + tns: Immediate, + *idxs: Field, + ) -> Self: return head(tns, idxs) @dataclass(eq=True, frozen=True) -class MapJoin(LogicNode): +class MapJoin(LogicExpression): """ Represents a logical AST expression for mapping the function `op` across `args...`. Dimensions which are not present are broadcasted. Dimensions which are @@ -207,38 +223,41 @@ class MapJoin(LogicNode): args: The arguments to map the function across. """ - op: LogicNode - args: tuple[LogicNode] - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + op: Immediate + args: tuple[LogicExpression, ...] @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [self.op, *self.args] - def get_fields(self): + def get_fields(self) -> list[Field]: """Returns fields of the node.""" # (mtsokol) I'm not sure if this comment still applies - the order is preserved. # TODO: this is wrong here: the overall order should at least be concordant with # the args if the args are concordant - fields = [f for fs in map(lambda x: x.get_fields(), self.args) for f in fs] - return list(dict.fromkeys(fields)) + fs: list[Field] = [] + for arg in self.args: + fs.extend(arg.get_fields()) + + return list(dict.fromkeys(fs)) @classmethod - def make_term(cls, head, op, *args): - return head(op, args) + def make_term( # type: ignore[override] + cls, + head: Callable[[Immediate, tuple[LogicExpression, ...]], Self], + op: Immediate, + *args: LogicExpression, + ) -> Self: + return head(op, tuple(args)) @dataclass(eq=True, frozen=True) -class Aggregate(LogicNode): +class Aggregate(LogicExpression): """ Represents a logical AST statement that reduces `arg` using `op`, starting with `init`. `idxs` are the dimensions to reduce. May happen in any order. @@ -250,40 +269,44 @@ class Aggregate(LogicNode): idxs: The dimensions to reduce. """ - op: LogicNode - init: LogicNode - arg: LogicNode - idxs: tuple[LogicNode] - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + op: Immediate + init: Immediate + arg: LogicExpression + idxs: tuple[Field, ...] @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [self.op, self.init, self.arg, *self.idxs] - def get_fields(self): + def get_fields(self) -> list[Field]: """Returns fields of the node.""" return [field for field in self.arg.get_fields() if field not in self.idxs] @classmethod - def make_term(cls, head, op, init, arg, *idxs): + def make_term( # type: ignore[override] + cls, + head: Callable[ + [Immediate, Immediate, LogicExpression, tuple[Field, ...]], Self + ], + op: Immediate, + init: Immediate, + arg: LogicExpression, + *idxs: Field, + ) -> Self: return head(op, init, arg, idxs) @dataclass(eq=True, frozen=True) -class Reorder(LogicNode): +class Reorder(LogicExpression): """ - Represents a logical AST statement that reorders the dimensions of `arg` to be `idxs...`. - Dimensions known to be length 1 may be dropped. Dimensions that do not exist in - `arg` may be added. + Represents a logical AST statement that reorders the dimensions of `arg` to be + `idxs...`. Dimensions known to be length 1 may be dropped. Dimensions that do not + exist in `arg` may be added. Attributes: arg: The argument to reorder. @@ -291,35 +314,36 @@ class Reorder(LogicNode): """ arg: LogicNode - idxs: tuple[LogicNode] - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + idxs: tuple[Field, ...] @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [self.arg, *self.idxs] - def get_fields(self): + def get_fields(self) -> list[Field]: """Returns fields of the node.""" - return self.idxs + return [*self.idxs] @classmethod - def make_term(cls, head, arg, *idxs): - return head(arg, idxs) + def make_term( # type: ignore[override] + cls, + head: Callable[[LogicNode, tuple[Field, ...]], Self], + arg: LogicNode, + *idxs: Field, + ) -> Self: + return head(arg, tuple(idxs)) @dataclass(eq=True, frozen=True) -class Relabel(LogicNode): +class Relabel(LogicExpression): """ - Represents a logical AST statement that relabels the dimensions of `arg` to be `idxs...`. + Represents a logical AST statement that relabels the dimensions of `arg` to be + `idxs...`. Attributes: arg: The argument to relabel. @@ -327,29 +351,24 @@ class Relabel(LogicNode): """ arg: LogicNode - idxs: tuple[LogicNode] - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + idxs: tuple[Field, ...] @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [self.arg, *self.idxs] - def get_fields(self): + def get_fields(self) -> list[Field]: """Returns fields of the node.""" - return self.idxs + return [*self.idxs] @dataclass(eq=True, frozen=True) -class Reformat(LogicNode): +class Reformat(LogicExpression): """ Represents a logical AST statement that reformats `arg` into the tensor `tns`. @@ -358,33 +377,28 @@ class Reformat(LogicNode): arg: The argument to reformat. """ - tns: LogicNode - arg: LogicNode - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + tns: Immediate + arg: LogicExpression @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [self.tns, self.arg] - def get_fields(self): + def get_fields(self) -> list[Field]: """Returns fields of the node.""" return self.arg.get_fields() @dataclass(eq=True, frozen=True) -class Subquery(LogicNode): +class Subquery(LogicExpression): """ - Represents a logical AST statement that evaluates `rhs`, binding the result to `lhs`, - and returns `rhs`. + Represents a logical AST statement that evaluates `rhs`, binding the result to + `lhs`, and returns `rhs`. Attributes: lhs: The left-hand side of the binding. @@ -392,31 +406,27 @@ class Subquery(LogicNode): """ lhs: LogicNode - arg: LogicNode - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + arg: LogicExpression @staticmethod def is_stateful(): """Determines if the node is stateful.""" return False - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [self.lhs, self.arg] - def get_fields(self): + def get_fields(self) -> list[Field]: """Returns fields of the node.""" return self.arg.get_fields() @dataclass(eq=True, frozen=True) -class Query(LogicNode): +class Query(LogicExpression): """ - Represents a logical AST statement that evaluates `rhs`, binding the result to `lhs`. + Represents a logical AST statement that evaluates `rhs`, binding the result to + `lhs`. Attributes: lhs: The left-hand side of the binding. @@ -426,23 +436,21 @@ class Query(LogicNode): lhs: LogicNode rhs: LogicNode - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True - @staticmethod def is_stateful(): """Determines if the node is stateful.""" return True - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [self.lhs, self.rhs] + def get_fields(self) -> list[Field]: + raise NotImplementedError + @dataclass(eq=True, frozen=True) -class Produces(LogicNode): +class Produces(LogicExpression): """ Represents a logical AST statement that returns `args...` from the current plan. Halts execution of the program. @@ -451,53 +459,53 @@ class Produces(LogicNode): args: The arguments to return. """ - args: tuple[LogicNode] - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + args: tuple[LogicNode, ...] @staticmethod def is_stateful(): """Determines if the node is stateful.""" return True - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [*self.args] + def get_fields(self) -> list[Field]: + raise NotImplementedError + @classmethod - def make_term(cls, head, *args): - return head(args) + def make_term( # type: ignore[override] + cls, head: Callable[[tuple[LogicNode, ...]], Self], *args: LogicNode + ) -> Self: + return head(tuple(args)) @dataclass(eq=True, frozen=True) -class Plan(LogicNode): +class Plan(LogicExpression): """ - Represents a logical AST statement that executes a sequence of statements `bodies...`. - Returns the last statement. + Represents a logical AST statement that executes a sequence of statements + `bodies...`. Returns the last statement. Attributes: bodies: The sequence of statements to execute. """ - bodies: tuple[LogicNode] = () - - @staticmethod - def is_expr(): - """Determines if the node is an expression.""" - return True + bodies: tuple[LogicNode, ...] = () @staticmethod def is_stateful(): """Determines if the node is stateful.""" return True - def children(self): + def children(self) -> list[LogicNode]: """Returns the children of the node.""" return [*self.bodies] + def get_fields(self) -> list[Field]: + raise NotImplementedError + @classmethod - def make_term(cls, head, *val): - return head(val) + def make_term( # type: ignore[override] + cls, head: Callable[[tuple[LogicNode, ...]], Self], *bodies: LogicNode + ) -> Self: + return head(tuple(bodies)) diff --git a/src/finch/interface/__init__.py b/src/finch/interface/__init__.py index edd4b308..6667301b 100644 --- a/src/finch/interface/__init__.py +++ b/src/finch/interface/__init__.py @@ -1,6 +1,17 @@ -from .lazy import * -#from .tensor import * -from .fuse import * +# from .tensor import * +from .fuse import fuse, fused +from .lazy import ( + compute, + elementwise, + expand_dims, + identify, + lazy, + multiply, + permute_dims, + prod, + reduce, + squeeze, +) __all__ = [ "lazy", @@ -15,4 +26,4 @@ "elementwise", "prod", "multiply", -] \ No newline at end of file +] diff --git a/src/finch/interface/fuse.py b/src/finch/interface/fuse.py index fe30832d..ac7bc008 100644 --- a/src/finch/interface/fuse.py +++ b/src/finch/interface/fuse.py @@ -1,14 +1,12 @@ -from .lazy import * -from ..finch_logic import * -from ..symbolic import gensym -from ..algebra import * """ -This module provides functionality for array fusion and computation using lazy evaluation. +This module provides functionality for array fusion and computation using lazy +evaluation. Overview: --------- -Array fusion allows composing multiple array operations into a single kernel, enabling significant -performance optimizations by letting the compiler optimize the entire operation at once. +Array fusion allows composing multiple array operations into a single kernel, enabling +significant performance optimizations by letting the compiler optimize the entire +operation at once. Key Functions: -------------- @@ -25,13 +23,15 @@ >>> E = (C + D) / 2 >>> compute(E) - In this example, `E` represents a fused operation that adds `C` and `D` together and divides - the result by 2. The `compute` function optimizes and executes the operation efficiently. + In this example, `E` represents a fused operation that adds `C` and `D` together and + divides the result by 2. The `compute` function optimizes and executes the operation + efficiently. 2. Using `fuse` as a higher-order function: >>> result = fuse(lambda x, y: (x + y) / 2, A, B) - Here, `fuse` combines the addition and division operations into a single fused kernel. + Here, `fuse` combines the addition and division operations into a single fused + kernel. 3. Using the `fused` decorator: >>> @fused @@ -44,35 +44,44 @@ Performance: ------------ - Using `lazy` and `compute` results in faster execution due to operation fusion. -- Different optimizers can be used with `compute`, such as the Galley optimizer, which adapts to - the sparsity patterns of the inputs. -- The optimizer can be set using the `ctx` argument in `compute`, or via `set_scheduler` or `with_scheduler`. +- Different optimizers can be used with `compute`, such as the Galley optimizer, which + adapts to the sparsity patterns of the inputs. +- The optimizer can be set using the `ctx` argument in `compute`, or via `set_scheduler` + or `with_scheduler`. """ -def fuse(f, *args, ctx=get_default_scheduler()): +from .lazy import compute, get_default_scheduler, lazy + + +def fuse(f, *args, ctx=None): """ - fuse(f, *args, ctx=get_default_scheduler()): - Fuses multiple array operations into a single kernel. This function allows for composing - operations and executing them efficiently. + Fuses multiple array operations into a single kernel. This function allows for + composing operations and executing them efficiently. Parameters: - - f: The function representing the operation to be fused, returning a tensor or tuple of tensor results. + - f: The function representing the operation to be fused, returning a tensor or + tuple of tensor results. - *args: The input arrays or LazyTensors to be fused. - - ctx: The scheduler to use for computation. Defaults to the result of `get_default_scheduler()`. + - ctx: The scheduler to use for computation. Defaults to the result of + `get_default_scheduler()`. Returns: - The result of the fused operation, a tensor or tuple of tensors. """ + if ctx is None: + ctx = get_default_scheduler() args = [lazy(arg) for arg in args] if len(args) == 1: return f(args[0]) return compute(f(*args), ctx=ctx) -def fused(f, /, ctx=get_default_scheduler()): + +def fused(f, /, ctx=None): """ - fused(f): - A decorator that marks a function as fused. This allows the function to be used with the - `fuse` function for automatic fusion of operations. + A decorator that marks a function as fused. This allows the function to be used with + the `fuse` function for automatic fusion of operations. Parameters: - f: The function to be marked as fused. @@ -80,6 +89,10 @@ def fused(f, /, ctx=get_default_scheduler()): Returns: - A wrapper function that applies the fusion mechanism to the original function. """ + if ctx is None: + ctx = get_default_scheduler() + def wrapper(*args): return fuse(f, *args, ctx=ctx) - return wrapper \ No newline at end of file + + return wrapper diff --git a/src/finch/interface/lazy.py b/src/finch/interface/lazy.py index cd97a645..610dae97 100644 --- a/src/finch/interface/lazy.py +++ b/src/finch/interface/lazy.py @@ -1,79 +1,124 @@ -import operator import builtins +import itertools +import operator from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Tuple, Iterable -from itertools import accumulate -from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple -from ..algebra import * - -from ..finch_logic import * +from typing import Any + +from numpy.core.numeric import normalize_axis_tuple + +from ..algebra import ( + element_type, + fill_value, + fixpoint_type, + init_value, + register_property, + return_type, +) +from ..finch_logic import ( + Aggregate, + Alias, + Field, + Immediate, + LogicNode, + MapJoin, + Plan, + Produces, + Query, + Relabel, + Reorder, + Subquery, + Table, +) +from ..finch_logic.interpreter import FinchLogicInterpreter from ..symbolic import gensym +__all__ = [ + "LazyTensor", + "lazy", + "get_default_scheduler", + "compute", + "permute_dims", + "identify", + "expand_dims", + "squeeze", + "reduce", + "elementwise", + "prod", + "multiply", +] + + @dataclass class LazyTensor: data: LogicNode - shape: Tuple - fill_value: Any + shape: tuple[int, ...] + fill_value: Immediate element_type: Any @property def ndim(self) -> int: return len(self.shape) + def lazy(arr) -> LazyTensor: """ - - lazy(arr) -> LazyTensor: - Converts an array into a LazyTensor. If the input is already a LazyTensor, it is returned as-is. - Otherwise, it creates a LazyTensor representation of the input array. + - lazy(arr) -> LazyTensor: + Converts an array into a LazyTensor. If the input is already a LazyTensor, it is + returned as-is. Otherwise, it creates a LazyTensor representation of the input + array. - Parameters: - - arr: The input array to be converted into a LazyTensor. + Parameters: + - arr: The input array to be converted into a LazyTensor. - Returns: - - LazyTensor: A lazy representation of the input array. + Returns: + - LazyTensor: A lazy representation of the input array. """ if isinstance(arr, LazyTensor): return arr name = Alias(gensym("A")) - idxs = [Field(gensym("i")) for _ in range(arr.ndim)] + idxs = tuple(Field(gensym("i")) for _ in range(arr.ndim)) shape = tuple(arr.shape) tns = Subquery(name, Table(Immediate(arr), idxs)) return LazyTensor(tns, shape, fill_value(arr), element_type(arr)) + def get_default_scheduler(): return FinchLogicInterpreter() -def compute(arg, ctx=get_default_scheduler()): + +def compute(arg, ctx=None): """ - compute(arg, ctx=get_default_scheduler()): - Executes a fused operation represented by LazyTensors. This function evaluates the entire - operation in an optimized manner using the provided scheduler. + Executes a fused operation represented by LazyTensors. This function evaluates the + entire operation in an optimized manner using the provided scheduler. Parameters: - - arg: A lazy tensor or a tuple of lazy tensors representing the fused operation to be computed. - - ctx: The scheduler to use for computation. Defaults to the result of `get_default_scheduler()`. + - arg: A lazy tensor or a tuple of lazy tensors representing the fused operation to + be computed. + - ctx: The scheduler to use for computation. Defaults to the result of + `get_default_scheduler()`. Returns: - A tensor or a list of tensors computed by the fused operation. """ - if isinstance(arg, tuple): - args = arg - else: - args = (arg,) + if ctx is None: + ctx = get_default_scheduler() + args = arg if isinstance(arg, tuple) else (arg,) vars = tuple(Alias(gensym("A")) for _ in args) bodies = tuple(map(lambda arg, var: Query(var, arg.data), args, vars)) prgm = Plan(bodies + (Produces(vars),)) res = ctx(prgm) if isinstance(arg, tuple): return tuple(res) - else: - return res[0] + return res[0] + register_property(LazyTensor, "__self__", "fill_value", lambda x: x.fill_value) register_property(LazyTensor, "__self__", "element_type", lambda x: x.element_type) -def permute_dims(arg: LazyTensor, /, axis: Tuple[int, ...]) -> LazyTensor: + +def permute_dims(arg: LazyTensor, /, axis: tuple[int, ...]) -> LazyTensor: """ Permutes the axes (dimensions) of an array ``x``. @@ -82,40 +127,51 @@ def permute_dims(arg: LazyTensor, /, axis: Tuple[int, ...]) -> LazyTensor: x: array input array. axes: Tuple[int, ...] - tuple containing a permutation of ``(0, 1, ..., N-1)`` where ``N`` is the number of axes (dimensions) of ``x``. + tuple containing a permutation of ``(0, 1, ..., N-1)`` where ``N`` is the number + of axes (dimensions) of ``x``. Returns ------- out: array - an array containing the axes permutation. The returned array must have the same data type as ``x``. + an array containing the axes permutation. The returned array must have the same + data type as ``x``. """ axis = normalize_axis_tuple(axis, arg.ndim + len(axis)) - idxs = [Field(gensym("i")) for _ in range(arg.ndim)] + idxs = tuple(Field(gensym("i")) for _ in range(arg.ndim)) return LazyTensor( - Reorder(Relabel(arg.data, idxs), [idxs[i] for i in axis]), - [arg.shape[i] for i in axis], + Reorder(Relabel(arg.data, idxs), tuple(idxs[i] for i in axis)), + tuple(arg.shape[i] for i in axis), arg.fill_value, arg.element_type, ) + def identify(data): lhs = Alias(gensym("A")) return Subquery(lhs, data) + def expand_dims( x: LazyTensor, /, axis: int | tuple[int, ...] = 0, ) -> LazyTensor: """ - Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``. + Expands the shape of an array by inserting a new axis (dimension) of size one at the + position specified by ``axis``. Parameters ---------- x: array input array. axis: int - axis position (zero-based). If ``x`` has rank (i.e, number of dimensions) ``N``, a valid ``axis`` must reside on the closed-interval ``[-N-1, N]``. If provided a negative ``axis``, the axis position at which to insert a singleton dimension must be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position must be ``N`` (i.e., a singleton dimension must be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position must be ``0`` (i.e., a singleton dimension must be prepended to the input array ``x``). + axis position (zero-based). If ``x`` has rank (i.e, number of dimensions) ``N``, + a valid ``axis`` must reside on the closed-interval ``[-N-1, N]``. If provided a + negative ``axis``, the axis position at which to insert a singleton dimension + must be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved + axis position must be ``N`` (i.e., a singleton dimension must be appended to the + input array ``x``). If provided ``-N-1``, the resolved axis position must be + ``0`` (i.e., a singleton dimension must be prepended to the input array ``x``). Returns ------- @@ -130,24 +186,25 @@ def expand_dims( if isinstance(axis, int): axis = (axis,) axis = normalize_axis_tuple(axis, x.ndim + len(axis)) + assert isinstance(axis, tuple) assert len(axis) == len(set(axis)), "axis must be unique" assert set(axis).issubset(range(x.ndim + len(axis))), "Invalid axis" offset = [0] * (x.ndim + len(axis)) for d in axis: offset[d] = 1 - offset = list(accumulate(offset)) - idxs_1 = [Field(gensym("i")) for _ in range(x.ndim)] - idxs_2 = [ + offset = list(itertools.accumulate(offset)) + idxs_1 = tuple(Field(gensym("i")) for _ in range(x.ndim)) + idxs_2 = tuple( Field(gensym("i")) if n in axis else idxs_1[n - offset[n]] for n in range(x.ndim + len(axis)) - ] + ) data_2 = Reorder(Relabel(x.data, idxs_1), idxs_2) shape_2 = tuple( - 1 if n in axis else x.shape[n - offset[n]] - for n in range(x.ndim + len(axis)) + 1 if n in axis else x.shape[n - offset[n]] for n in range(x.ndim + len(axis)) ) return LazyTensor(data_2, shape_2, x.fill_value, x.element_type) + def squeeze( x: LazyTensor, /, @@ -177,25 +234,28 @@ def squeeze( if isinstance(axis, int): axis = (axis,) axis = normalize_axis_tuple(axis, x.ndim) + assert isinstance(axis, tuple) assert len(axis) == len(set(axis)), "axis must be unique" assert set(axis).issubset(range(x.ndim)), "Invalid axis" assert all(x.shape[d] == 1 for d in axis), "axis to drop must have size 1" newaxis = [n for n in range(x.ndim) if n not in axis] - idxs_1 = [Field(gensym("i")) for _ in range(x.ndim)] - idxs_2 = [idxs_1[n] for n in newaxis] + idxs_1 = tuple(Field(gensym("i")) for _ in range(x.ndim)) + idxs_2 = tuple(idxs_1[n] for n in newaxis) data_2 = Reorder(Relabel(x.data, idxs_1), idxs_2) shape_2 = tuple(x.shape[n] for n in newaxis) return LazyTensor(data_2, shape_2, x.fill_value, x.element_type) + def reduce( op: Callable, x: LazyTensor, /, *, axis: int | tuple[int, ...] | None = None, - dtype = None, + dtype=None, keepdims: bool = False, - init = None): + init=None, +): """ Reduces the input array ``x`` with the binary operator ``op``. Reduces along the specified `axis`, with an initial value `init`. @@ -205,36 +265,57 @@ def reduce( x: array input array. Should have a numeric data type. axis: Optional[Union[int, Tuple[int, ...]]] - axis or axes along which reduction must be computed. By default, the reduction must be computed over the entire array. If a tuple of integers, reductions must be computed over multiple axes. Default: ``None``. + axis or axes along which reduction must be computed. By default, the reduction + must be computed over the entire array. If a tuple of integers, reductions must + be computed over multiple axes. Default: ``None``. dtype: Optional[dtype] - data type of the returned array. If ``None``, a suitable data type will be calculated. + data type of the returned array. If ``None``, a suitable data type will be + calculated. keepdims: bool - if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result. Default: ``False``. - + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes + (dimensions) must not be included in the result. Default: ``False``. + init: Optional - Initial value for the reduction. If ``None``, a suitable initial value will be calculated. The initial value must be compatible with the operation defined by ``op``. For example, if ``op`` is addition, the initial value should be zero; if ``op`` is multiplication, the initial value should be one. + Initial value for the reduction. If ``None``, a suitable initial value will be + calculated. The initial value must be compatible with the operation defined by + ``op``. For example, if ``op`` is addition, the initial value should be zero; if + ``op`` is multiplication, the initial value should be one. Returns ------- out: array - If the reduction was computed over the entire array, a zero-dimensional array containing the reduction; otherwise, a non-zero-dimensional array containing the reduction. The returned array must have a data type as described by the ``dtype`` parameter above. + If the reduction was computed over the entire array, a zero-dimensional array + containing the reduction; otherwise, a non-zero-dimensional array containing the + reduction. The returned array must have a data type as described by the + ``dtype`` parameter above. """ if init is None: init = init_value(op, x.element_type) axis = normalize_axis_tuple(axis, x.ndim) + assert isinstance(axis, tuple) shape = tuple(x.shape[n] for n in range(x.ndim) if n not in axis) - fields = [Field(gensym("i")) for _ in range(x.ndim)] - data = Aggregate(Immediate(op), Immediate(init), Relabel(x.data, fields), [fields[i] for i in axis]) + fields = tuple(Field(gensym("i")) for _ in range(x.ndim)) + data: Aggregate | Reorder = Aggregate( + Immediate(op), + Immediate(init), + Relabel(x.data, fields), + tuple(fields[i] for i in axis), + ) if keepdims: - keeps = [fields[i] if i in axis else Field(gensym("j")) for i in range(x.ndim)] + keeps = tuple( + fields[i] if i in axis else Field(gensym("j")) for i in range(x.ndim) + ) data = Reorder(data, keeps) - shape = [shape[i] if i in axis else 1 for i in range(x.ndim)] + shape = tuple(shape[i] if i in axis else 1 for i in range(x.ndim)) if dtype is None: dtype = fixpoint_type(op, init, x.element_type) return LazyTensor(identify(data), shape, init, dtype) + def elementwise(f: Callable, *args) -> LazyTensor: """ elementwise(f, *args) -> LazyTensor: @@ -252,16 +333,21 @@ def elementwise(f: Callable, *args) -> LazyTensor: - f: The function to apply elementwise. - *args: The tensors to apply the function to. These tensors should be compatible for broadcasting. - + Returns: - LazyTensor: The tensor, `out`, of results from applying `f` elementwise to - the input tensors. After broadcasting the arguments to the same shape, for + the input tensors. After broadcasting the arguments to the same shape, for each index `i`, `out[*i] = f(args[0][*i], args[1][*i], ...)`. """ largs = list(map(lazy, args)) ndim = builtins.max([arg.ndim for arg in largs]) shape = tuple( - builtins.max([arg.shape[i - ndim + arg.ndim] if i - ndim + arg.ndim >= 0 else 1 for arg in largs]) + builtins.max( + [ + arg.shape[i - ndim + arg.ndim] if i - ndim + arg.ndim >= 0 else 1 + for arg in largs + ] + ) for i in range(ndim) ) idxs = [Field(gensym("i")) for _ in range(ndim)] @@ -269,7 +355,7 @@ def elementwise(f: Callable, *args) -> LazyTensor: for arg in largs: idims = [] odims = [] - for i in range(ndim - arg.ndim,ndim): + for i in range(ndim - arg.ndim, ndim): if arg.shape[i - ndim + arg.ndim] == shape[i]: idims.append(idxs[i]) odims.append(idxs[i]) @@ -283,7 +369,8 @@ def elementwise(f: Callable, *args) -> LazyTensor: new_element_type = return_type(f, *[x.element_type for x in largs]) return LazyTensor(identify(data), shape, new_fill_value, new_element_type) -def prod(arr: LazyTensor, dims) -> LazyTensor: + +def prod(arr: LazyTensor, dims: tuple[int, ...]) -> LazyTensor: """ Calculates the product of input array ``x`` elements. @@ -292,23 +379,42 @@ def prod(arr: LazyTensor, dims) -> LazyTensor: x: array input array. Should have a numeric data type. axis: Optional[Union[int, Tuple[int, ...]]] - axis or axes along which products must be computed. By default, the product must be computed over the entire array. If a tuple of integers, products must be computed over multiple axes. Default: ``None``. + axis or axes along which products must be computed. By default, the product must + be computed over the entire array. If a tuple of integers, products must be + computed over multiple axes. Default: ``None``. dtype: Optional[dtype] - data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases: - - - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. - - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). - - If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``. + data type of the returned array. If ``None``, the returned array must have the + same data type as ``x``, unless ``x`` has an integer data type supporting a + smaller range of values than the default integer data type (e.g., ``x`` has an + ``int16`` or ``uint32`` data type and the default integer data type is + ``int64``). In those latter cases: + + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned + array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned + array must have an unsigned integer data type having the same number of bits + as the default integer data type (e.g., if the default integer data type is + ``int32``, the returned array must have a ``uint32`` data type). + + If the data type (either specified or resolved) differs from the data type of + ``x``, the input array should be cast to the specified data type before + computing the sum (rationale: the ``dtype`` keyword argument is intended to help + prevent overflows). Default: ``None``. keepdims: bool - if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result. Default: ``False``. + if ``True``, the reduced axes (dimensions) must be included in the result as + singleton dimensions, and, accordingly, the result must be compatible with the + input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes + (dimensions) must not be included in the result. Default: ``False``. Returns ------- out: array - if the product was computed over the entire array, a zero-dimensional array containing the product; otherwise, a non-zero-dimensional array containing the products. The returned array must have a data type as described by the ``dtype`` parameter above. + if the product was computed over the entire array, a zero-dimensional array + containing the product; otherwise, a non-zero-dimensional array containing the + products. The returned array must have a data type as described by the ``dtype`` + parameter above. Notes ----- @@ -319,15 +425,12 @@ def prod(arr: LazyTensor, dims) -> LazyTensor: - If ``N`` is ``0``, the product is `1` (i.e., the empty product). - For both real-valued and complex floating-point operands, special cases must be handled as if the operation is implemented by successive application of :func:`~array_api.multiply`. - - .. versionchanged:: 2022.12 - Added complex data type support. - - .. versionchanged:: 2023.12 - Required the function to return a floating-point array having the same data type as the input array when provided a floating-point array. + For both real-valued and complex floating-point operands, special cases must be + handled as if the operation is implemented by successive application of + :func:`~array_api.multiply`. """ - return reduce(operator.mul, arr, dims, arr.fill_value) + return reduce(operator.mul, arr, axis=dims) + def multiply(x1: LazyTensor, x2: LazyTensor) -> LazyTensor: - return elementwise(operator.mul, x1, x2) \ No newline at end of file + return elementwise(operator.mul, x1, x2) diff --git a/src/finch/interface/tensor.py b/src/finch/interface/tensor.py index 8a594dcb..2df208e0 100644 --- a/src/finch/interface/tensor.py +++ b/src/finch/interface/tensor.py @@ -1,22 +1,21 @@ from abc import ABC, abstractmethod + from . import lazy -from .fuse import * +from .fuse import compute + class EagerTensor(ABC): @abstractmethod def shape(self): """Return the shape of the tensor.""" - pass @abstractmethod def dtype(self): """Return the data type of the tensor.""" - pass @abstractmethod def to_numpy(self): """Convert the tensor to a NumPy array.""" - pass @abstractmethod def __add__(self, other): @@ -25,10 +24,9 @@ def __add__(self, other): @abstractmethod def __mul__(self, other): """Define multiplication for tensors.""" - pass + def prod(arr, /, axis=None): if arr.is_lazy(): return lazy.prod(arr, axis=axis) - else: - return compute(lazy.prod(lazy.lazy(arr), axis=axis)) + return compute(lazy.prod(lazy.lazy(arr), axis=axis)) diff --git a/src/finch/symbolic/__init__.py b/src/finch/symbolic/__init__.py index 4eeeba1e..e04e3b5e 100644 --- a/src/finch/symbolic/__init__.py +++ b/src/finch/symbolic/__init__.py @@ -1,14 +1,14 @@ -from .rewriters import * -from .term import * -from .gensym import * +from .gensym import gensym +from .rewriters import Chain, PostWalk, PreWalk, Rewrite +from .term import PostOrderDFS, PreOrderDFS, Term __all__ = [ "PostOrderDFS", "PreOrderDFS", "Term", - "Rewriter", + "Rewrite", "PreWalk", "PostWalk", "Chain", - "gensym" -] \ No newline at end of file + "gensym", +] diff --git a/src/finch/symbolic/gensym.py b/src/finch/symbolic/gensym.py index 843b7612..353bfd06 100644 --- a/src/finch/symbolic/gensym.py +++ b/src/finch/symbolic/gensym.py @@ -1,4 +1,7 @@ -from typing import Callable +from collections.abc import Callable + +__all__ = ["gensym"] + class SymbolGenerator: counter: int = 0 @@ -9,5 +12,6 @@ def gensym(cls, name: str) -> str: cls.counter += 1 return sym + _sg = SymbolGenerator() -gensym: Callable[[str], str] = _sg.gensym \ No newline at end of file +gensym: Callable[[str], str] = _sg.gensym diff --git a/src/finch/symbolic/rewriters.py b/src/finch/symbolic/rewriters.py index 190ebab8..219ff7fc 100644 --- a/src/finch/symbolic/rewriters.py +++ b/src/finch/symbolic/rewriters.py @@ -1,8 +1,8 @@ """ This module provides a set of classes and utilities for symbolic term rewriting. Rewriters transform terms based on specific rules. A rewriter is any callable -which takes a Term and returns a Term or `None`. A rewriter can return `None` -if there are no changes applicable to the input Term. The module includes +which takes a term and returns a term or `None`. A rewriter can return `None` +if there are no changes applicable to the input term. The module includes various strategies for applying rewriters, such as recursive rewriting, chaining multiple rewriters, and caching results. @@ -20,13 +20,31 @@ Memo: Caches the results of a rewriter to avoid redundant computations. """ +from __future__ import annotations + from collections.abc import Callable, Iterable +from typing import TypeVar + from .term import Term -RwCallable = Callable[[Term], Term | None] +T = TypeVar("T", bound="Term") +U = TypeVar("U", bound="Term") + +RwCallable = Callable[[T], T | None] + +__all__ = [ + "default_rewrite", + "Rewrite", + "PreWalk", + "PostWalk", + "Chain", + "Fixpoint", + "Prestep", + "Memo", +] -def default_rewrite(x: Term | None, y: Term) -> Term: +def default_rewrite(x: T | None, y: U) -> T | U: return x if x is not None else y @@ -41,7 +59,7 @@ class Rewrite: def __init__(self, rw: RwCallable): self.rw = rw - def __call__(self, x: Term) -> Term: + def __call__(self, x: T) -> T: return default_rewrite(self.rw(x), x) @@ -58,24 +76,25 @@ class PreWalk: def __init__(self, rw: RwCallable): self.rw = rw - def __call__(self, x: Term) -> Term | None: + def __call__(self, x: T) -> T | None: + from ..finch_logic import LogicExpression + y = self.rw(x) - if y is not None: - if y.is_expr(): - args = y.children() - return y.make_term( - y.head(), *[default_rewrite(self(arg), arg) for arg in args] - ) - return y - if x.is_expr(): - args = x.children() - new_args = list(map(self, args)) - if not all(arg is None for arg in new_args): - return x.make_term( - x.head(), - *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args), + match y: + case LogicExpression() as expr: + args = expr.children() + return expr.make_term( # type: ignore[return-value] + expr.head(), *tuple(default_rewrite(self(arg), arg) for arg in args) ) - return None + match x: + case LogicExpression() as expr: + args = expr.children() + new_args = list(map(self, args)) + if not all(arg is None for arg in args): + return expr.make_term( # type: ignore[return-value] + expr.head(), + *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args), + ) return None @@ -92,16 +111,20 @@ class PostWalk: def __init__(self, rw: RwCallable): self.rw = rw - def __call__(self, x: Term) -> Term | None: - if x.is_expr(): - args = x.children() - new_args = list(map(self, args)) - if all(arg is None for arg in new_args): - return self.rw(x) - y = x.make_term( - x.head(), *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args) - ) - return default_rewrite(self.rw(y), y) + def __call__(self, x: T) -> T | None: + from ..finch_logic import LogicExpression + + match x: + case LogicExpression() as expr: + args = expr.children() + new_args = list(map(self, args)) + if all(arg is None for arg in new_args): + return self.rw(expr) + y = expr.make_term( + expr.head(), + *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args), + ) + return default_rewrite(self.rw(y), y) # type: ignore[return-value] return self.rw(x) @@ -117,7 +140,7 @@ class Chain: def __init__(self, rws: Iterable[RwCallable]): self.rws = rws - def __call__(self, x: Term) -> Term | None: + def __call__(self, x: T) -> T | None: is_success = False for rw in self.rws: y = rw(x) @@ -141,15 +164,12 @@ class Fixpoint: def __init__(self, rw: RwCallable): self.rw = rw - def __call__(self, x: Term) -> Term | None: + def __call__(self, x: T) -> T | None: y = self.rw(x) - if y is not None: - while y is not None and x != y: - x = y - y = self.rw(x) - return x - else: - return None + while y is not None and x != y: + x = y + y = self.rw(x) + return y class Prestep: @@ -164,16 +184,17 @@ class Prestep: def __init__(self, rw: RwCallable): self.rw = rw - def __call__(self, x: Term) -> Term | None: + def __call__(self, x: T) -> T | None: + from ..finch_logic import LogicExpression + y = self.rw(x) - if y is not None: - if y.is_expr(): - y_args = y.children() - return y.make_term( - y.head(), *[default_rewrite(self(arg), arg) for arg in y_args] + match y: + case LogicExpression() as expr: + args = expr.children() + return expr.make_term( # type: ignore[return-value] + expr.head(), *(default_rewrite(self(arg), arg) for arg in args) ) - return y - return None + return y class Memo: @@ -186,11 +207,11 @@ class Memo: cache (dict): A dictionary to store cached results. """ - def __init__(self, rw: RwCallable, cache: dict = None): + def __init__(self, rw: RwCallable, cache: dict | None = None): self.rw = rw self.cache = cache if cache is not None else {} - def __call__(self, x: Term) -> Term | None: + def __call__(self, x: T) -> T | None: if x not in self.cache: self.cache[x] = self.rw(x) return self.cache[x] diff --git a/src/finch/symbolic/term.py b/src/finch/symbolic/term.py index ef4d34c5..4e6bd98e 100644 --- a/src/finch/symbolic/term.py +++ b/src/finch/symbolic/term.py @@ -1,41 +1,44 @@ -from typing import Any, Iterator -from abc import ABC, abstractmethod - """ -This module contains definitions for common functions that are useful for symbolic expression manipulation. -Its purpose is to provide a shared interface between various symbolic programming in Finch. +This module contains definitions for common functions that are useful for symbolic +expression manipulation. Its purpose is to provide a shared interface between various +symbolic programming in Finch. Classes: - Term (ABC): An abstract base class representing a symbolic term. It provides methods to access the head - of the term, its children, and to construct a new term with a similar structure. + Term (ABC): An abstract base class representing a symbolic term. It provides methods + to access the head of the term, its children, and to construct a new term with a + similar structure. """ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator +from typing import Any, Self + +__all__ = ["Term", "PreOrderDFS", "PostOrderDFS"] + class Term(ABC): def __init__(self): self._hashcache = None # Private field to cache the hash value - @abstractmethod - def head(self) -> Any: + @classmethod + def head(cls) -> Callable[..., Self]: """Return the head type of the S-expression.""" - pass - - def children(self) -> list["Term"]: - """Return the children (AKA tail) of the S-expression.""" - pass + raise NotImplementedError @abstractmethod - def is_expr(self) -> bool: - """Return True if the term is an expression tree, False otherwise. Must implement children() if True.""" - pass + def children(self) -> list[Any]: + """Return the children (AKA tail) of the S-expression.""" - @abstractmethod - def make_term(self, head: Any, children: list["Term"]) -> "Term": + @classmethod + def make_term(cls, head: Callable[..., Self], *children: Term) -> Self: """ - Construct a new term in the same family of terms with the given head type and children. - This function should satisfy `x == x.make_term(x.head(), *x.children())` + Construct a new term in the same family of terms with the given + children. This function should satisfy + `x == x.make_term(x.head(), *x.children())` """ - pass + raise NotImplementedError def __hash__(self) -> int: """Return the hash value of the term.""" @@ -45,19 +48,27 @@ def __hash__(self) -> int: ) return self._hashcache - def __eq__(self, other: "Term") -> bool: - self.head() == other.head() and self.children() == other.children() + def __eq__(self, other: object) -> bool: + if not isinstance(other, Term): + return NotImplemented + return self.head() is other.head() and self.children() == other.children() def PostOrderDFS(node: Term) -> Iterator[Term]: - if node.is_expr(): - for arg in node.children(): - yield from PostOrderDFS(arg) + from ..finch_logic import LogicExpression + + match node: + case LogicExpression() as expr: + for arg in expr.children(): + yield from PostOrderDFS(arg) yield node def PreOrderDFS(node: Term) -> Iterator[Term]: yield node - if node.is_expr(): - for arg in node.children(): - yield from PreOrderDFS(arg) + from ..finch_logic import LogicExpression + + match node: + case LogicExpression() as expr: + for arg in expr.children(): + yield from PreOrderDFS(arg) diff --git a/src/finch/tensor/tensor.py b/src/finch/tensor/tensor.py index c36df69a..3dc1ba45 100644 --- a/src/finch/tensor/tensor.py +++ b/src/finch/tensor/tensor.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class AbstractTensor(ABC): @abstractmethod def shape(self): @@ -21,10 +22,13 @@ def __add__(self, other): def __mul__(self, other): pass + def fill_value(arg): - if isinstance(arg, LazyTensor): + from ..finch_logic import Immediate + from ..interface.lazy import LazyTensor + + if isinstance(arg, LazyTensor | Immediate): return arg.fill_value - elif isinstance(arg, (int, float)): + if isinstance(arg, int | float | bool | complex): return arg - else: - raise ValueError("Unsupported type for fill_value") \ No newline at end of file + raise TypeError("Unsupported type for fill_value") diff --git a/tests/test_interface.py b/tests/test_interface.py index 55ebac34..84dff39a 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,8 +1,12 @@ +from operator import add, mul + import numpy as np from numpy.testing import assert_equal + import pytest + import finch -from operator import add, mul + @pytest.mark.parametrize( "a, b", @@ -12,8 +16,14 @@ ], ) def test_matrix_multiplication(a, b): - result = finch.fuse(lambda a, b: finch.reduce(add, finch.elementwise(mul, finch.expand_dims(a, 2), b), axis=1), a, b) + result = finch.fuse( + lambda a, b: finch.reduce( + add, finch.elementwise(mul, finch.expand_dims(a, 2), b), axis=1 + ), + a, + b, + ) expected = np.matmul(a, b) - - assert_equal(result, expected) \ No newline at end of file + + assert_equal(result, expected) diff --git a/tests/test_logic_interpreter.py b/tests/test_logic_interpreter.py index ae4747d3..b54416a2 100644 --- a/tests/test_logic_interpreter.py +++ b/tests/test_logic_interpreter.py @@ -1,8 +1,24 @@ +from operator import add, mul + import numpy as np from numpy.testing import assert_equal + import pytest -from finch.finch_logic import * -from operator import add, mul + +from finch.finch_logic import ( + Aggregate, + Alias, + Field, + Immediate, + MapJoin, + Plan, + Produces, + Query, + Reorder, + Table, +) +from finch.finch_logic.interpreter import FinchLogicInterpreter + @pytest.mark.parametrize( "a, b", @@ -16,16 +32,23 @@ def test_matrix_multiplication(a, b): j = Field("j") k = Field("k") - p = Plan([ - Query(Alias("A"), Table(Immediate(a), (i, k))), - Query(Alias("B"), Table(Immediate(b), (k, j))), - Query(Alias("AB"), MapJoin(Immediate(mul), (Alias("A"), Alias("B")))), - Query(Alias("C"), Reorder(Aggregate(Immediate(add), Immediate(0), Alias("AB"), (k,)), (i, j))), - Produces((Alias("C"),)), - ]) + p = Plan( + [ + Query(Alias("A"), Table(Immediate(a), (i, k))), + Query(Alias("B"), Table(Immediate(b), (k, j))), + Query(Alias("AB"), MapJoin(Immediate(mul), (Alias("A"), Alias("B")))), + Query( + Alias("C"), + Reorder( + Aggregate(Immediate(add), Immediate(0), Alias("AB"), (k,)), (i, j) + ), + ), + Produces((Alias("C"),)), + ] + ) result = FinchLogicInterpreter()(p)[0] expected = np.matmul(a, b) - - assert_equal(result, expected) \ No newline at end of file + + assert_equal(result, expected) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 881dcf5e..8ce894de 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,15 +1,13 @@ -from finch.autoschedule import ( - propagate_fields, propagate_map_queries, lift_subqueries -) +from finch.autoschedule import lift_subqueries, propagate_fields, propagate_map_queries from finch.finch_logic import ( - Plan, - Query, + Aggregate, Alias, Field, - Aggregate, Immediate, MapJoin, + Plan, Produces, + Query, Relabel, Subquery, Table, @@ -99,33 +97,37 @@ def test_lift_subqueries(): def test_propagate_fields(): - plan = Plan(( - Query( - Alias("A10"), - MapJoin( - Immediate("op"), - ( - Table(Immediate("tbl1"), (Field("A1"), Field("A2"))), - Table(Immediate("tbl2"), (Field("A2"), Field("A3"))), + plan = Plan( + ( + Query( + Alias("A10"), + MapJoin( + Immediate("op"), + ( + Table(Immediate("tbl1"), (Field("A1"), Field("A2"))), + Table(Immediate("tbl2"), (Field("A2"), Field("A3"))), + ), ), ), - ), - Alias("A10"), - )) - - expected = Plan(( - Query( Alias("A10"), - MapJoin( - Immediate("op"), - ( - Table(Immediate("tbl1"), (Field("A1"), Field("A2"))), - Table(Immediate("tbl2"), (Field("A2"), Field("A3"))), + ) + ) + + expected = Plan( + ( + Query( + Alias("A10"), + MapJoin( + Immediate("op"), + ( + Table(Immediate("tbl1"), (Field("A1"), Field("A2"))), + Table(Immediate("tbl2"), (Field("A2"), Field("A3"))), + ), ), ), - ), - Relabel(Alias("A10"), (Field("A1"), Field("A2"), Field("A3"))), - )) + Relabel(Alias("A10"), (Field("A1"), Field("A2"), Field("A3"))), + ) + ) result = propagate_fields(plan) assert result == expected