Skip to content

Commit 36c1add

Browse files
authored
Merge pull request #859 from o1-labs/feature/generalized-constraint
Abstract over the specific `Constraint.t` type used by the backing constraint system
2 parents b2442f4 + e09e89d commit 36c1add

16 files changed

+339
-345
lines changed

src/base/backend_extended.ml

+19-29
Original file line numberDiff line numberDiff line change
@@ -59,30 +59,31 @@ module type S = sig
5959
val to_constant : t -> Field.t option
6060
end
6161

62-
module R1CS_constraint_system : Constraint_system.S with module Field := Field
63-
6462
module Constraint : sig
65-
type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]
66-
67-
type 'k with_constraint_args = ?label:string -> 'k
63+
type t [@@deriving sexp]
6864

69-
val boolean : (Cvar.t -> t) with_constraint_args
65+
val boolean : Cvar.t -> t
7066

71-
val equal : (Cvar.t -> Cvar.t -> t) with_constraint_args
67+
val equal : Cvar.t -> Cvar.t -> t
7268

73-
val r1cs : (Cvar.t -> Cvar.t -> Cvar.t -> t) with_constraint_args
69+
val r1cs : Cvar.t -> Cvar.t -> Cvar.t -> t
7470

75-
val square : (Cvar.t -> Cvar.t -> t) with_constraint_args
71+
val square : Cvar.t -> Cvar.t -> t
7672

77-
val annotation : t -> string
73+
val eval : t -> (Cvar.t -> Field.t) -> bool
7874

79-
val eval :
80-
(Cvar.t, Field.t) Constraint.basic_with_annotation
81-
-> (Cvar.t -> Field.t)
82-
-> bool
75+
val log_constraint : t -> (Cvar.t -> Field.t) -> string
8376
end
8477

85-
module Run_state : Run_state_intf.S
78+
module R1CS_constraint_system :
79+
Constraint_system.S
80+
with module Field := Field
81+
with type constraint_ = Constraint.t
82+
83+
module Run_state :
84+
Run_state_intf.S
85+
with type field := Field.t
86+
and type constraint_ := Constraint.t
8687
end
8788

8889
module Make (Backend : Backend_intf.S) :
@@ -91,7 +92,8 @@ module Make (Backend : Backend_intf.S) :
9192
and type Field.Vector.t = Backend.Field.Vector.t
9293
and type Bigint.t = Backend.Bigint.t
9394
and type R1CS_constraint_system.t = Backend.R1CS_constraint_system.t
94-
and type 'field Run_state.t = 'field Backend.Run_state.t = struct
95+
and type Run_state.t = Backend.Run_state.t
96+
and type Constraint.t = Backend.Constraint.t = struct
9597
open Backend
9698

9799
module Bigint = struct
@@ -207,19 +209,7 @@ module Make (Backend : Backend_intf.S) :
207209
None
208210
end
209211

210-
module Constraint = struct
211-
open Constraint
212-
include Constraint.T
213-
214-
type 'k with_constraint_args = ?label:string -> 'k
215-
216-
type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]
217-
218-
let m = (module Field : Snarky_intf.Field.S with type t = Field.t)
219-
220-
let eval { basic; _ } get_value = Constraint.Basic.eval m get_value basic
221-
end
222-
212+
module Constraint = Constraint
223213
module R1CS_constraint_system = R1CS_constraint_system
224214
module Run_state = Run_state
225215
end

src/base/backend_intf.ml

+24-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,29 @@ module type S = sig
55

66
val field_size : Bigint.t
77

8-
module R1CS_constraint_system : Constraint_system.S with module Field := Field
8+
module Constraint : sig
9+
type t [@@deriving sexp]
910

10-
module Run_state : Run_state_intf.S
11+
val boolean : Field.t Cvar.t -> t
12+
13+
val equal : Field.t Cvar.t -> Field.t Cvar.t -> t
14+
15+
val r1cs : Field.t Cvar.t -> Field.t Cvar.t -> Field.t Cvar.t -> t
16+
17+
val square : Field.t Cvar.t -> Field.t Cvar.t -> t
18+
19+
val eval : t -> (Field.t Cvar.t -> Field.t) -> bool
20+
21+
val log_constraint : t -> (Field.t Cvar.t -> Field.t) -> string
22+
end
23+
24+
module R1CS_constraint_system :
25+
Constraint_system.S
26+
with module Field := Field
27+
with type constraint_ = Constraint.t
28+
29+
module Run_state :
30+
Run_state_intf.S
31+
with type field := Field.t
32+
and type constraint_ := Constraint.t
1133
end

src/base/checked.ml

+25-22
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
open Core_kernel
22

