Skip to content

Commit

Permalink
feat: add membership check gadget
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Jan 7, 2025
1 parent 16c5532 commit a4e87a5
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 2 deletions.
2 changes: 1 addition & 1 deletion crates/proof-of-sql/benches/bench_append_rows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! ```bash
//! cargo bench --features "test" --bench bench_append_rows
//! ```
#![allow(missing_docs, clippy::missing_docs_in_private_items)]
#![allow(deprecated, missing_docs, clippy::missing_docs_in_private_items)]
use ark_std::test_rng;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use proof_of_sql::{
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub(super) use column_comparison_operation::{
};

mod column_index_operation;
pub(super) use column_index_operation::apply_column_to_indexes;
pub(crate) use column_index_operation::apply_column_to_indexes;

mod column_repetition_operation;
pub(super) use column_repetition_operation::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp};
Expand Down
13 changes: 13 additions & 0 deletions crates/proof-of-sql/src/base/map.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::hash::Hash;
pub(crate) type IndexMap<K, V> =
indexmap::IndexMap<K, V, core::hash::BuildHasherDefault<ahash::AHasher>>;
pub(crate) type IndexSet<T> = indexmap::IndexSet<T, core::hash::BuildHasherDefault<ahash::AHasher>>;
Expand Down Expand Up @@ -40,3 +41,15 @@ macro_rules! indexset_macro {

pub(crate) use indexmap_macro as indexmap;
pub(crate) use indexset_macro as indexset;

/// Create an [`IndexMap`][self::IndexMap] of counts from a list of values
pub(crate) fn counts_as_indexmap<T>(iter: impl IntoIterator<Item = T>) -> IndexMap<T, usize>
where
T: Eq + Hash,
{
let mut map = IndexMap::default();
for val in iter {
*map.entry(val).or_insert(0) += 1;
}
map
}
154 changes: 154 additions & 0 deletions crates/proof-of-sql/src/sql/proof_gadgets/membership_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use crate::{
base::{
database::Column, map::counts_as_indexmap, proof::ProofError, scalar::Scalar, slice_ops,
},
sql::{
proof::{
FinalRoundBuilder, FirstRoundBuilder, SumcheckSubpolynomialType, VerificationBuilder,
},
proof_plans::{fold_columns, fold_vals},
},
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use num_traits::{One, Zero};

/// Perform first round evaluation of the membership check.
#[allow(dead_code)]
pub(crate) fn first_round_evaluate_membership_check<'a, S: Scalar>(
builder: &mut FirstRoundBuilder<'a, S>,
indexes: &[usize],
num_rows: usize,
alloc: &'a Bump,
) {
let multiplicity_map = counts_as_indexmap(indexes.iter());
let multiplicities = (0..num_rows - 1)
.map(|i| multiplicity_map.get(&i).copied().unwrap_or(0) as i128)
.collect::<Vec<_>>();
let alloc_multiplicities = alloc.alloc_slice_copy(&multiplicities);
builder.produce_intermediate_mle(alloc_multiplicities as &[_]);
builder.request_post_result_challenges(2);
}

/// Perform final round evaluation of the membership check.
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn final_round_evaluate_membership_check<'a, S: Scalar>(
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
alpha: S,
beta: S,
columns: &[Column<'a, S>],
candidate_subset: &[Column<'a, S>],
indexes: &[usize],
num_rows: usize,
candidate_num_rows: usize,
) {
// 1. Get multiplicity of each index
let multiplicity_map = counts_as_indexmap(indexes.iter());
let multiplicities = (0..num_rows - 1)
.map(|i| multiplicity_map.get(&i).copied().unwrap_or(0) as i128)
.collect::<Vec<_>>();
let alloc_multiplicities = alloc.alloc_slice_copy(&multiplicities);
builder.produce_intermediate_mle(alloc_multiplicities as &[_]);
// 2. Fold the columns
let input_ones = alloc.alloc_slice_fill_copy(num_rows, true);
let candidate_ones = alloc.alloc_slice_fill_copy(candidate_num_rows, true);

let c_fold = alloc.alloc_slice_fill_copy(num_rows, Zero::zero());
fold_columns(c_fold, alpha, beta, columns);
let d_fold = alloc.alloc_slice_fill_copy(candidate_num_rows, Zero::zero());
fold_columns(d_fold, alpha, beta, candidate_subset);

let c_star = alloc.alloc_slice_copy(c_fold);
slice_ops::add_const::<S, S>(c_star, One::one());
slice_ops::batch_inversion(c_star);

let d_star = alloc.alloc_slice_copy(d_fold);
slice_ops::add_const::<S, S>(d_star, One::one());
slice_ops::batch_inversion(d_star);

builder.produce_intermediate_mle(c_star as &[_]);
builder.produce_intermediate_mle(d_star as &[_]);

// sum c_star * multiplicities - d_star = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::ZeroSum,
vec![
(
S::one(),
vec![
Box::new(c_star as &[_]),
Box::new(alloc_multiplicities as &[_]),
],
),
(-S::one(), vec![Box::new(d_star as &[_])]),
],
);

// c_star + c_fold * c_star - input_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(c_star as &[_])]),
(
S::one(),
vec![Box::new(c_star as &[_]), Box::new(c_fold as &[_])],
),
(-S::one(), vec![Box::new(input_ones as &[_])]),
],
);

