@@ -26,19 +26,24 @@ mod tests {
26
26
27
27
#[ test]
28
28
fn test_polymul_ntt_square_modulus ( ) {
29
- let moduli = [ 17 * 17 , 12289 * 12289 ] ; // Different moduli to test
30
- let n: usize = 8 ; // Length of the NTT (must be a power of 2)
31
-
32
- for & modulus in & moduli {
33
- let omega = omega ( modulus, n) ; // n-th root of unity
34
- let mut a = vec ! [ 1 , 2 , 3 , 4 ] ;
35
- let mut b = vec ! [ 5 , 6 , 7 , 8 ] ;
36
- a. resize ( n, 0 ) ;
37
- b. resize ( n, 0 ) ;
38
- let c_std = polymul ( & a, & b, n as i64 , modulus) ;
39
- let c_fast = polymul_ntt ( & a, & b, n, modulus, omega) ;
40
- assert_eq ! ( c_std, c_fast, "The results of polymul and polymul_ntt do not match" ) ;
29
+ let cases = [
30
+ ( 17 * 17 , 4 ) , // small square modulus
31
+ ( 12289 * 12289 , 512 ) // large square modulus
32
+ ] ;
33
+
34
+ for & ( modulus, n) in & cases {
35
+ let omega = omega ( modulus, 2 * n) ; // n-th root of unity
36
+ let mut a: Vec < i64 > = ( 0 ..n) . map ( |x| x as i64 ) . collect ( ) ;
37
+ let mut b: Vec < i64 > = ( 0 ..n) . map ( |x| x as i64 ) . collect ( ) ;
38
+ a. resize ( 2 * n, 0 ) ;
39
+ b. resize ( 2 * n, 0 ) ;
40
+
41
+ let c_std = polymul ( & a, & b, 2 * n as i64 , modulus) ;
42
+ let c_fast = polymul_ntt ( & a, & b, 2 * n, modulus, omega) ;
43
+
44
+ assert_eq ! ( c_std, c_fast, "The results of polymul and polymul_ntt do not match for modulus {} and n {}" , modulus, n) ;
41
45
}
46
+
42
47
}
43
48
44
49
#[ test]
0 commit comments