Skip to content

Commit 6a225f5

Browse files
committed
Some typing fixes.
1 parent 9bea059 commit 6a225f5

File tree

5 files changed

+69
-57
lines changed

5 files changed

+69
-57
lines changed

src/finch/autoschedule/compiler.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,40 @@
1717
)
1818
from ..symbolic import Term
1919

20-
T = TypeVar("T", bound=Term)
20+
T = TypeVar("T", bound="LogicNode")
2121

2222

23-
def get_or_insert(dictionary: dict[str, Term], key: str, default: Any) -> Any:
23+
def get_or_insert(dictionary: dict[str, LogicNode], key: str, default: Any) -> Any:
2424
if key in dictionary:
2525
return dictionary[key]
2626
dictionary[key] = default
2727
return default
2828

2929

30-
def get_structure(node: T, fields: dict[str, Term], aliases: dict[str, Term]) -> T:
30+
def get_structure(
31+
node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode]
32+
) -> LogicNode:
3133
match node:
3234
case Field(name):
3335
return get_or_insert(fields, name, Immediate(len(fields) + len(aliases)))
3436
case Alias(name):
3537
return get_or_insert(aliases, name, Immediate(len(fields) + len(aliases)))
3638
case Subquery(Alias(name) as lhs, arg):
3739
if name in aliases:
38-
return aliases[name] # type: ignore[return-value]
39-
return Subquery( # type: ignore[return-value]
40+
return aliases[name]
41+
return Subquery(
4042
get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases)
4143
)
4244
case Table(tns, idxs):
43-
return Table( # type: ignore[return-value]
45+
assert all(isinstance(idx, Field) for idx in idxs)
46+
return Table(
4447
Immediate(type(tns.val)),
45-
tuple(get_structure(idx, fields, aliases) for idx in idxs),
48+
tuple(get_structure(idx, fields, aliases) for idx in idxs), # type: ignore[misc]
49+
)
50+
case any if any.is_expr():
51+
return any.make_term(
52+
*[get_structure(arg, fields, aliases) for arg in any.children()]
4653
)
47-
# TODO: `is_tree` isn't defined
48-
# case any if any.is_tree():
49-
# return any.from_arguments(
50-
# *[get_structure(arg, fields, aliases) for arg in any.get_arguments()]
51-
# )
5254
case _:
5355
return node
5456

src/finch/autoschedule/optimize.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ..finch_logic import (
22
Aggregate,
33
Alias,
4+
LogicNode,
45
MapJoin,
56
Plan,
67
Produces,
@@ -17,7 +18,9 @@ def optimize(prgm: Term) -> Term:
1718
return propagate_map_queries(prgm)
1819

1920

20-
def _lift_subqueries_expr(node: Term, bindings: dict[Term, Term]) -> Term:
21+
def _lift_subqueries_expr(
22+
node: LogicNode, bindings: dict[LogicNode, LogicNode]
23+
) -> LogicNode:
2124
match node:
2225
case Subquery(lhs, arg):
2326
if lhs not in bindings:
@@ -26,19 +29,18 @@ def _lift_subqueries_expr(node: Term, bindings: dict[Term, Term]) -> Term:
2629
return lhs
2730
case any if any.is_expr():
2831
return any.make_term(
29-
any.head(),
30-
[_lift_subqueries_expr(x, bindings) for x in any.children()],
32+
*tuple(_lift_subqueries_expr(x, bindings) for x in any.children()),
3133
)
3234
case _:
3335
return node
3436

3537

36-
def lift_subqueries(node: Term) -> Term:
38+
def lift_subqueries(node: Term) -> LogicNode:
3739
match node:
3840
case Plan(bodies):
3941
return Plan(tuple(map(lift_subqueries, bodies)))
4042
case Query(lhs, rhs):
41-
bindings: dict[Term, Term] = {}
43+
bindings: dict[LogicNode, LogicNode] = {}
4244
rhs_2 = _lift_subqueries_expr(rhs, bindings)
4345
return Plan(
4446
(*[Query(lhs, rhs) for lhs, rhs in bindings.items()], Query(lhs, rhs_2))
@@ -91,6 +93,6 @@ class DefaultLogicOptimizer:
9193
def __init__(self, ctx: LogicCompiler):
9294
self.ctx = ctx
9395

94-
def __call__(self, prgm: Term):
96+
def __call__(self, prgm: Term) -> str:
9597
prgm = optimize(prgm)
9698
return self.ctx(prgm)

src/finch/finch_logic/nodes.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22

33
from abc import abstractmethod
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any
5+
from typing import Any, Self, TypeVar
66

77
from ..symbolic import Term
88

9-
if TYPE_CHECKING:
10-
pass
11-
129
__all__ = [
1310
"LogicNode",
1411
"Immediate",
@@ -28,6 +25,9 @@
2825
]
2926

3027

28+
T = TypeVar("T", bound="LogicNode")
29+
30+
3131
@dataclass(eq=True, frozen=True)
3232
class LogicNode(Term):
3333
"""
@@ -57,14 +57,14 @@ def head(cls):
5757
"""Returns the head of the node."""
5858
return cls
5959

60-
def children(self) -> list[Term]:
60+
@abstractmethod
61+
def children(self) -> list[LogicNode]:
6162
"""Returns the children of the node."""
62-
raise Exception(f"`children` isn't supported for {self.__class__}.")
6363

6464
@classmethod
65-
def make_term(cls, head, args):
65+
def make_term(cls, *args: Term) -> Self:
6666
"""Creates a term with the given head and arguments."""
67-
return head(*args)
67+
return cls(*args)
6868

6969

7070
@dataclass(eq=True, frozen=True)
@@ -94,6 +94,9 @@ def fill_value(self):
9494

9595
return fill_value(self)
9696

97+
def children(self):
98+
raise TypeError(f"`{type(self).__name__}` doesn't support `.children()`.")
99+
97100

98101
@dataclass(eq=True, frozen=True)
99102
class Deferred(LogicNode):
@@ -222,7 +225,7 @@ class MapJoin(LogicNode):
222225
"""
223226

224227
op: Immediate
225-
args: tuple[Term, ...]
228+
args: tuple[LogicNode, ...]
226229

227230
@staticmethod
228231
def is_expr():
@@ -239,8 +242,8 @@ def children(self):
239242
return [self.op, *self.args]
240243

241244
@classmethod
242-
def make_term(cls, head, args):
243-
return head(args[0], tuple(args[1:]))
245+
def make_term(cls, op: Immediate, *args: LogicNode) -> Self: # type: ignore[override]
246+
return cls(op, tuple(args))
244247

245248

246249
@dataclass(eq=True, frozen=True)
@@ -258,7 +261,7 @@ class Aggregate(LogicNode):
258261

259262
op: Immediate
260263
init: Immediate
261-
arg: Term
264+
arg: LogicNode
262265
idxs: tuple[Field, ...]
263266

264267
@staticmethod
@@ -288,7 +291,7 @@ class Reorder(LogicNode):
288291
idxs: The new order of dimensions.
289292
"""
290293

291-
arg: Term
294+
arg: LogicNode
292295
idxs: tuple[Field, ...]
293296

294297
@staticmethod
@@ -317,7 +320,7 @@ class Relabel(LogicNode):
317320
idxs: The new labels for dimensions.
318321
"""
319322

320-
arg: Term
323+
arg: LogicNode
321324
idxs: tuple[Field, ...]
322325

323326
@staticmethod
@@ -346,7 +349,7 @@ class Reformat(LogicNode):
346349
"""
347350

348351
tns: Immediate
349-
arg: Term
352+
arg: LogicNode
350353

351354
@staticmethod
352355
def is_expr():
@@ -374,8 +377,8 @@ class Subquery(LogicNode):
374377
rhs: The argument to evaluate.
375378
"""
376379

377-
lhs: Term
378-
arg: Term
380+
lhs: LogicNode
381+
arg: LogicNode
379382

380383
@staticmethod
381384
def is_expr():
@@ -403,8 +406,8 @@ class Query(LogicNode):
403406
rhs: The right-hand side to evaluate.
404407
"""
405408

406-
lhs: Term
407-
rhs: Term
409+
lhs: LogicNode
410+
rhs: LogicNode
408411

409412
@staticmethod
410413
def is_expr():
@@ -431,7 +434,7 @@ class Produces(LogicNode):
431434
args: The arguments to return.
432435
"""
433436

434-
args: tuple[Term, ...]
437+
args: tuple[LogicNode, ...]
435438

436439
@staticmethod
437440
def is_expr():
@@ -445,11 +448,11 @@ def is_stateful():
445448

446449
def children(self):
447450
"""Returns the children of the node."""
448-
return list(self.args)
451+
return [*self.args]
449452

450453
@classmethod
451-
def make_term(cls, head, args):
452-
return head(tuple(args))
454+
def make_term(cls, *args: LogicNode) -> Self: # type: ignore[override]
455+
return cls(tuple(args))
453456

454457

455458
@dataclass(eq=True, frozen=True)
@@ -462,7 +465,7 @@ class Plan(LogicNode):
462465
bodies: The sequence of statements to execute.
463466
"""
464467

465-
bodies: tuple[Term, ...] = ()
468+
bodies: tuple[LogicNode, ...] = ()
466469

467470
@staticmethod
468471
def is_expr():
@@ -479,5 +482,5 @@ def children(self):
479482
return tuple(self.bodies)
480483

481484
@classmethod
482-
def make_term(cls, head, val):
483-
return head(tuple(val))
485+
def make_term(cls, *bodies: LogicNode) -> Self: # type: ignore[override]
486+
return cls(tuple(bodies))

src/finch/symbolic/rewriters.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"""
2222

2323
from collections.abc import Callable, Iterable
24+
from typing import TypeVar
2425

2526
from .term import Term
2627

@@ -35,10 +36,13 @@
3536
"Memo",
3637
]
3738

39+
T = TypeVar("T", bound="Term")
40+
U = TypeVar("U", bound="Term")
41+
3842
RwCallable = Callable[[Term], Term | None]
3943

4044

41-
def default_rewrite(x: Term | None, y: Term) -> Term:
45+
def default_rewrite(x: T | None, y: U) -> T | U:
4246
return x if x is not None else y
4347

4448

@@ -76,16 +80,15 @@ def __call__(self, x: Term) -> Term | None:
7680
if y.is_expr():
7781
args = y.children()
7882
return y.make_term(
79-
y.head(), [default_rewrite(self(arg), arg) for arg in args]
83+
*tuple(default_rewrite(self(arg), arg) for arg in args)
8084
)
8185
return y
8286
if x.is_expr():
8387
args = x.children()
8488
new_args = list(map(self, args))
8589
if not all(arg is None for arg in new_args):
8690
return x.make_term(
87-
x.head(),
88-
list(map(lambda x1, x2: default_rewrite(x1, x2), new_args, args)),
91+
*map(lambda x1, x2: default_rewrite(x1, x2), new_args, args),
8992
)
9093
return None
9194
return None
@@ -111,8 +114,7 @@ def __call__(self, x: Term) -> Term | None:
111114
if all(arg is None for arg in new_args):
112115
return self.rw(x)
113116
y = x.make_term(
114-
x.head(),
115-
list(map(lambda x1, x2: default_rewrite(x1, x2), new_args, args)),
117+
*map(lambda x1, x2: default_rewrite(x1, x2), new_args, args),
116118
)
117119
return default_rewrite(self.rw(y), y)
118120
return self.rw(x)
@@ -179,9 +181,7 @@ def __call__(self, x: Term) -> Term | None:
179181
if y is not None:
180182
if y.is_expr():
181183
y_args = y.children()
182-
return y.make_term(
183-
y.head(), [default_rewrite(self(arg), arg) for arg in y_args]
184-
)
184+
return y.make_term(*(default_rewrite(self(arg), arg) for arg in y_args))
185185
return y
186186
return None
187187

src/finch/symbolic/term.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313

1414
from abc import ABC, abstractmethod
1515
from collections.abc import Iterator
16-
from typing import Any
16+
from typing import TYPE_CHECKING, Any, Self
1717

1818
__all__ = ["Term", "PreOrderDFS", "PostOrderDFS"]
1919

20+
if TYPE_CHECKING:
21+
from ..finch_logic import LogicNode
22+
2023

2124
class Term(ABC):
2225
def __init__(self):
@@ -27,7 +30,7 @@ def head(self) -> Any:
2730
"""Return the head type of the S-expression."""
2831

2932
@abstractmethod
30-
def children(self) -> list[Term]:
33+
def children(self) -> list[LogicNode]:
3134
"""Return the children (AKA tail) of the S-expression."""
3235

3336
@abstractmethod
@@ -37,11 +40,11 @@ def is_expr(self) -> bool:
3740
`children()` if `True`."""
3841

3942
@abstractmethod
40-
def make_term(self, head: Any, children: list[Term]) -> Term:
43+
def make_term(self, head: Any, *children: Term) -> Self:
4144
"""
4245
Construct a new term in the same family of terms with the given head type and
4346
children. This function should satisfy
44-
`x == x.make_term(x.head(), *x.children())`
47+
`x == x.make_term(*x.children())`
4548
"""
4649

4750
def __hash__(self) -> int:
@@ -60,6 +63,7 @@ def __eq__(self, other: object) -> bool:
6063

6164
def PostOrderDFS(node: Term) -> Iterator[Term]:
6265
if node.is_expr():
66+
arg: Term
6367
for arg in node.children():
6468
yield from PostOrderDFS(arg)
6569
yield node
@@ -68,5 +72,6 @@ def PostOrderDFS(node: Term) -> Iterator[Term]:
6872
def PreOrderDFS(node: Term) -> Iterator[Term]:
6973
yield node
7074
if node.is_expr():
75+
arg: Term
7176
for arg in node.children():
7277
yield from PreOrderDFS(arg)

0 commit comments

Comments
 (0)