// d_star + d_fold * d_star - candidate_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(d_star as &[_])]),
(
S::one(),
vec![Box::new(d_star as &[_]), Box::new(d_fold as &[_])],
),
(-S::one(), vec![Box::new(candidate_ones as &[_])]),
],
);
}

#[allow(dead_code)]
pub(crate) fn verify_membership_check<S: Scalar>(
builder: &mut VerificationBuilder<S>,
alpha: S,
beta: S,
input_one_eval: S,
candidate_one_eval: S,
column_evals: &[S],
candidate_evals: &[S],
) -> Result<(), ProofError> {
let multiplicity_eval = builder.try_consume_first_round_mle_evaluation()?;
let c_fold_eval = alpha * fold_vals(beta, column_evals);
let d_fold_eval = alpha * fold_vals(beta, candidate_evals);
let c_star_eval = builder.try_consume_final_round_mle_evaluation()?;
let d_star_eval = builder.try_consume_final_round_mle_evaluation()?;

// sum c_star * multiplicities - d_star = 0
builder.try_produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::ZeroSum,
c_star_eval * multiplicity_eval - d_star_eval,
2,
)?;

// c_star + c_fold * c_star - input_ones = 0
builder.try_produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
c_star_eval + c_fold_eval * c_star_eval - input_one_eval,
2,
)?;

// d_star + d_fold * d_star - candidate_ones = 0
builder.try_produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
d_star_eval + d_fold_eval * d_star_eval - candidate_one_eval,
2,
)?;

Ok(())
}
167 changes: 167 additions & 0 deletions crates/proof-of-sql/src/sql/proof_gadgets/membership_check_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
//! This module contains the implementation of the `MembershipCheckTestPlan` struct. This struct
//! is used to check whether the membership check gadgets work correctly.
use super::membership_check::{
final_round_evaluate_membership_check, first_round_evaluate_membership_check,
verify_membership_check,
};
use crate::{
base::{
database::{
apply_column_to_indexes, Column, ColumnField, ColumnRef, OwnedTable, Table,
TableEvaluation, TableOptions, TableRef,
},
map::{IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
},
sql::proof::{
FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, VerificationBuilder,
},
};
use bumpalo::{
collections::{CollectIn, Vec as BumpVec},
Bump,
};
use serde::Serialize;

#[derive(Debug, Serialize)]
pub struct MembershipCheckTestPlan {
pub columns: [[i64; 2]; 2],
pub indexes: Vec<usize>,
}

impl ProverEvaluate for MembershipCheckTestPlan {
#[doc = "Evaluate the query, modify `FirstRoundBuilder` and return the result."]
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder<'a, S>,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S> {
// Produce one evaluation lengths
builder.produce_one_evaluation_length(2);
builder.produce_one_evaluation_length(self.indexes.len());
// Evaluate the first round
first_round_evaluate_membership_check(builder, &self.indexes, 2, alloc);
// This is just a dummy table, the actual data is not used
Table::try_new_with_options(IndexMap::default(), TableOptions { row_count: Some(0) })
.unwrap()
}

fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S> {
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();
// Build the Columns
let final_columns = self
.columns
.iter()
.map(|raw| {
let col = Column::BigInt(alloc.alloc_slice_copy(raw));
builder.produce_intermediate_mle(col);
col
})
.collect_in::<BumpVec<_>>(alloc);
// Build the candidate subset
let candidate_subset = self
.columns
.iter()
.map(|raw| {
let col = Column::BigInt(alloc.alloc_slice_copy(raw));
let col_with_indexes = apply_column_to_indexes(&col, alloc, &self.indexes).unwrap();
builder.produce_intermediate_mle(col_with_indexes);
col_with_indexes
})
.collect_in::<BumpVec<_>>(alloc);
// Perform final membership check
final_round_evaluate_membership_check(
builder,
alloc,
alpha,
beta,
&final_columns,
&candidate_subset,
&self.indexes,
2,
self.indexes.len(),
);

// Return a dummy table
Table::try_new_with_options(IndexMap::default(), TableOptions { row_count: Some(0) })
.unwrap()
}
}

