Skip to content

add documentation #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 149 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,25 @@ use reikna::totient::totient;
use reikna::factor::quick_factorize;
use std::collections::HashMap;

// Modular arithmetic functions using i64
/// Modular arithmetic functions using i64
fn mod_add(a: i64, b: i64, p: i64) -> i64 {
(a + b) % p
}

/// Modular multiplication
fn mod_mul(a: i64, b: i64, p: i64) -> i64 {
(a * b) % p
}

/// Modular exponentiation
/// # Arguments
///
/// * `base` - Base of the exponentiation.
/// * `exp` - Exponent.
/// * `p` - Prime modulus for the operations.
///
/// # Returns
/// The result of the exponentiation modulo `p`.
pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
let mut result = 1;
base %= p;
Expand All @@ -24,6 +34,14 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
result
}

/// Extended Euclidean algorithm
/// # Arguments
///
/// * `a` - First number.
/// * `b` - Second number.
///
/// # Returns
/// A tuple with the greatest common divisor and the Bézout coefficients.
fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
if b == 0 {
(a, 1, 0) // gcd, x, y
Expand All @@ -33,15 +51,38 @@ fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
}
}

pub fn mod_inv(a: i64, modulus: i64) -> i64 {
/// Compute the modular inverse of a modulo modulus
fn mod_inv(a: i64, modulus: i64) -> i64 {
let (gcd, x, _) = extended_gcd(a, modulus);
if gcd != 1 {
panic!("{} and {} are not coprime, no inverse exists", a, modulus);
}
(x % modulus + modulus) % modulus // Ensure a positive result
}

// Compute n-th root of unity (omega) for p not necessarily prime
/// Compute n-th root of unity (omega) for p not necessarily prime
/// # Arguments
///
/// * `modulus` - Modulus. n must divide each prime power factor.
/// * `n` - Order of the root of unity.
///
/// # Returns
/// The n-th root of unity modulo `modulus`.
///
/// # Examples
///
/// ```
/// // For modulus = 17^2 = 289, we compute and verify an 8th root of unity.
/// let modulus = 17 * 17;
/// let n = 8;
/// let omega = ntt::omega(modulus, n);
/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
///
/// // For modulus = 17*41*73, we compute and verify an 8th root of unity.
/// let modulus = 17*41*73;
/// let omega = ntt::omega(modulus, n);
/// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
/// ```
pub fn omega(modulus: i64, n: usize) -> i64 {
let factors = factorize(modulus as i64);
if factors.len() == 1 {
Expand All @@ -56,7 +97,29 @@ pub fn omega(modulus: i64, n: usize) -> i64 {
}
}

// Forward transform using NTT, output bit-reversed
/// Forward transform using NTT, output bit-reversed
/// # Arguments
///
/// * `a` - Input vector.
/// * `omega` - Primitive root of unity modulo `p`.
/// * `n` - Length of the input vector and the result.
/// * `p` - Prime modulus for the operations.
///
/// # Returns
/// A vector representing the NTT of the input vector.
///
/// # Examples
///
/// ```
/// let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p
/// let n: usize = 8; // Length of the NTT (must be a power of 2)
/// let omega = ntt::omega(modulus, n); // n-th root of unity
/// let mut a = vec![1, 2, 3, 4];
/// a.resize(n, 0);
/// // Perform the forward NTT
/// let a_ntt = ntt::ntt(&a, omega, n, modulus);
/// let a_ntt_expected = vec![10, 15, 6, 7, 16, 13, 11, 15];
/// assert_eq!(a_ntt, a_ntt_expected);
pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
let mut result = a.to_vec();
let mut step = n/2;
Expand All @@ -77,7 +140,16 @@ pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
result
}

// Inverse transform using INTT, input bit-reversed
/// Inverse transform using INTT, input bit-reversed
/// # Arguments
///
/// * `a` - Input vector (bit-reversed).
/// * `omega` - Primitive root of unity modulo `p`.
/// * `n` - Length of the input vector and the result.
/// * `p` - Prime modulus for the operations.
///
/// # Returns
/// A vector representing the inverse NTT of the input vector.
pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
let omega_inv = mod_inv(omega, p);
let n_inv = mod_inv(n as i64, p);
Expand All @@ -103,7 +175,16 @@ pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
.collect()
}

