Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 6 additions & 50 deletions input/circuit.circom
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,12 @@

pragma circom 2.0.0;

template Switcher() {
signal input sel;
signal input L;
signal input R;
signal output outL;
signal output outR;
template Test() {
signal input a;
signal input b;
signal output c;

signal aux;

aux <== (R-L)*sel; // We create aux in order to have only one multiplication
outL <== aux + L;
outR <== -aux + R;
}

template ArgMax (n) {
signal input in[n];
signal output out;

// assert (out < n);
signal gts[n]; // store comparators
component switchers[n+1]; // switcher for comparing maxs
component aswitchers[n+1]; // switcher for arg max

signal maxs[n+1];
signal amaxs[n+1];

maxs[0] <== in[0];
amaxs[0] <== 0;
for(var i = 0; i < n; i++) {
gts[i] <== in[i] > maxs[i]; // changed to 252 (maximum) for better compatibility
switchers[i+1] = Switcher();
aswitchers[i+1] = Switcher();

switchers[i+1].sel <== gts[i];
switchers[i+1].L <== maxs[i];
switchers[i+1].R <== in[i];

aswitchers[i+1].sel <== gts[i];
aswitchers[i+1].L <== amaxs[i];
aswitchers[i+1].R <== i;
amaxs[i+1] <== aswitchers[i+1].outL;
maxs[i+1] <== switchers[i+1].outL;
}

out <== amaxs[n];
c <== a * b;
}

component main = ArgMax(2);

/* INPUT = {
"in": ["2","3","1","5","4"],
"out": "3"
} */
component main = Test();
46 changes: 23 additions & 23 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
use bristol_circuit::{BristolCircuit, CircuitInfo, ConstantInfo, Gate};
use log::debug;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeMap, HashSet};
use thiserror::Error;