3-
module Make (Field : sig
4-
type t [@@deriving sexp]
5-
6-
val equal : t -> t -> bool
7-
end)
8-
(Types : Types.Types)
9-
(Basic : Checked_intf.Basic with type field = Field.t with module Types := Types)
10-
(As_prover : As_prover_intf.Basic
11-
with type field := Basic.field
12-
with module Types := Types) :
3+
module Make
4+
(Backend : Backend_extended.S)
5+
(Types : Types.Types)
6+
(Basic : Checked_intf.Basic
7+
with type field = Backend.Field.t
8+
and type constraint_ = Backend.Constraint.t
9+
with module Types := Types)
10+
(As_prover : As_prover_intf.Basic
11+
with type field := Basic.field
12+
with module Types := Types) :
1313
Checked_intf.S
1414
with module Types := Types
15-
with type field = Field.t
16-
and type run_state = Basic.run_state = struct
15+
with type field = Backend.Field.t
16+
and type run_state = Basic.run_state
17+
and type constraint_ = Basic.constraint_ = struct
1718
include Basic
1819

1920
let request_witness (typ : ('var, 'value) Types.Typ.t)
@@ -69,23 +70,25 @@ end)
6970
in
7071
handle t (fun request -> (Option.value_exn !handler) request)
7172

72-
let assert_ ?label c = add_constraint (Constraint.override_label c label)
73+
let assert_ c = add_constraint c
7374

74-
let assert_r1cs ?label a b c = assert_ (Constraint.r1cs ?label a b c)
75+
let assert_r1cs a b c = assert_ (Backend.Constraint.r1cs a b c)
7576

76-
let assert_square ?label a c = assert_ (Constraint.square ?label a c)
77+
let assert_square a c = assert_ (Backend.Constraint.square a c)
7778

78-
let assert_all ?label cs =
79+
let assert_all cs =
7980
List.fold_right cs ~init:(return ()) ~f:(fun c (acc : _ t) ->
80-
bind acc ~f:(fun () ->
81-
add_constraint (Constraint.override_label c label) ) )
81+
bind acc ~f:(fun () -> add_constraint c) )
8282

83-
let assert_equal ?label x y =
83+
let assert_equal x y =
8484
match (x, y) with
8585
| Cvar.Constant x, Cvar.Constant y ->
86-
if Field.equal x y then return ()
86+
if Backend.Field.equal x y then return ()
8787
else
88-
failwithf !"assert_equal: %{sexp: Field.t} != %{sexp: Field.t}" x y ()
88+
failwithf
89+
!"assert_equal: %{sexp: Backend.Field.t} != %{sexp: \
90+
Backend.Field.t}"
91+
x y ()
8992
| _ ->
90-
assert_ (Constraint.equal ?label x y)
93+
assert_ (Backend.Constraint.equal x y)
9194
end

src/base/checked_intf.ml

+12-13
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ module type Basic = sig
33

44
type field
55

6+
type constraint_
7+
68
type 'a t = 'a Types.Checked.t
79

810
type run_state
911

1012
include Monad_let.S with type 'a t := 'a t
1113

12-
val add_constraint : (field Cvar.t, field) Constraint.t -> unit t
14+
val add_constraint : constraint_ -> unit t
1315

1416
val as_prover : unit Types.As_prover.t -> unit t
1517

