diff --git a/input/circuit.circom b/input/circuit.circom index 9eb9bfe..70b4708 100644 --- a/input/circuit.circom +++ b/input/circuit.circom @@ -2,56 +2,12 @@ pragma circom 2.0.0; -template Switcher() { - signal input sel; - signal input L; - signal input R; - signal output outL; - signal output outR; +template Test() { + signal input a; + signal input b; + signal output c; - signal aux; - - aux <== (R-L)*sel; // We create aux in order to have only one multiplication - outL <== aux + L; - outR <== -aux + R; -} - -template ArgMax (n) { - signal input in[n]; - signal output out; - - // assert (out < n); - signal gts[n]; // store comparators - component switchers[n+1]; // switcher for comparing maxs - component aswitchers[n+1]; // switcher for arg max - - signal maxs[n+1]; - signal amaxs[n+1]; - - maxs[0] <== in[0]; - amaxs[0] <== 0; - for(var i = 0; i < n; i++) { - gts[i] <== in[i] > maxs[i]; // changed to 252 (maximum) for better compatibility - switchers[i+1] = Switcher(); - aswitchers[i+1] = Switcher(); - - switchers[i+1].sel <== gts[i]; - switchers[i+1].L <== maxs[i]; - switchers[i+1].R <== in[i]; - - aswitchers[i+1].sel <== gts[i]; - aswitchers[i+1].L <== amaxs[i]; - aswitchers[i+1].R <== i; - amaxs[i+1] <== aswitchers[i+1].outL; - maxs[i+1] <== switchers[i+1].outL; - } - - out <== amaxs[n]; + c <== a * b; } -component main = ArgMax(2); - -/* INPUT = { - "in": ["2","3","1","5","4"], - "out": "3" -} */ \ No newline at end of file +component main = Test(); \ No newline at end of file diff --git a/src/compiler.rs b/src/compiler.rs index 6ca4c1d..8c78e33 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -9,7 +9,7 @@ use crate::{ use bristol_circuit::{BristolCircuit, CircuitInfo, ConstantInfo, Gate}; use log::debug; use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashSet}; use thiserror::Error; /// Represents a signal in the circuit, with a name and an optional value. @@ -106,10 +106,10 @@ impl ArithmeticGate { #[derive(Default, Debug, Serialize, Deserialize)] pub struct Compiler { node_count: u32, - inputs: HashMap, - outputs: HashMap, - signals: HashMap, - nodes: HashMap, + inputs: BTreeMap, + outputs: BTreeMap, + signals: BTreeMap, + nodes: BTreeMap, gates: Vec, value_type: ValueType, } @@ -118,20 +118,20 @@ impl Compiler { pub fn new() -> Compiler { Compiler { node_count: 0, - inputs: HashMap::new(), - outputs: HashMap::new(), - signals: HashMap::new(), - nodes: HashMap::new(), + inputs: BTreeMap::new(), + outputs: BTreeMap::new(), + signals: BTreeMap::new(), + nodes: BTreeMap::new(), gates: Vec::new(), value_type: Default::default(), } } - pub fn add_inputs(&mut self, inputs: HashMap) { + pub fn add_inputs(&mut self, inputs: BTreeMap) { self.inputs.extend(inputs); } - pub fn add_outputs(&mut self, outputs: HashMap) { + pub fn add_outputs(&mut self, outputs: BTreeMap) { self.outputs.extend(outputs); } @@ -160,8 +160,8 @@ impl Compiler { Ok(()) } - pub fn get_signals(&self, filter: String) -> HashMap { - let mut ret = HashMap::new(); + pub fn get_signals(&self, filter: String) -> BTreeMap { + let mut ret = BTreeMap::new(); for (signal_id, signal) in self.signals.iter() { if signal.name.starts_with(filter.as_str()) { ret.insert(*signal_id, signal.name.to_string()); @@ -320,9 +320,9 @@ impl Compiler { pub fn build_circuit(&self) -> Result { // First build up these maps so we can easily see which node id to use - let mut input_to_node_id = HashMap::::new(); - let mut constant_to_node_id_and_value = HashMap::::new(); - let mut output_to_node_id = HashMap::::new(); + let mut input_to_node_id = BTreeMap::::new(); + let mut constant_to_node_id_and_value = BTreeMap::::new(); + let mut output_to_node_id = BTreeMap::::new(); for (node_id, node) in self.nodes.iter() { // Each node has a list of signal ids which all correspond to that node @@ -368,7 +368,7 @@ impl Compiler { let node_id_to_input_name = input_to_node_id .iter() .map(|(name, node_id)| (node_id, name)) - .collect::>(); + .collect::>(); for (output_name, output_node_id) in &output_to_node_id { if let Some(input_name) = node_id_to_input_name.get(output_node_id) { @@ -385,7 +385,7 @@ impl Compiler { // Now node ids are like wire ids, but the compiler generates them in a way that leaves a // lot of gaps. So we assign new wire ids so they'll be sequential instead. We also do this // ensure inputs are at the start and outputs are at the end. - let mut node_id_to_wire_id = HashMap::::new(); + let mut node_id_to_wire_id = BTreeMap::::new(); let mut next_wire_id = 0; // First inputs @@ -398,7 +398,7 @@ impl Compiler { // assigned in the order they are needed. The topological order is also needed to comply // with bristol format and allow for easy evaluation. - let mut node_id_to_required_gate = HashMap::::new(); + let mut node_id_to_required_gate = BTreeMap::::new(); for (gate_id, gate) in self.gates.iter().enumerate() { // the gate.out node depends on this gate @@ -463,7 +463,7 @@ impl Compiler { }); } - let mut constants = HashMap::::new(); + let mut constants = BTreeMap::::new(); for (name, (node_id, value)) in constant_to_node_id_and_value { constants.insert( @@ -482,7 +482,7 @@ impl Compiler { .iter() .map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize)) .collect(), - constants, + constants: constants.into_iter().collect(), output_name_to_wire_index: output_to_node_id .iter() .map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize)) @@ -628,7 +628,7 @@ mod tests { #[test] fn test_compiler_add_inputs() { let mut compiler = Compiler::new(); - let mut inputs = HashMap::new(); + let mut inputs = BTreeMap::new(); inputs.insert(1, String::from("input1")); inputs.insert(2, String::from("input2")); compiler.add_inputs(inputs); @@ -641,7 +641,7 @@ mod tests { #[test] fn test_compiler_add_outputs() { let mut compiler = Compiler::new(); - let mut outputs = HashMap::new(); + let mut outputs = BTreeMap::new(); outputs.insert(3, String::from("output1")); outputs.insert(4, String::from("output2")); compiler.add_outputs(outputs); diff --git a/src/process.rs b/src/process.rs index a75878d..12075dc 100644 --- a/src/process.rs +++ b/src/process.rs @@ -15,7 +15,7 @@ use circom_program_structure::ast::{ }; use circom_program_structure::program_archive::ProgramArchive; use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::BTreeMap; use std::rc::Rc; /// Processes a sequence of statements. @@ -368,7 +368,7 @@ fn handle_call( // Get return values let mut function_return: Option = None; - let mut component_return: HashMap = HashMap::new(); + let mut component_return: BTreeMap = BTreeMap::new(); if is_function { if let Ok(value) = runtime diff --git a/src/runtime.rs b/src/runtime.rs index 5fb2633..7a79bea 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -7,7 +7,7 @@ use circom_program_structure::ast::VariableType; use rand::{thread_rng, Rng}; use std::{ cell::RefCell, - collections::{HashMap, HashSet, VecDeque}, + collections::{BTreeMap, HashSet, VecDeque}, fmt::Write, rc::Rc, }; @@ -130,9 +130,9 @@ impl Runtime { pub struct Context { ctx_name: String, names: HashSet, - variables: HashMap, - signals: HashMap, - components: HashMap, + variables: BTreeMap, + signals: BTreeMap, + components: BTreeMap, } impl Context { @@ -141,9 +141,9 @@ impl Context { Self { ctx_name, names: HashSet::new(), - variables: HashMap::new(), - signals: HashMap::new(), - components: HashMap::new(), + variables: BTreeMap::new(), + signals: BTreeMap::new(), + components: BTreeMap::new(), } } @@ -352,7 +352,7 @@ impl Context { pub fn get_component_map( &self, access: &DataAccess, - ) -> Result, RuntimeError> { + ) -> Result, RuntimeError> { let component = self .components .get(&access.name) @@ -405,7 +405,7 @@ impl Context { pub fn set_component( &mut self, access: &DataAccess, - map: HashMap, + map: BTreeMap, ) -> Result<(), RuntimeError> { let component = self.components @@ -514,13 +514,13 @@ impl Variable { /// Stores a component's input/output signals with their respective identifiers. #[derive(Clone, Debug)] pub struct Component { - signal_map: NestedValue>, + signal_map: NestedValue>, } impl Component { /// Constructs a new Component as a nested structure based on provided dimensions. fn new(dimensions: &[u32]) -> Self { - let mut signal_map = NestedValue::Value(HashMap::new()); + let mut signal_map = NestedValue::Value(BTreeMap::new()); // Construct the nested structure in reverse order to ensure the correct dimensionality. for &dimension in dimensions.iter().rev() { @@ -532,7 +532,7 @@ impl Component { } /// Retrieves the component signal map at the specified index path. - fn get_map(&self, index_path: &[u32]) -> Result, RuntimeError> { + fn get_map(&self, index_path: &[u32]) -> Result, RuntimeError> { let nested_val = get_nested_value(&self.signal_map, index_path)?; match nested_val { @@ -545,7 +545,7 @@ impl Component { fn set_signal_map( &mut self, component_access: &[u32], - map: HashMap, + map: BTreeMap, ) -> Result<(), RuntimeError> { let nested_val = get_mut_nested_value(&mut self.signal_map, component_access)?; @@ -1040,7 +1040,7 @@ mod tests { ) .unwrap(); - let mut signal_map = HashMap::new(); + let mut signal_map = BTreeMap::new(); let signal = Signal::new(&[], next_signal_id.clone()); signal_map.insert("signal1".to_string(), signal); @@ -1285,7 +1285,7 @@ mod tests { #[test] fn test_component_set_and_get_signal_map() { let mut component = Component::new(&[1]); - let mut signal_map = HashMap::new(); + let mut signal_map = BTreeMap::new(); let signal = Signal::new(&[], Rc::new(RefCell::new(0))); signal_map.insert("signal1".to_string(), signal); @@ -1300,7 +1300,7 @@ mod tests { #[test] fn test_component_get_signal_content() { let mut component = Component::new(&[1]); - let mut signal_map = HashMap::new(); + let mut signal_map = BTreeMap::new(); let signal = Signal::new(&[], Rc::new(RefCell::new(0))); signal_map.insert("signal1".to_string(), signal); @@ -1319,7 +1319,7 @@ mod tests { #[test] fn test_component_get_signal_id() { let mut component = Component::new(&[1]); - let mut signal_map = HashMap::new(); + let mut signal_map = BTreeMap::new(); let signal = Signal::new(&[], Rc::new(RefCell::new(0))); signal_map.insert("signal1".to_string(), signal); @@ -1338,7 +1338,7 @@ mod tests { #[test] fn test_component_nested_signal_map() { let mut component = Component::new(&[2]); - let mut signal_map_0 = HashMap::new(); + let mut signal_map_0 = BTreeMap::new(); let signal_0 = Signal::new(&[], Rc::new(RefCell::new(0))); signal_map_0.insert("signal1".to_string(), signal_0); @@ -1346,7 +1346,7 @@ mod tests { .set_signal_map(&[0], signal_map_0) .expect("Setting signal map failed"); - let mut signal_map_1 = HashMap::new(); + let mut signal_map_1 = BTreeMap::new(); let signal_1 = Signal::new(&[], Rc::new(RefCell::new(1))); signal_map_1.insert("signal2".to_string(), signal_1);