/// Represents a signal in the circuit, with a name and an optional value.
Expand Down Expand Up @@ -106,10 +106,10 @@ impl ArithmeticGate {
#[derive(Default, Debug, Serialize, Deserialize)]
pub struct Compiler {
node_count: u32,
inputs: HashMap<u32, String>,
outputs: HashMap<u32, String>,
signals: HashMap<u32, Signal>,
nodes: HashMap<u32, Node>,
inputs: BTreeMap<u32, String>,
outputs: BTreeMap<u32, String>,
signals: BTreeMap<u32, Signal>,
nodes: BTreeMap<u32, Node>,
gates: Vec<ArithmeticGate>,
value_type: ValueType,
}
Expand All @@ -118,20 +118,20 @@ impl Compiler {
pub fn new() -> Compiler {
Compiler {
node_count: 0,
inputs: HashMap::new(),
outputs: HashMap::new(),
signals: HashMap::new(),
nodes: HashMap::new(),
inputs: BTreeMap::new(),
outputs: BTreeMap::new(),
signals: BTreeMap::new(),
nodes: BTreeMap::new(),
gates: Vec::new(),
value_type: Default::default(),
}
}

pub fn add_inputs(&mut self, inputs: HashMap<u32, String>) {
pub fn add_inputs(&mut self, inputs: BTreeMap<u32, String>) {
self.inputs.extend(inputs);
}

pub fn add_outputs(&mut self, outputs: HashMap<u32, String>) {
pub fn add_outputs(&mut self, outputs: BTreeMap<u32, String>) {
self.outputs.extend(outputs);
}

Expand Down Expand Up @@ -160,8 +160,8 @@ impl Compiler {
Ok(())
}

pub fn get_signals(&self, filter: String) -> HashMap<u32, String> {
let mut ret = HashMap::new();
pub fn get_signals(&self, filter: String) -> BTreeMap<u32, String> {
let mut ret = BTreeMap::new();
for (signal_id, signal) in self.signals.iter() {
if signal.name.starts_with(filter.as_str()) {
ret.insert(*signal_id, signal.name.to_string());
Expand Down Expand Up @@ -320,9 +320,9 @@ impl Compiler {

pub fn build_circuit(&self) -> Result<BristolCircuit, CircuitError> {
// First build up these maps so we can easily see which node id to use
let mut input_to_node_id = HashMap::<String, u32>::new();
let mut constant_to_node_id_and_value = HashMap::<String, (u32, String)>::new();
let mut output_to_node_id = HashMap::<String, u32>::new();
let mut input_to_node_id = BTreeMap::<String, u32>::new();
let mut constant_to_node_id_and_value = BTreeMap::<String, (u32, String)>::new();
let mut output_to_node_id = BTreeMap::<String, u32>::new();

for (node_id, node) in self.nodes.iter() {
// Each node has a list of signal ids which all correspond to that node
Expand Down Expand Up @@ -368,7 +368,7 @@ impl Compiler {
let node_id_to_input_name = input_to_node_id
.iter()
.map(|(name, node_id)| (node_id, name))
.collect::<HashMap<_, _>>();
.collect::<BTreeMap<_, _>>();

for (output_name, output_node_id) in &output_to_node_id {
if let Some(input_name) = node_id_to_input_name.get(output_node_id) {
Expand All @@ -385,7 +385,7 @@ impl Compiler {
// Now node ids are like wire ids, but the compiler generates them in a way that leaves a
// lot of gaps. So we assign new wire ids so they'll be sequential instead. We also do this
// ensure inputs are at the start and outputs are at the end.
let mut node_id_to_wire_id = HashMap::<u32, u32>::new();
let mut node_id_to_wire_id = BTreeMap::<u32, u32>::new();
let mut next_wire_id = 0;

// First inputs
Expand All @@ -398,7 +398,7 @@ impl Compiler {
// assigned in the order they are needed. The topological order is also needed to comply
// with bristol format and allow for easy evaluation.

let mut node_id_to_required_gate = HashMap::<u32, usize>::new();
let mut node_id_to_required_gate = BTreeMap::<u32, usize>::new();

for (gate_id, gate) in self.gates.iter().enumerate() {
// the gate.out node depends on this gate
Expand Down Expand Up @@ -463,7 +463,7 @@ impl Compiler {
});
}

let mut constants = HashMap::<String, ConstantInfo>::new();
let mut constants = BTreeMap::<String, ConstantInfo>::new();

for (name, (node_id, value)) in constant_to_node_id_and_value {
constants.insert(
Expand All @@ -482,7 +482,7 @@ impl Compiler {
.iter()
.map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize))
.collect(),
constants,
constants: constants.into_iter().collect(),
output_name_to_wire_index: output_to_node_id
.iter()
.map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize))
Expand Down Expand Up @@ -628,7 +628,7 @@ mod tests {
#[test]
fn test_compiler_add_inputs() {
let mut compiler = Compiler::new();
let mut inputs = HashMap::new();
let mut inputs = BTreeMap::new();
inputs.insert(1, String::from("input1"));
inputs.insert(2, String::from("input2"));
compiler.add_inputs(inputs);
Expand All @@ -641,7 +641,7 @@ mod tests {
#[test]
fn test_compiler_add_outputs() {
let mut compiler = Compiler::new();
let mut outputs = HashMap::new();
let mut outputs = BTreeMap::new();
outputs.insert(3, String::from("output1"));
outputs.insert(4, String::from("output2"));
compiler.add_outputs(outputs);
Expand Down
4 changes: 2 additions & 2 deletions src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use circom_program_structure::ast::{
};
use circom_program_structure::program_archive::ProgramArchive;
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::BTreeMap;
use std::rc::Rc;

/// Processes a sequence of statements.
Expand Down Expand Up @@ -368,7 +368,7 @@ fn handle_call(

// Get return values
let mut function_return: Option<u32> = None;
let mut component_return: HashMap<String, Signal> = HashMap::new();
let mut component_return: BTreeMap<String, Signal> = BTreeMap::new();

if is_function {
if let Ok(value) = runtime
Expand Down
38 changes: 19 additions & 19 deletions src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use circom_program_structure::ast::VariableType;
use rand::{thread_rng, Rng};
use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
collections::{BTreeMap, HashSet, VecDeque},
fmt::Write,
rc::Rc,
};
Expand Down Expand Up @@ -130,9 +130,9 @@ impl Runtime {
pub struct Context {
ctx_name: String,
names: HashSet<String>,
variables: HashMap<String, Variable>,
signals: HashMap<String, Signal>,
components: HashMap<String, Component>,
variables: BTreeMap<String, Variable>,
signals: BTreeMap<String, Signal>,
components: BTreeMap<String, Component>,
}

impl Context {
Expand All @@ -141,9 +141,9 @@ impl Context {
Self {
ctx_name,
names: HashSet::new(),
variables: HashMap::new(),
signals: HashMap::new(),
components: HashMap::new(),
variables: BTreeMap::new(),
signals: BTreeMap::new(),
components: BTreeMap::new(),
}
}

Expand Down Expand Up @@ -352,7 +352,7 @@ impl Context {
pub fn get_component_map(
&self,
access: &DataAccess,
) -> Result<HashMap<String, Signal>, RuntimeError> {
) -> Result<BTreeMap<String, Signal>, RuntimeError> {
let component = self
.components
.get(&access.name)
Expand Down Expand Up @@ -405,7 +405,7 @@ impl Context {
pub fn set_component(
&mut self,
access: &DataAccess,
map: HashMap<String, Signal>,
map: BTreeMap<String, Signal>,
) -> Result<(), RuntimeError> {
let component =
self.components
Expand Down Expand Up @@ -514,13 +514,13 @@ impl Variable {
/// Stores a component's input/output signals with their respective identifiers.
#[derive(Clone, Debug)]
pub struct Component {
signal_map: NestedValue<HashMap<String, Signal>>,
signal_map: NestedValue<BTreeMap<String, Signal>>,
}

impl Component {
/// Constructs a new Component as a nested structure based on provided dimensions.
fn new(dimensions: &[u32]) -> Self {
let mut signal_map = NestedValue::Value(HashMap::new());
let mut signal_map = NestedValue::Value(BTreeMap::new());

// Construct the nested structure in reverse order to ensure the correct dimensionality.
for &dimension in dimensions.iter().rev() {
Expand All @@ -532,7 +532,7 @@ impl Component {
}

/// Retrieves the component signal map at the specified index path.
fn get_map(&self, index_path: &[u32]) -> Result<HashMap<String, Signal>, RuntimeError> {
fn get_map(&self, index_path: &[u32]) -> Result<BTreeMap<String, Signal>, RuntimeError> {
let nested_val = get_nested_value(&self.signal_map, index_path)?;

match nested_val {
Expand All @@ -545,7 +545,7 @@ impl Component {
fn set_signal_map(
&mut self,
component_access: &[u32],
map: HashMap<String, Signal>,
map: BTreeMap<String, Signal>,
) -> Result<(), RuntimeError> {
let nested_val = get_mut_nested_value(&mut self.signal_map, component_access)?;

Expand Down Expand Up @@ -1040,7 +1040,7 @@ mod tests {
)
.unwrap();

let mut signal_map = HashMap::new();
let mut signal_map = BTreeMap::new();
let signal = Signal::new(&[], next_signal_id.clone());
signal_map.insert("signal1".to_string(), signal);

Expand Down Expand Up @@ -1285,7 +1285,7 @@ mod tests {
#[test]
fn test_component_set_and_get_signal_map() {
let mut component = Component::new(&[1]);
let mut signal_map = HashMap::new();
let mut signal_map = BTreeMap::new();
let signal = Signal::new(&[], Rc::new(RefCell::new(0)));
signal_map.insert("signal1".to_string(), signal);

Expand All @@ -1300,7 +1300,7 @@ mod tests {
#[test]
fn test_component_get_signal_content() {
let mut component = Component::new(&[1]);
let mut signal_map = HashMap::new();
let mut signal_map = BTreeMap::new();
let signal = Signal::new(&[], Rc::new(RefCell::new(0)));
signal_map.insert("signal1".to_string(), signal);

Expand All @@ -1319,7 +1319,7 @@ mod tests {
#[test]
fn test_component_get_signal_id() {
let mut component = Component::new(&[1]);
let mut signal_map = HashMap::new();
let mut signal_map = BTreeMap::new();
let signal = Signal::new(&[], Rc::new(RefCell::new(0)));
signal_map.insert("signal1".to_string(), signal);

Expand All @@ -1338,15 +1338,15 @@ mod tests {
#[test]
fn test_component_nested_signal_map() {
let mut component = Component::new(&[2]);
let mut signal_map_0 = HashMap::new();
let mut signal_map_0 = BTreeMap::new();
let signal_0 = Signal::new(&[], Rc::new(RefCell::new(0)));
signal_map_0.insert("signal1".to_string(), signal_0);

component
.set_signal_map(&[0], signal_map_0)
.expect("Setting signal map failed");

let mut signal_map_1 = HashMap::new();
let mut signal_map_1 = BTreeMap::new();
let signal_1 = Signal::new(&[], Rc::new(RefCell::new(1)));
signal_map_1.insert("signal2".to_string(), signal_1);

Expand Down