@@ -2,15 +2,25 @@ use reikna::totient::totient;
2
2
use reikna:: factor:: quick_factorize;
3
3
use std:: collections:: HashMap ;
4
4
5
- // Modular arithmetic functions using i64
5
+ /// Modular arithmetic functions using i64
6
6
fn mod_add ( a : i64 , b : i64 , p : i64 ) -> i64 {
7
7
( a + b) % p
8
8
}
9
9
10
+ /// Modular multiplication
10
11
fn mod_mul ( a : i64 , b : i64 , p : i64 ) -> i64 {
11
12
( a * b) % p
12
13
}
13
14
15
+ /// Modular exponentiation
16
+ /// # Arguments
17
+ ///
18
+ /// * `base` - Base of the exponentiation.
19
+ /// * `exp` - Exponent.
20
+ /// * `p` - Prime modulus for the operations.
21
+ ///
22
+ /// # Returns
23
+ /// The result of the exponentiation modulo `p`.
14
24
pub fn mod_exp ( mut base : i64 , mut exp : i64 , p : i64 ) -> i64 {
15
25
let mut result = 1 ;
16
26
base %= p;
@@ -24,6 +34,14 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
24
34
result
25
35
}
26
36
37
+ /// Extended Euclidean algorithm
38
+ /// # Arguments
39
+ ///
40
+ /// * `a` - First number.
41
+ /// * `b` - Second number.
42
+ ///
43
+ /// # Returns
44
+ /// A tuple with the greatest common divisor and the Bézout coefficients.
27
45
fn extended_gcd ( a : i64 , b : i64 ) -> ( i64 , i64 , i64 ) {
28
46
if b == 0 {
29
47
( a, 1 , 0 ) // gcd, x, y
@@ -33,15 +51,38 @@ fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
33
51
}
34
52
}
35
53
36
- pub fn mod_inv ( a : i64 , modulus : i64 ) -> i64 {
54
+ /// Compute the modular inverse of a modulo modulus
55
+ fn mod_inv ( a : i64 , modulus : i64 ) -> i64 {
37
56
let ( gcd, x, _) = extended_gcd ( a, modulus) ;
38
57
if gcd != 1 {
39
58
panic ! ( "{} and {} are not coprime, no inverse exists" , a, modulus) ;
40
59
}
41
60
( x % modulus + modulus) % modulus // Ensure a positive result
42
61
}
43
62
44
- // Compute n-th root of unity (omega) for p not necessarily prime
63
+ /// Compute n-th root of unity (omega) for p not necessarily prime
64
+ /// # Arguments
65
+ ///
66
+ /// * `modulus` - Modulus. n must divide each prime power factor.
67
+ /// * `n` - Order of the root of unity.
68
+ ///
69
+ /// # Returns
70
+ /// The n-th root of unity modulo `modulus`.
71
+ ///
72
+ /// # Examples
73
+ ///
74
+ /// ```
75
+ /// // For modulus = 17^2 = 289, we compute and verify an 8th root of unity.
76
+ /// let modulus = 17 * 17;
77
+ /// let n = 8;
78
+ /// let omega = ntt::omega(modulus, n);
79
+ /// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
80
+ ///
81
+ /// // For modulus = 17*41*73, we compute and verify an 8th root of unity.
82
+ /// let modulus = 17*41*73;
83
+ /// let omega = ntt::omega(modulus, n);
84
+ /// assert!(ntt::verify_root_of_unity(omega,n.try_into().unwrap(),modulus));
85
+ /// ```
45
86
pub fn omega ( modulus : i64 , n : usize ) -> i64 {
46
87
let factors = factorize ( modulus as i64 ) ;
47
88
if factors. len ( ) == 1 {
@@ -56,7 +97,29 @@ pub fn omega(modulus: i64, n: usize) -> i64 {
56
97
}
57
98
}
58
99
59
- // Forward transform using NTT, output bit-reversed
100
+ /// Forward transform using NTT, output bit-reversed
101
+ /// # Arguments
102
+ ///
103
+ /// * `a` - Input vector.
104
+ /// * `omega` - Primitive root of unity modulo `p`.
105
+ /// * `n` - Length of the input vector and the result.
106
+ /// * `p` - Prime modulus for the operations.
107
+ ///
108
+ /// # Returns
109
+ /// A vector representing the NTT of the input vector.
110
+ ///
111
+ /// # Examples
112
+ ///
113
+ /// ```
114
+ /// let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p
115
+ /// let n: usize = 8; // Length of the NTT (must be a power of 2)
116
+ /// let omega = ntt::omega(modulus, n); // n-th root of unity
117
+ /// let mut a = vec![1, 2, 3, 4];
118
+ /// a.resize(n, 0);
119
+ /// // Perform the forward NTT
120
+ /// let a_ntt = ntt::ntt(&a, omega, n, modulus);
121
+ /// let a_ntt_expected = vec![10, 15, 6, 7, 16, 13, 11, 15];
122
+ /// assert_eq!(a_ntt, a_ntt_expected);
60
123
pub fn ntt ( a : & [ i64 ] , omega : i64 , n : usize , p : i64 ) -> Vec < i64 > {
61
124
let mut result = a. to_vec ( ) ;
62
125
let mut step = n/2 ;
@@ -77,7 +140,16 @@ pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
77
140
result
78
141
}
79
142
80
- // Inverse transform using INTT, input bit-reversed
143
+ /// Inverse transform using INTT, input bit-reversed
144
+ /// # Arguments
145
+ ///
146
+ /// * `a` - Input vector (bit-reversed).
147
+ /// * `omega` - Primitive root of unity modulo `p`.
148
+ /// * `n` - Length of the input vector and the result.
149
+ /// * `p` - Prime modulus for the operations.
150
+ ///
151
+ /// # Returns
152
+ /// A vector representing the inverse NTT of the input vector.
81
153
pub fn intt ( a : & [ i64 ] , omega : i64 , n : usize , p : i64 ) -> Vec < i64 > {
82
154
let omega_inv = mod_inv ( omega, p) ;
83
155
let n_inv = mod_inv ( n as i64 , p) ;
@@ -103,7 +175,16 @@ pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
103
175
. collect ( )
104
176
}
105
177
106
- // Naive polynomial multiplication
178
+ /// Naive polynomial multiplication
179
+ /// # Arguments
180
+ ///
181
+ /// * `a` - First polynomial (as a vector of coefficients).
182
+ /// * `b` - Second polynomial (as a vector of coefficients).
183
+ /// * `n` - Length of the polynomials and the result.
184
+ /// * `p` - Prime modulus for the operations.
185
+ ///
186
+ /// # Returns
187
+ /// A vector representing the polynomial product modulo `p`.
107
188
pub fn polymul ( a : & Vec < i64 > , b : & Vec < i64 > , n : i64 , p : i64 ) -> Vec < i64 > {
108
189
let mut result = vec ! [ 0 ; n as usize ] ;
109
190
for i in 0 ..a. len ( ) {
@@ -145,7 +226,14 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i6
145
226
c
146
227
}
147
228
148
- /// Compute the prime factorization of `n` (with multiplicities).
229
+ /// Compute the prime factorization of `n` (with multiplicities)
230
+ /// Uses reikna::quick_factorize internally
231
+ /// # Arguments
232
+ ///
233
+ /// * `n` - Number to factorize.
234
+ ///
235
+ /// # Returns
236
+ /// A HashMap with the prime factors of `n` as keys and their multiplicities as values.
149
237
fn factorize ( n : i64 ) -> HashMap < i64 , u32 > {
150
238
let mut factors = HashMap :: new ( ) ;
151
239
for factor in quick_factorize ( n as u64 ) {
@@ -155,6 +243,23 @@ fn factorize(n: i64) -> HashMap<i64, u32> {
155
243
}
156
244
157
245
/// Fast computation of a primitive root mod p^e
246
+ /// Computes a primitive root mod p and lifts it to p^e by adding successive powers of p
247
+ /// # Arguments
248
+ ///
249
+ /// * `p` - Prime modulus.
250
+ /// * `e` - Exponent.
251
+ ///
252
+ /// # Returns
253
+ /// A primitive root modulo `p^e`.
254
+ ///
255
+ /// # Examples
256
+ ///
257
+ /// ```
258
+ /// // For p = 17 and e = 2, we compute a primitive root modulo 289.
259
+ /// let p = 17;
260
+ /// let e = 2;
261
+ /// let g = ntt::primitive_root(p, e);
262
+ /// assert_eq!(ntt::mod_exp(g, p*(p-1), p*p), 1);
158
263
pub fn primitive_root ( p : i64 , e : u32 ) -> i64 {
159
264
let g = primitive_root_mod_p ( p) ;
160
265
let mut g_lifted = g; // Lift it to p^e
@@ -167,6 +272,12 @@ pub fn primitive_root(p: i64, e: u32) -> i64 {
167
272
}
168
273
169
274
/// Finds a primitive root modulo a prime p
275
+ /// # Arguments
276
+ ///
277
+ /// * `p` - Prime modulus.
278
+ ///
279
+ /// # Returns
280
+ /// A primitive root modulo `p`.
170
281
fn primitive_root_mod_p ( p : i64 ) -> i64 {
171
282
let phi = p - 1 ;
172
283
let factors = factorize ( phi) ; // Reusing factorize to get both prime factors and multiplicities
@@ -179,7 +290,16 @@ fn primitive_root_mod_p(p: i64) -> i64 {
179
290
0 // Should never happen
180
291
}
181
292
182
- // the Chinese remainder theorem for two moduli
293
+ /// the Chinese remainder theorem for two moduli
294
+ /// # Arguments
295
+ ///
296
+ /// * `a1` - First residue.
297
+ /// * `n1` - First modulus.
298
+ /// * `a2` - Second residue.
299
+ /// * `n2` - Second modulus.
300
+ ///
301
+ /// # Returns
302
+ /// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2).
183
303
pub fn crt ( a1 : i64 , n1 : i64 , a2 : i64 , n2 : i64 ) -> i64 {
184
304
let n = n1 * n2;
185
305
let m1 = mod_inv ( n1, n2) ; // Inverse of n1 mod n2
@@ -188,10 +308,17 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
188
308
if x < 0 { x + n } else { x }
189
309
}
190
310
191
- // computes an n^th root of unity modulo a composite modulus
192
- // note we require that an n^th root of unity exists for each multiplicative group modulo p^e
193
- // use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
194
- // for the NTT, we require than a 2n^th root of unity exists
311
+ /// computes an n^th root of unity modulo a composite modulus
312
+ /// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
313
+ /// use the CRT isomorphism to pull back the list of n^th roots of unity to the composite modulus
314
+ /// for the NTT, we require than a 2n^th root of unity exists
315
+ /// # Arguments
316
+ ///
317
+ /// * `modulus` - Modulus. n must divide each prime power factor.
318
+ /// * `n` - Order of the root of unity.
319
+ ///
320
+ /// # Returns
321
+ /// The n-th root of unity modulo `modulus`.
195
322
pub fn root_of_unity ( modulus : i64 , n : i64 ) -> i64 {
196
323
let factors = factorize ( modulus) ;
197
324
let mut result = 1 ;
@@ -202,10 +329,17 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
202
329
result
203
330
}
204
331
205
- //ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
332
+ /// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
333
+ /// # Arguments
334
+ ///
335
+ /// * `omega` - n-th root of unity.
336
+ /// * `n` - Order of the root of unity.
337
+ /// * `modulus` - Modulus.
338
+ ///
339
+ /// # Returns
340
+ /// True if the root of unity satisfies the condition.
206
341
pub fn verify_root_of_unity ( omega : i64 , n : i64 , modulus : i64 ) -> bool {
207
342
assert ! ( mod_exp( omega, n, modulus as i64 ) == 1 , "omega is not an n-th root of unity" ) ;
208
343
assert ! ( mod_exp( omega, n/2 , modulus as i64 ) == modulus-1 , "omgea^(n/2) != -1 (mod modulus)" ) ;
209
344
true
210
- }
211
-
345
+ }
0 commit comments