@@ -29,7 +31,7 @@ module type Basic = sig
2931
val direct : (run_state -> run_state * 'a) -> 'a t
3032

3133
val constraint_count :
32-
?weight:((field Cvar.t, field) Constraint.t -> int)
34+
?weight:(constraint_ -> int)
3335
-> ?log:(?start:bool -> string -> int -> unit)
3436
-> (unit -> 'a t)
3537
-> int
@@ -40,6 +42,8 @@ module type S = sig
4042

4143
type field
4244

45+
type constraint_
46+
4347
type run_state
4448

4549
type 'a t = 'a Types.Checked.t
@@ -89,25 +93,20 @@ module type S = sig
8993

9094
val with_label : string -> (unit -> 'a t) -> 'a t
9195

92-
val assert_ :
93-
?label:Base.string -> (field Cvar.t, field) Constraint.t -> unit t
96+
val assert_ : constraint_ -> unit t
9497

95-
val assert_r1cs :
96-
?label:Base.string -> field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t
98+
val assert_r1cs : field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t
9799

98-
val assert_square :
99-
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
100+
val assert_square : field Cvar.t -> field Cvar.t -> unit t
100101

101-
val assert_all :
102-
?label:Base.string -> (field Cvar.t, field) Constraint.t list -> unit t
102+
val assert_all : constraint_ list -> unit t
103103

104-
val assert_equal :
105-
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
104+
val assert_equal : field Cvar.t -> field Cvar.t -> unit t
106105

107106
val direct : (run_state -> run_state * 'a) -> 'a t
108107

109108
val constraint_count :
110-
?weight:((field Cvar.t, field) Constraint.t -> int)
109+
?weight:(constraint_ -> int)
111110
-> ?log:(?start:bool -> string -> int -> unit)
112111
-> (unit -> 'a t)
113112
-> int

src/base/checked_runner.ml

+16-48
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
open Core_kernel
2-
module Constraint0 = Constraint
32

43
let stack_to_string = String.concat ~sep:"\n"
54

@@ -10,9 +9,7 @@ let eval_constraints_ref = eval_constraints
109
module T (Backend : Backend_extended.S) = struct
1110
type 'a t =
1211
| Pure of 'a
13-
| Function of
14-
( Backend.Field.t Backend.Run_state.t
15-
-> Backend.Field.t Backend.Run_state.t * 'a )
12+
| Function of (Backend.Run_state.t -> Backend.Run_state.t * 'a)
1613
end
1714

1815
module Simple_types (Backend : Backend_extended.S) = Types.Make_types (struct
@@ -40,15 +37,15 @@ module Make_checked
4037
with type field := Backend.Field.t
4138
with module Types := Types) =
4239
struct
43-
type run_state = Backend.Field.t Backend.Run_state.t
40+
type run_state = Backend.Run_state.t
41+
42+
type constraint_ = Backend.Constraint.t
4443

4544
type field = Backend.Field.t
4645

4746
type 'a t = 'a T(Backend).t =
4847
| Pure of 'a
49-
| Function of
50-
( Backend.Field.t Backend.Run_state.t
51-
-> Backend.Field.t Backend.Run_state.t * 'a )
48+
| Function of (Backend.Run_state.t -> Backend.Run_state.t * 'a)
5249

5350
let eval (t : 'a t) : run_state -> run_state * 'a =
5451
match t with Pure a -> fun s -> (s, a) | Function g -> g
@@ -83,7 +80,7 @@ struct
8380

8481
open Backend
8582

86-
let get_value (t : Field.t Run_state.t) : Cvar.t -> Field.t =
83+
let get_value (t : Run_state.t) : Cvar.t -> Field.t =
8784
let get_one i = Run_state.get_variable_value t i in
8885
Cvar.eval (`Return_values_will_be_mutated get_one)
8986

@@ -143,36 +140,10 @@ struct
143140
f ~at_label_boundary:(`End, lab) None ) ;
144141
(Run_state.set_stack s' stack, y) )
145142

146-
let log_constraint ({ basic; _ } : Constraint.t) s =
147-
let open Constraint0 in
148-
match basic with
149-
| Boolean var ->
150-
Format.(asprintf "Boolean %s" (Field.to_string (get_value s var)))
151-
| Equal (var1, var2) ->
152-
Format.(
153-
asprintf "Equal %s %s"
154-
(Field.to_string (get_value s var1))
155-
(Field.to_string (get_value s var2)))
156-
| Square (var1, var2) ->
157-
Format.(
158-
asprintf "Square %s %s"
159-
(Field.to_string (get_value s var1))
160-
(Field.to_string (get_value s var2)))
161-
| R1CS (var1, var2, var3) ->
162-
Format.(
163-
asprintf "R1CS %s %s %s"
164-
(Field.to_string (get_value s var1))
165-
(Field.to_string (get_value s var2))
166-
(Field.to_string (get_value s var3)))
167-
| _ ->
168-
Format.asprintf
169-
!"%{sexp:(Field.t, Field.t) Constraint0.basic}"
170-
(Constraint0.Basic.map basic ~f:(get_value s))
171-
172-
let add_constraint ~stack ({ basic; annotation } : Constraint.t)
173-
(Constraint_system.T ((module C), system) : Field.t Constraint_system.t) =
174-
let label = Option.value annotation ~default:"<unknown>" in
175-
C.add_constraint system basic ~label:(stack_to_string (label :: stack))
143+
let add_constraint (basic : Constraint.t)
144+
(Constraint_system.T ((module C), system) :
145+
(Field.t, Constraint.t) Constraint_system.t ) =
146+
C.add_constraint system basic
176147

177148
let add_constraint c : _ t =
178149
Function
@@ -189,19 +160,18 @@ struct
189160
then
190161
failwithf
191162
"Constraint unsatisfied (unreduced):\n\
192-
%s\n\
193163
%s\n\n\
194164
Constraint:\n\
195165
%s\n\
196166
Data:\n\
197167
%s"
198-
(Constraint.annotation c)
199168
(stack_to_string (Run_state.stack s))
200169
(Sexp.to_string (Constraint.sexp_of_t c))
201-
(log_constraint c s) () ;
170+
(Backend.Constraint.log_constraint c (get_value s))
171+
() ;
202172
if not (Run_state.as_prover s) then
203173
Option.iter (Run_state.system s) ~f:(fun system ->
204-
add_constraint ~stack:(Run_state.stack s) c system ) ;
174+
add_constraint c system ) ;
205175
(s, ()) ) )
206176

207177
let with_handler h t : _ t =
@@ -422,17 +392,15 @@ module type S = sig
422392
module State : sig
423393
val make :
424394
num_inputs:int
425-
-> input:field Run_state.Vector.t
395+
-> input:field Run_state_intf.Vector.t
426396
-> next_auxiliary:int ref
427-
-> aux:field Run_state.Vector.t
397+
-> aux:field Run_state_intf.Vector.t
428398
-> ?system:r1cs
429399
-> ?eval_constraints:bool
430400
-> ?handler:Request.Handler.t
431401
-> with_witness:bool
432402
-> ?log_constraint:
433-
( ?at_label_boundary:[ `End | `Start ] * string
434-
-> (field Cvar.t, field) Constraint.t option
435-
-> unit )
403+
(?at_label_boundary:[ `End | `Start ] * string -> constr -> unit)
436404
-> unit
437405
-> run_state
438406
end

0 commit comments

Comments
 (0)