impl ProofPlan for MembershipCheckTestPlan {
fn get_column_result_fields(&self) -> Vec<ColumnField> {
vec![]
}

fn get_column_references(&self) -> IndexSet<ColumnRef> {
IndexSet::default()
}

#[doc = "Return all the tables referenced in the Query"]
fn get_table_references(&self) -> IndexSet<TableRef> {
IndexSet::default()
}

#[doc = "Form components needed to verify and proof store into `VerificationBuilder`"]
fn verifier_evaluate<S: Scalar>(
&self,
builder: &mut VerificationBuilder<S>,
_accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
_one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
// Get the challenges from the builder
let alpha = builder.try_consume_post_result_challenge()?;
let beta = builder.try_consume_post_result_challenge()?;
let num_columns = self.columns.len();
// Get the columns
let column_evals = builder.try_consume_final_round_mle_evaluations(num_columns)?;
// Get the target columns
let candidate_subset_evals =
builder.try_consume_final_round_mle_evaluations(num_columns)?;
// Get the one evaluations
let one_eval = builder.try_consume_one_evaluation()?;
let candidate_subset_one_eval = builder.try_consume_one_evaluation()?;
// Evaluate the verifier
verify_membership_check(
builder,
alpha,
beta,
one_eval,
candidate_subset_one_eval,
&column_evals,
&candidate_subset_evals,
)?;
Ok(TableEvaluation::new(vec![], S::zero()))
}
}

#[cfg(all(test, feature = "blitzar"))]
mod tests {
use super::*;
use crate::{
base::database::{TableTestAccessor, TestAccessor},
sql::proof::VerifiableQueryResult,
};
use blitzar::proof::InnerProductProof;

#[test]
fn we_can_do_membership_check() {
let indexes = vec![0, 0, 1, 1, 0, 1, 1, 0];
let plan = MembershipCheckTestPlan {
columns: [[1, 2], [3, 4]],
indexes,
};
let accessor = TableTestAccessor::<InnerProductProof>::new_empty();
let verifiable_res = VerifiableQueryResult::<InnerProductProof>::new(&plan, &accessor, &());
assert!(verifiable_res.verify(&plan, &accessor, &()).is_ok());
}
}
8 changes: 8 additions & 0 deletions crates/proof-of-sql/src/sql/proof_gadgets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ mod bitwise_verification;
use bitwise_verification::{verify_constant_abs_decomposition, verify_constant_sign_decomposition};
#[cfg(test)]
mod bitwise_verification_test;
mod membership_check;
#[allow(unused_imports)]
use membership_check::{
final_round_evaluate_membership_check, first_round_evaluate_membership_check,
verify_membership_check,
};
#[cfg(test)]
mod membership_check_test;
mod sign_expr;
pub(crate) use sign_expr::{prover_evaluate_sign, result_evaluate_sign, verifier_evaluate_sign};
pub mod range_check;
Expand Down

0 comments on commit a4e87a5

Please sign in to comment.