1- from __future__ import annotations
1+ from dataclasses import dataclass
22from itertools import product
3+ from typing import Iterable , Any
34import numpy as np
4- from ..finch_logic import *
5- from ..symbolic import Term
6- from ..algebra import *
5+ from ..finch_logic import (
6+ Immediate ,
7+ Deferred ,
8+ Field ,
9+ Alias ,
10+ Table ,
11+ MapJoin ,
12+ Aggregate ,
13+ Query ,
14+ Plan ,
15+ Produces ,
16+ Subquery ,
17+ Relabel ,
18+ Reorder ,
19+ )
20+ from ..algebra import return_type , fill_value , element_type , fixpoint_type
21+
722
823@dataclass (eq = True , frozen = True )
9- class TableValue () :
24+ class TableValue :
1025 tns : Any
1126 idxs : Iterable [Any ]
27+
1228 def __post_init__ (self ):
1329 if isinstance (self .tns , TableValue ):
1430 raise ValueError ("The tensor (tns) cannot be a TableValue" )
1531
16- from typing import Any , Type
1732
1833class FinchLogicInterpreter :
1934 def __init__ (self , * , make_tensor = np .full ):
2035 self .verbose = False
2136 self .bindings = {}
2237 self .make_tensor = make_tensor # Added make_tensor argument
23-
38+
2439 def __call__ (self , node ):
2540 # Example implementation for evaluating an expression
2641 if self .verbose :
27- print (f"Evaluating: { expression } " )
42+ print (f"Evaluating: { node } " )
2843 # Placeholder for actual logic
2944 head = node .head ()
3045 if head == Immediate :
3146 return node .val
3247 elif head == Deferred :
33- raise ValueError ("The interpreter cannot evaluate a deferred node, a compiler might generate code for it" )
48+ raise ValueError (
49+ "The interpreter cannot evaluate a deferred node, a compiler might generate code for it"
50+ )
3451 elif head == Field :
3552 raise ValueError ("Fields cannot be used in expressions" )
3653 elif head == Alias :
@@ -59,7 +76,9 @@ def __call__(self, node):
5976 dims [idx ] = dim
6077 fill_val = op (* [fill_value (arg .tns ) for arg in args ])
6178 dtype = return_type (op , * [element_type (arg .tns ) for arg in args ])
62- result = self .make_tensor (tuple (dims [idx ] for idx in idxs ), fill_val , dtype = dtype )
79+ result = self .make_tensor (
80+ tuple (dims [idx ] for idx in idxs ), fill_val , dtype = dtype
81+ )
6382 for crds in product (* [range (dims [idx ]) for idx in idxs ]):
6483 idx_crds = {idx : crd for (idx , crd ) in zip (idxs , crds )}
6584 vals = [arg .tns [* [idx_crds [idx ] for idx in arg .idxs ]] for arg in args ]
@@ -74,10 +93,16 @@ def __call__(self, node):
7493 init = node .init .val
7594 op = node .op .val
7695 dtype = fixpoint_type (op , init , element_type (arg .tns ))
77- new_shape = [dim for (dim , idx ) in zip (arg .tns .shape , arg .idxs ) if not idx in node .idxs ]
96+ new_shape = [
97+ dim
98+ for (dim , idx ) in zip (arg .tns .shape , arg .idxs )
99+ if idx not in node .idxs
100+ ]
78101 result = self .make_tensor (new_shape , init , dtype = dtype )
79102 for crds in product (* [range (dim ) for dim in arg .tns .shape ]):
80- out_crds = [crd for (crd , idx ) in zip (crds , arg .idxs ) if not idx in node .idxs ]
103+ out_crds = [
104+ crd for (crd , idx ) in zip (crds , arg .idxs ) if idx not in node .idxs
105+ ]
81106 result [* out_crds ] = op (result [* out_crds ], arg .tns [* crds ])
82107 return TableValue (result , [idx for idx in arg .idxs if idx not in node .idxs ])
83108 elif head == Relabel :
@@ -92,7 +117,7 @@ def __call__(self, node):
92117 raise ValueError ("Trying to drop a dimension that is not 1" )
93118 arg_dims = {idx : dim for idx , dim in zip (arg .idxs , arg .tns .shape )}
94119 dims = [arg_dims .get (idx , 1 ) for idx in node .idxs ]
95- result = self .make_tensor (dims , fill_value (arg .tns ), dtype = arg .tns .dtype )
120+ result = self .make_tensor (dims , fill_value (arg .tns ), dtype = arg .tns .dtype )
96121 for crds in product (* [range (dim ) for dim in dims ]):
97122 node_crds = {idx : crd for (idx , crd ) in zip (node .idxs , crds )}
98123 in_crds = [node_crds .get (idx , 0 ) for idx in arg .idxs ]
@@ -110,8 +135,8 @@ def __call__(self, node):
110135 elif head == Produces :
111136 return tuple (self (arg ).tns for arg in node .args )
112137 elif head == Subquery :
113- if not node .lhs in self .bindings :
114- self .bindings [node .lhs ] = self (node .rhs )
138+ if node .lhs not in self .bindings :
139+ self .bindings [node .lhs ] = self (node .arg )
115140 return self .bindings [node .lhs ]
116141 else :
117- raise ValueError (f"Unknown expression type: { head } " )
142+ raise ValueError (f"Unknown expression type: { head } " )
0 commit comments