Skip to content

Commit fe25a75

Browse files
committed
restructure constraint solver for disjunction + write tests
1 parent b7de72e commit fe25a75

19 files changed

+787
-289
lines changed

src/exo/backend/LoopIR_interpreter.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,20 @@ def _eshape(typ, env):
3131
return tuple(r if is_pos_int(r) else env[r] for r in typ.shape())
3232

3333

34-
def run_interpreter(proc, kwargs):
35-
Interpreter(proc, kwargs)
34+
def run_interpreter(proc, kwargs, ctxt=None):
35+
Interpreter(proc, kwargs, ctxt)
3636

3737

3838
class Interpreter:
39-
def __init__(self, proc, kwargs, use_randomization=False):
39+
def __init__(self, proc, kwargs, ctxt=None, use_randomization=False):
4040
if not isinstance(proc, LoopIR.proc):
4141
raise TypeError(f"Expected {proc.name} to be of type proc")
4242

4343
self.env = ChainMap()
4444
self.use_randomization = use_randomization
45-
self.ctxt = defaultdict(dict)
45+
self.ctxt = {}
46+
if ctxt is not None:
47+
self.ctxt |= ctxt
4648

4749
self.eval_proc(proc, kwargs)
4850

@@ -52,13 +54,16 @@ def _new_scope(self):
5254
def _del_scope(self):
5355
self.env = self.env.parents
5456

57+
def lookup_kwarg(self, kwargs, name):
58+
return kwargs[name] if name in kwargs else kwargs[str(name)]
59+
5560
def typecheck_input_buffer(self, proc_arg, kwargs):
5661
nm = proc_arg.name
5762
if not proc_arg.type.is_numeric():
5863
raise TypeError(f"arg {nm} is expected to be numeric")
5964

6065
basetype = proc_arg.type.basetype()
61-
buf = kwargs[str(proc_arg.name)]
66+
buf = self.lookup_kwarg(kwargs, proc_arg.name)
6267

6368
pre = f"bad argument '{nm}'"
6469
if not isinstance(buf, np.ndarray):
@@ -108,46 +113,50 @@ def typecheck_input_buffer(self, proc_arg, kwargs):
108113

109114
def eval_proc(self, proc, kwargs):
110115
proc = ParallelAnalysis().run(proc)
111-
proc = PrecisionAnalysis().run(proc) # TODO: need this?
116+
# proc = PrecisionAnalysis().run(proc) # TODO: need this?
112117
proc = WindowAnalysis().apply_proc(proc)
113-
proc = MemoryAnalysis().run(proc) # TODO: need this?
118+
# proc = MemoryAnalysis().run(proc) # TODO: need this?
114119

115120
for a in proc.args:
116-
if not str(a.name) in kwargs:
121+
if not str(a.name) in kwargs and not a.name in kwargs:
117122
raise TypeError(f"expected argument '{a.name}' to be supplied")
118123

124+
kwarg_val = self.lookup_kwarg(kwargs, a.name)
119125
if a.type is T.size:
120-
if not is_pos_int(kwargs[str(a.name)]):
126+
if not is_pos_int(kwarg_val):
121127
raise TypeError(
122128
f"expected size '{a.name}' to have positive integer value"
123129
)
124-
self.env[a.name] = kwargs[str(a.name)]
130+
self.env[a.name] = kwarg_val
125131
elif a.type is T.index:
126-
if type(kwargs[str(a.name)]) is not int:
132+
if type(kwarg_val) is not int:
127133
raise TypeError(
128134
f"expected index variable '{a.name}' to be an integer"
129135
)
130-
self.env[a.name] = kwargs[str(a.name)]
136+
self.env[a.name] = kwarg_val
131137
elif a.type is T.bool:
132-
if type(kwargs[str(a.name)]) is not bool:
138+
if type(kwarg_val) is not bool:
133139
raise TypeError(f"expected bool variable '{a.name}' to be a bool")
134-
self.env[a.name] = kwargs[str(a.name)]
140+
self.env[a.name] = kwarg_val
135141
elif a.type is T.stride:
136-
if type(kwargs[str(a.name)]) is not int:
142+
if type(kwarg_val) is not int:
137143
raise TypeError(
138144
f"expected stride variable '{a.name}' to be an integer"
139145
)
140-
self.env[a.name] = kwargs[str(a.name)]
146+
self.env[a.name] = kwarg_val
141147
else:
142148
self.typecheck_input_buffer(a, kwargs)
143-
self.env[a.name] = kwargs[str(a.name)]
149+
self.env[a.name] = kwarg_val
144150

145151
# evaluate preconditions
146152
for pred in proc.preds:
147153
if isinstance(pred, LoopIR.Const):
148154
continue
149155
else:
150-
assert self.eval_e(pred), "precondition not satisfied"
156+
predv = self.eval_e(pred)
157+
if not predv:
158+
print("hi")
159+
assert predv, "precondition not satisfied"
151160

152161
# eval statements
153162
self.eval_stmts(proc.body)
@@ -176,7 +185,10 @@ def eval_s(self, s):
176185
elif isinstance(s, LoopIR.WriteConfig):
177186
nm = s.config.name()
178187
rhs = self.eval_e(s.rhs)
179-
self.ctxt[nm][s.field] = rhs
188+
if s.rhs.type.is_numeric():
189+
self.ctxt[(nm, s.field)] = np.array([rhs])
190+
else:
191+
self.ctxt[(nm, s.field)] = rhs
180192

