@@ -84,8 +84,8 @@ impl Poly2 {
8484}
8585
8686/// Raises `x` to the power `n` using binary exponentiation,
87- /// with (1 to 2)*lg(n) scalar multiplications.
88- /// TODO: a consttime version of this would be awfully similar to a Montgomery ladder.
87+ /// with ` (1 to 2)*lg(n)` scalar multiplications.
88+ /// TODO: a consttime version of this would be similar to a Montgomery ladder.
8989pub fn scalar_exp_vartime ( x : & Scalar , mut n : u64 ) -> Scalar {
9090 let mut result = Scalar :: one ( ) ;
9191 let mut aux = * x; // x, x^2, x^4, x^8, ...
@@ -95,38 +95,85 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
9595 result = result * aux;
9696 }
9797 n = n >> 1 ;
98- aux = aux * aux; // FIXME: one unnecessary mult at the last step here!
98+ if n > 0 {
99+ aux = aux * aux;
100+ }
99101 }
100102 result
101103}
102104
103- /// Takes the sum of all the powers of `x`, up to `n`
104- /// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplcations and additions.
105- /// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions.
106- /// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2.
107- pub fn sum_of_powers ( x : & Scalar , n : usize ) -> Scalar {
108- if !n. is_power_of_two ( ) {
109- return sum_of_powers_slow ( x, n) ;
110- }
111- if n == 0 || n == 1 {
112- return Scalar :: from_u64 ( n as u64 ) ;
113- }
114- let mut m = n;
115- let mut result = Scalar :: one ( ) + x;
116- let mut factor = * x;
117- while m > 2 {
118- factor = factor * factor;
119- result = result + factor * result;
120- m = m / 2 ;
105+ /// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\)
106+ /// using \\(O(\lg n)\\) multiplications and additions. Length \\(n\\) is not considered secret
107+ /// and algorithm is fastest when \\(n\\) is the power of two.
108+ ///
109+ /// ### Algorithm overview
110+ ///
111+ /// First, let \\(n\\) be a power of two.
112+ /// Then, we can divide the polynomial in two halves like so:
113+ /// \\[
114+ /// \begin{aligned}
115+ /// S(n) &= (1+\dots+x^{n-1}) \\\\
116+ /// &= (1+\dots+x^{n/2-1}) + x^{n/2} (1+\dots+x^{n/2-1}) \\\\
117+ /// &= s + x^{n/2} s.
118+ /// \end{aligned}
119+ /// \\]
120+ /// We can divide each \\(s\\) in half until we arrive to a degree-1 polynomial \\((1+x\cdot 1)\\).
121+ /// Recursively, the total sum can be defined as:
122+ /// \\[
123+ /// \begin{aligned}
124+ /// S(0) &= 0 \\\\
125+ /// S(n) &= s_{\lg n} \\\\
126+ /// s_0 &= 1 \\\\
127+ /// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1}
128+ /// \end{aligned}
129+ /// \\]
130+ /// This representation allows us to square \\(x\\) only \\(\lg n\\) times.
131+ ///
132+ /// Lets apply this to \\(n\\) which is not a power of two (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using
133+ /// bits \\(b_i\\) in \\(\\{0,1\\}\\):
134+ /// \\[
135+ /// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1}
136+ /// \\]
137+ /// If we scan the bits of \\(n\\) from low to high (\\(i \in [0,k)\\)),
138+ /// we can conditionally (if \\(b_i = 1\\)) add to a resulting scalar
139+ /// an intermediate polynomial with \\(2^i\\) terms using the above algorithm,
140+ /// provided we offset the polynomial by \\(x^{n_i}\\), the next power of \\(x\\)
141+ /// for the existing sum, where \\(n_i = \sum_{j=0}^{i-1} b_j 2^j\\).
142+ ///
143+ /// The full algorithm becomes:
144+ /// \\[
145+ /// \begin{aligned}
146+ /// S(0) &= 0 \\\\
147+ /// S(1) &= 1 \\\\
148+ /// S(i) &= S(i-1) + x^{n_i} s_i b_i\\\\
149+ /// &= S(i-1) + x^{n_{i-1}} (x^{2^{i-1}})^{b_{i-1}} s_i b_i
150+ /// \end{aligned}
151+ /// \\]
152+ pub fn sum_of_powers ( x : & Scalar , mut n : usize ) -> Scalar {
153+ let mut result = Scalar :: zero ( ) ;
154+ let mut f = Scalar :: one ( ) ; // power of x to offset subsequent polynomials based on lower bits of n.
155+ let mut s = Scalar :: one ( ) ; // power-of-two polynomial: 1, 1+x, 1+x+x^2+x^3, ...
156+ let mut p = * x; // x, x^2, x^4, ..., x^{2^i}
157+ while n > 0 {
158+ // take a bit from n
159+ let bit = n & 1 ;
160+ n = n >> 1 ;
161+
162+ if bit == 1 {
163+ // bits of `n` are not secret, so it's okay to be vartime because of `n` value.
164+ result += f * s;
165+ if n > 0 { // avoid multiplication if no bits left
166+ f = f * p;
167+ }
168+ }
169+ if n > 0 { // avoid multiplication if no bits left
170+ s = s + p * s;
171+ p = p * p;
172+ }
121173 }
122174 result
123175}
124176
125- // takes the sum of all of the powers of x, up to n
126- fn sum_of_powers_slow ( x : & Scalar , n : usize ) -> Scalar {
127- exp_iter ( * x) . take ( n) . fold ( Scalar :: zero ( ) , |acc, x| acc + x)
128- }
129-
130177#[ cfg( test) ]
131178mod tests {
132179 use super :: * ;
@@ -185,9 +232,14 @@ mod tests {
185232 ) ;
186233 }
187234
235+ // takes the sum of all of the powers of x, up to n
236+ fn sum_of_powers_slow ( x : & Scalar , n : usize ) -> Scalar {
237+ exp_iter ( * x) . take ( n) . fold ( Scalar :: zero ( ) , |acc, x| acc + x)
238+ }
239+
188240 #[ test]
189- fn test_sum_of_powers ( ) {
190- let x = Scalar :: from_u64 ( 10 ) ;
241+ fn test_sum_of_powers_pow2 ( ) {
242+ let x = Scalar :: from_u64 ( 1337133713371337 ) ;
191243 assert_eq ! ( sum_of_powers_slow( & x, 0 ) , sum_of_powers( & x, 0 ) ) ;
192244 assert_eq ! ( sum_of_powers_slow( & x, 1 ) , sum_of_powers( & x, 1 ) ) ;
193245 assert_eq ! ( sum_of_powers_slow( & x, 2 ) , sum_of_powers( & x, 2 ) ) ;
@@ -199,14 +251,16 @@ mod tests {
199251 }
200252
201253 #[ test]
202- fn test_sum_of_powers_slow ( ) {
254+ fn test_sum_of_powers_non_pow2 ( ) {
203255 let x = Scalar :: from_u64 ( 10 ) ;
204- assert_eq ! ( sum_of_powers_slow( & x, 0 ) , Scalar :: zero( ) ) ;
205- assert_eq ! ( sum_of_powers_slow( & x, 1 ) , Scalar :: one( ) ) ;
206- assert_eq ! ( sum_of_powers_slow( & x, 2 ) , Scalar :: from_u64( 11 ) ) ;
207- assert_eq ! ( sum_of_powers_slow( & x, 3 ) , Scalar :: from_u64( 111 ) ) ;
208- assert_eq ! ( sum_of_powers_slow( & x, 4 ) , Scalar :: from_u64( 1111 ) ) ;
209- assert_eq ! ( sum_of_powers_slow( & x, 5 ) , Scalar :: from_u64( 11111 ) ) ;
210- assert_eq ! ( sum_of_powers_slow( & x, 6 ) , Scalar :: from_u64( 111111 ) ) ;
256+ assert_eq ! ( sum_of_powers( & x, 0 ) , Scalar :: zero( ) ) ;
257+ assert_eq ! ( sum_of_powers( & x, 1 ) , Scalar :: one( ) ) ;
258+ assert_eq ! ( sum_of_powers( & x, 2 ) , Scalar :: from_u64( 11 ) ) ;
259+ assert_eq ! ( sum_of_powers( & x, 3 ) , Scalar :: from_u64( 111 ) ) ;
260+ assert_eq ! ( sum_of_powers( & x, 4 ) , Scalar :: from_u64( 1111 ) ) ;
261+ assert_eq ! ( sum_of_powers( & x, 5 ) , Scalar :: from_u64( 11111 ) ) ;
262+ assert_eq ! ( sum_of_powers( & x, 6 ) , Scalar :: from_u64( 111111 ) ) ;
263+ assert_eq ! ( sum_of_powers( & x, 7 ) , Scalar :: from_u64( 1111111 ) ) ;
264+ assert_eq ! ( sum_of_powers( & x, 8 ) , Scalar :: from_u64( 11111111 ) ) ;
211265 }
212266}
0 commit comments