// Naive polynomial multiplication
/// Naive polynomial multiplication
/// # Arguments
///
/// * `a` - First polynomial (as a vector of coefficients).
/// * `b` - Second polynomial (as a vector of coefficients).
/// * `n` - Length of the polynomials and the result.
/// * `p` - Prime modulus for the operations.
///
/// # Returns
/// A vector representing the polynomial product modulo `p`.
pub fn polymul(a: &Vec<i64>, b: &Vec<i64>, n: i64, p: i64) -> Vec<i64> {
let mut result = vec![0; n as usize];
for i in 0..a.len() {
Expand Down Expand Up @@ -145,7 +226,14 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i6
c
}

/// Compute the prime factorization of `n` (with multiplicities).
/// Compute the prime factorization of `n` (with multiplicities)
/// Uses reikna::quick_factorize internally
/// # Arguments
///
/// * `n` - Number to factorize.
///
/// # Returns
/// A HashMap with the prime factors of `n` as keys and their multiplicities as values.
fn factorize(n: i64) -> HashMap<i64, u32> {
let mut factors = HashMap::new();
for factor in quick_factorize(n as u64) {
Expand All @@ -155,6 +243,23 @@ fn factorize(n: i64) -> HashMap<i64, u32> {
}

/// Fast computation of a primitive root mod p^e
/// Computes a primitive root mod p and lifts it to p^e by adding successive powers of p
/// # Arguments
///
/// * `p` - Prime modulus.
/// * `e` - Exponent.
///
/// # Returns
/// A primitive root modulo `p^e`.
///
/// # Examples
///
/// ```
/// // For p = 17 and e = 2, we compute a primitive root modulo 289.
/// let p = 17;
/// let e = 2;
/// let g = ntt::primitive_root(p, e);
/// assert_eq!(ntt::mod_exp(g, p*(p-1), p*p), 1);
pub fn primitive_root(p: i64, e: u32) -> i64 {
let g = primitive_root_mod_p(p);
let mut g_lifted = g; // Lift it to p^e
Expand All @@ -167,6 +272,12 @@ pub fn primitive_root(p: i64, e: u32) -> i64 {
}

/// Finds a primitive root modulo a prime p
/// # Arguments
///
/// * `p` - Prime modulus.
///
/// # Returns
/// A primitive root modulo `p`.
fn primitive_root_mod_p(p: i64) -> i64 {
let phi = p - 1;
let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities
Expand All @@ -179,7 +290,16 @@ fn primitive_root_mod_p(p: i64) -> i64 {
0 // Should never happen
}

// the Chinese remainder theorem for two moduli
/// the Chinese remainder theorem for two moduli
/// # Arguments
///
/// * `a1` - First residue.
/// * `n1` - First modulus.
/// * `a2` - Second residue.
/// * `n2` - Second modulus.
///
/// # Returns
/// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2).
pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
let n = n1 * n2;
let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2
Expand All @@ -188,10 +308,17 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
if x < 0 { x + n } else { x }
}

// computes an n^th root of unity modulo a composite modulus
// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
// for the NTT, we require than a 2n^th root of unity exists
/// computes an n^th root of unity modulo a composite modulus
/// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
/// use the CRT isomorphism to pull back the list of n^th roots of unity to the composite modulus
/// for the NTT, we require than a 2n^th root of unity exists
/// # Arguments
///
/// * `modulus` - Modulus. n must divide each prime power factor.
/// * `n` - Order of the root of unity.
///
/// # Returns
/// The n-th root of unity modulo `modulus`.
pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
let factors = factorize(modulus);
let mut result = 1;
Expand All @@ -202,10 +329,17 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
result
}

//ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
/// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
/// # Arguments
///
/// * `omega` - n-th root of unity.
/// * `n` - Order of the root of unity.
/// * `modulus` - Modulus.
///
/// # Returns
/// True if the root of unity satisfies the condition.
pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool {
assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity");
assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)");
true
}

}