181193
elif isinstance(s, LoopIR.WindowStmt):
182194
# nm = rbuf[...]
@@ -222,7 +234,7 @@ def eval_s(self, s):
222234

223235
elif isinstance(s, LoopIR.Call):
224236
argvals = [self.eval_e(a, call_arg=True) for a in s.args]
225-
argnames = [str(a.name) for a in s.f.args]
237+
argnames = [a.name for a in s.f.args]
226238
kwargs = {nm: val for nm, val in zip(argnames, argvals)}
227239
self._new_scope()
228240
self.eval_proc(s.f, kwargs)
@@ -300,6 +312,8 @@ def stringify_w_access(a):
300312

301313
elif isinstance(e, LoopIR.USub):
302314
return -self.eval_e(e.arg)
315+
elif isinstance(e, LoopIR.Extern):
316+
return e.f.interpret([self.eval_e(arg) for arg in e.args])
303317

304318
# BuiltIns don't go to the interpreter, they are just called (via call) like a proc
305319
# TODO Discuss to make sure
@@ -316,7 +330,7 @@ def stringify_w_access(a):
316330

317331
elif isinstance(e, LoopIR.ReadConfig):
318332
nm = e.config.name()
319-
return self.ctxt[nm][e.field]
333+
return self.ctxt[(nm, e.field)]
320334

321335
else:
322336
print(e)

src/exo/libs/externs.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from exo.core.extern import Extern, _EErr
2+
import numpy as np
23

34

45
class _Sin(Extern):
@@ -20,8 +21,8 @@ def typecheck(self, args):
2021
def globl(self, prim_type):
2122
return "#include <math.h>"
2223

23-
# def interpret(self, args):
24-
# return math.sin(args[0])
24+
def interpret(self, args):
25+
return np.sin(args[0])
2526

2627
def compile(self, args, prim_type):
2728
return f"sin(({prim_type}){args[0]})"
@@ -55,11 +56,11 @@ def globl(self, prim_type):
5556
)
5657
return s
5758

58-
# def interpret(self, args):
59-
# if args[0] > 0:
60-
# return args[0]
61-
# else:
62-
# return 0
59+
def interpret(self, args):
60+
if args[0] > 0:
61+
return args[0]
62+
else:
63+
return 0
6364

6465
def compile(self, args, prim_type):
6566
return f"_relu_{prim_type}(({prim_type}){args[0]})"
@@ -95,15 +96,15 @@ def globl(self, prim_type):
9596
)
9697
return s
9798

98-
# def interpret(self, args):
99-
# x = args[0]
100-
# v = args[1]
101-
# y = args[2]
102-
# z = args[3]
103-
# if x < v:
104-
# return y
105-
# else:
106-
# return z
99+
def interpret(self, args):
100+
x = args[0]
101+
v = args[1]
102+
y = args[2]
103+
z = args[3]
104+
if x < v:
105+
return y
106+
else:
107+
return z
107108

108109
def compile(self, args, prim_type):
109110
return f"_select_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]}, ({prim_type}){args[2]}, ({prim_type}){args[3]})"
@@ -131,8 +132,8 @@ def typecheck(self, args):
131132
def globl(self, prim_type):
132133
return "#include <math.h>"
133134

134-
# def interpret(self, args):
135-
# return math.expf(args[0])
135+
def interpret(self, args):
136+
return np.exp(args[0])
136137

137138
def compile(self, args, prim_type):
138139
return f"expf(({prim_type})({args[0]}))"
@@ -161,8 +162,8 @@ def typecheck(self, args):
161162
def globl(self, prim_type):
162163
return "#include <math.h>"
163164

164-
# def interpret(self, args):
165-
# return math.fmaxf(args[0], args[1])
165+
def interpret(self, args):
166+
return np.nanmax([args[0], args[1]])
166167

167168
def compile(self, args, prim_type):
168169
return f"fmaxf(({prim_type})({args[0]}), ({prim_type})({args[1]}))"
@@ -195,8 +196,8 @@ def globl(self, prim_type):
195196
}}
196197
"""
197198

198-
# def interpret(self, args):
199-
# return math.sigmoid(args[0])
199+
def interpret(self, args):
200+
return 1 / (1 + np.exp(-args[0]))
200201

201202
def compile(self, args, prim_type):
202203
return f"sigmoid(({prim_type})({args[0]}))"
@@ -224,8 +225,8 @@ def typecheck(self, args):
224225
def globl(self, prim_type):
225226
return "#include <math.h>"
226227

227-
# def interpret(self, args):
228-
# return math.sqrt(args[0])
228+
def interpret(self, args):
229+
return np.sqrt(args[0])
229230

230231
def compile(self, args, prim_type):
231232
return f"sqrt(({prim_type})({args[0]}))"

src/exo/platforms/gemmini.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ def ld_i8_block(
386386
assert n <= 16
387387
assert m <= 4
388388
assert stride(src, 1) == 1
389-
assert stride(dst, 0) == 16
390-
assert stride(dst, 1) == 1
389+
assert stride(dst, 1) == 16
390+
assert stride(dst, 2) == 1
391391

392392
for i in seq(0, n):
393393
for j in seq(0, m):
@@ -481,8 +481,8 @@ def zero_block_id2(
481481
):
482482
assert n <= 16
483483
assert m <= 4
484-
assert stride(dst, 0) == 16
485-
assert stride(dst, 1) == 1
484+
assert stride(dst, 1) == 16
485+
assert stride(dst, 2) == 1
486486

487487
for i in seq(0, n):
488488
for j in seq(0, m):

0 commit comments

Comments
 (0)