@@ -31,18 +31,20 @@ def _eshape(typ, env):
3131 return tuple (r if is_pos_int (r ) else env [r ] for r in typ .shape ())
3232
3333
34- def run_interpreter (proc , kwargs ):
35- Interpreter (proc , kwargs )
34+ def run_interpreter (proc , kwargs , ctxt = None ):
35+ Interpreter (proc , kwargs , ctxt )
3636
3737
3838class Interpreter :
39- def __init__ (self , proc , kwargs , use_randomization = False ):
39+ def __init__ (self , proc , kwargs , ctxt = None , use_randomization = False ):
4040 if not isinstance (proc , LoopIR .proc ):
4141 raise TypeError (f"Expected { proc .name } to be of type proc" )
4242
4343 self .env = ChainMap ()
4444 self .use_randomization = use_randomization
45- self .ctxt = defaultdict (dict )
45+ self .ctxt = {}
46+ if ctxt is not None :
47+ self .ctxt |= ctxt
4648
4749 self .eval_proc (proc , kwargs )
4850
@@ -52,13 +54,16 @@ def _new_scope(self):
5254 def _del_scope (self ):
5355 self .env = self .env .parents
5456
57+ def lookup_kwarg (self , kwargs , name ):
58+ return kwargs [name ] if name in kwargs else kwargs [str (name )]
59+
5560 def typecheck_input_buffer (self , proc_arg , kwargs ):
5661 nm = proc_arg .name
5762 if not proc_arg .type .is_numeric ():
5863 raise TypeError (f"arg { nm } is expected to be numeric" )
5964
6065 basetype = proc_arg .type .basetype ()
61- buf = kwargs [ str ( proc_arg .name )]
66+ buf = self . lookup_kwarg ( kwargs , proc_arg .name )
6267
6368 pre = f"bad argument '{ nm } '"
6469 if not isinstance (buf , np .ndarray ):
@@ -108,46 +113,50 @@ def typecheck_input_buffer(self, proc_arg, kwargs):
108113
109114 def eval_proc (self , proc , kwargs ):
110115 proc = ParallelAnalysis ().run (proc )
111- proc = PrecisionAnalysis ().run (proc ) # TODO: need this?
116+ # proc = PrecisionAnalysis().run(proc) # TODO: need this?
112117 proc = WindowAnalysis ().apply_proc (proc )
113- proc = MemoryAnalysis ().run (proc ) # TODO: need this?
118+ # proc = MemoryAnalysis().run(proc) # TODO: need this?
114119
115120 for a in proc .args :
116- if not str (a .name ) in kwargs :
121+ if not str (a .name ) in kwargs and not a . name in kwargs :
117122 raise TypeError (f"expected argument '{ a .name } ' to be supplied" )
118123
124+ kwarg_val = self .lookup_kwarg (kwargs , a .name )
119125 if a .type is T .size :
120- if not is_pos_int (kwargs [ str ( a . name )] ):
126+ if not is_pos_int (kwarg_val ):
121127 raise TypeError (
122128 f"expected size '{ a .name } ' to have positive integer value"
123129 )
124- self .env [a .name ] = kwargs [ str ( a . name )]
130+ self .env [a .name ] = kwarg_val
125131 elif a .type is T .index :
126- if type (kwargs [ str ( a . name )] ) is not int :
132+ if type (kwarg_val ) is not int :
127133 raise TypeError (
128134 f"expected index variable '{ a .name } ' to be an integer"
129135 )
130- self .env [a .name ] = kwargs [ str ( a . name )]
136+ self .env [a .name ] = kwarg_val
131137 elif a .type is T .bool :
132- if type (kwargs [ str ( a . name )] ) is not bool :
138+ if type (kwarg_val ) is not bool :
133139 raise TypeError (f"expected bool variable '{ a .name } ' to be a bool" )
134- self .env [a .name ] = kwargs [ str ( a . name )]
140+ self .env [a .name ] = kwarg_val
135141 elif a .type is T .stride :
136- if type (kwargs [ str ( a . name )] ) is not int :
142+ if type (kwarg_val ) is not int :
137143 raise TypeError (
138144 f"expected stride variable '{ a .name } ' to be an integer"
139145 )
140- self .env [a .name ] = kwargs [ str ( a . name )]
146+ self .env [a .name ] = kwarg_val
141147 else :
142148 self .typecheck_input_buffer (a , kwargs )
143- self .env [a .name ] = kwargs [ str ( a . name )]
149+ self .env [a .name ] = kwarg_val
144150
145151 # evaluate preconditions
146152 for pred in proc .preds :
147153 if isinstance (pred , LoopIR .Const ):
148154 continue
149155 else :
150- assert self .eval_e (pred ), "precondition not satisfied"
156+ predv = self .eval_e (pred )
157+ if not predv :
158+ print ("hi" )
159+ assert predv , "precondition not satisfied"
151160
152161 # eval statements
153162 self .eval_stmts (proc .body )
@@ -176,7 +185,10 @@ def eval_s(self, s):
176185 elif isinstance (s , LoopIR .WriteConfig ):
177186 nm = s .config .name ()
178187 rhs = self .eval_e (s .rhs )
179- self .ctxt [nm ][s .field ] = rhs
188+ if s .rhs .type .is_numeric ():
189+ self .ctxt [(nm , s .field )] = np .array ([rhs ])
190+ else :
191+ self .ctxt [(nm , s .field )] = rhs
180192
181193 elif isinstance (s , LoopIR .WindowStmt ):
182194 # nm = rbuf[...]
@@ -222,7 +234,7 @@ def eval_s(self, s):
222234
223235 elif isinstance (s , LoopIR .Call ):
224236 argvals = [self .eval_e (a , call_arg = True ) for a in s .args ]
225- argnames = [str ( a .name ) for a in s .f .args ]
237+ argnames = [a .name for a in s .f .args ]
226238 kwargs = {nm : val for nm , val in zip (argnames , argvals )}
227239 self ._new_scope ()
228240 self .eval_proc (s .f , kwargs )
@@ -300,6 +312,8 @@ def stringify_w_access(a):
300312
301313 elif isinstance (e , LoopIR .USub ):
302314 return - self .eval_e (e .arg )
315+ elif isinstance (e , LoopIR .Extern ):
316+ return e .f .interpret ([self .eval_e (arg ) for arg in e .args ])
303317
304318 # BuiltIns don't go to the interpreter, they are just called (via call) like a proc
305319 # TODO Discuss to make sure
@@ -316,7 +330,7 @@ def stringify_w_access(a):
316330
317331 elif isinstance (e , LoopIR .ReadConfig ):
318332 nm = e .config .name ()
319- return self .ctxt [nm ][ e .field ]
333+ return self .ctxt [( nm , e .field ) ]
320334
321335 else :
322336 print (e )
0 commit comments