|
64 | 64 | * FiniteDomainSort |
65 | 65 | * Fixedpoint API |
66 | 66 | * SMT2 file support |
67 | | - * recursive functions |
68 | 67 | * Not missing, but different |
69 | 68 | * Options |
70 | 69 | * as expected |
@@ -150,6 +149,8 @@ def __init__(self): |
150 | 149 | self.vars = {} |
151 | 150 | # An increasing identifier used to make fresh identifiers |
152 | 151 | self.next_fresh_var = 0 |
| 152 | + # Function definitions to be added to a solver once it is created |
| 153 | + self.defined_functions = [] |
153 | 154 |
|
154 | 155 | def __del__(self): |
155 | 156 | self.tm = None |
@@ -937,6 +938,48 @@ def FreshFunction(*sig): |
937 | 938 | return Function(name, *sig) |
938 | 939 |
|
939 | 940 |
|
| 941 | +def RecFunction(name, *sig): |
| 942 | + """Create a new SMT uninterpreted function with the given sorts.""" |
| 943 | + return Function(name, sig) |
| 944 | + |
| 945 | + |
| 946 | +def RecAddDefinition(func, args, body): |
| 947 | + """Define a new SMT recursive function with the given function declaration. |
| 948 | + Replaces constants in `args` with bound variables. |
| 949 | +
|
| 950 | + >>> fact = Function('fact', IntSort(), IntSort()) |
| 951 | + >>> n = Int('n') |
| 952 | + >>> RecAddDefinition(fact, n, If(n == 0, 1, n * fact(n - 1))) |
| 953 | + >>> solve(Not(fact(5) == 120)) |
| 954 | + unsat |
| 955 | + """ |
| 956 | + if is_app(args): |
| 957 | + args = [args] |
| 958 | + ctx = func.ctx |
| 959 | + consts = [a.ast for a in args] |
| 960 | + vars_ = [ctx.tm.mkVar(a.sort().ast, str(a)) for a in args] |
| 961 | + subbed_body = body.ast.substitute(consts, vars_) |
| 962 | + ctx.defined_functions.append(((func.ast, vars_, subbed_body), True)) |
| 963 | + |
| 964 | + |
| 965 | +def AddDefinition(name, args, body): |
| 966 | + """Define a new SMT function with the given function declaration. |
| 967 | + Replaces constants in `args` with bound variables. |
| 968 | +
|
| 969 | + >>> x, y = Ints('x y') |
| 970 | + >>> minus = AddDefinition(minus, [x, y], x - y) |
| 971 | + >>> solve(Not(minus(10, 5) == 5)) |
| 972 | + unsat |
| 973 | + """ |
| 974 | + if is_app(args): |
| 975 | + args = [args] |
| 976 | + ctx = body.ctx |
| 977 | + consts = [a.ast for a in args] |
| 978 | + vars_ = [ctx.tm.mkVar(a.sort().ast, str(a)) for a in args] |
| 979 | + subbed_body = body.ast.substitute(consts, vars_) |
| 980 | + ctx.defined_functions.append(((name, vars_, subbed_body.getSort(), subbed_body), False)) |
| 981 | + |
| 982 | + |
940 | 983 | ######################################### |
941 | 984 | # |
942 | 985 | # Expressions |
@@ -6006,6 +6049,16 @@ def initFromLogic(self): |
6006 | 6049 | self.solver.setLogic(self.logic) |
6007 | 6050 | self.solver.setOption("produce-models", "true") |
6008 | 6051 |
|
| 6052 | + def add_func_definitions(self): |
| 6053 | + """Add function definitions present in the current context""" |
| 6054 | + # FIXME: This is a temporary fix and should be removed once the base |
| 6055 | + # API have the proper solution in place. |
| 6056 | + for func in self.ctx.defined_functions: |
| 6057 | + if func[1]: |
| 6058 | + self.solver.defineFunRec(*func[0]) |
| 6059 | + else: |
| 6060 | + self.solver.defineFun(*func[0]) |
| 6061 | + |
6009 | 6062 | def __del__(self): |
6010 | 6063 | if self.solver is not None: |
6011 | 6064 | self.solver = None |
@@ -6190,6 +6243,7 @@ def check(self, *assumptions): |
6190 | 6243 | unsat |
6191 | 6244 | >>> s.resetAssertions() |
6192 | 6245 | """ |
| 6246 | + self.add_func_definitions() |
6193 | 6247 | assumptions = _get_args(assumptions) |
6194 | 6248 | r = CheckSatResult(self.solver.checkSatAssuming(*[a.ast for a in assumptions])) |
6195 | 6249 | self.last_result = r |
|
0 commit comments