1414
1515#include "field_5x52_int128_impl.h"
1616
17+ #ifdef X86
18+ # include <immintrin.h>
19+ #endif
20+
1721#ifdef VERIFY
1822static void secp256k1_fe_impl_verify (const secp256k1_fe * a ) {
1923 const uint64_t * d = a -> n ;
@@ -37,10 +41,15 @@ static void secp256k1_fe_impl_get_bounds(secp256k1_fe *r, int m) {
3741 const uint64_t bound1 = 0xFFFFFFFFFFFFFULL * two_m ;
3842 const uint64_t bound2 = 0x0FFFFFFFFFFFFULL * two_m ;
3943
44+ #ifdef __AVX__
45+ __m256i vec = _mm256_set1_epi64x (bound1 );
46+ _mm256_storeu_si256 ((__m256i * )r -> n , vec );
47+ #else
4048 r -> n [0 ] = bound1 ;
4149 r -> n [1 ] = bound1 ;
4250 r -> n [2 ] = bound1 ;
4351 r -> n [3 ] = bound1 ;
52+ #endif
4453 r -> n [4 ] = bound2 ;
4554}
4655
@@ -239,6 +248,8 @@ static void secp256k1_fe_impl_set_b32_mod(secp256k1_fe *r, const unsigned char *
239248 limbs [3 ] = BYTESWAP_64 (limbs [3 ]);
240249#endif
241250
251+ /* TODO: parallelize avx2 */
252+
242253 r -> n [0 ] = (limbs [3 ] & 0xFFFFFFFFFFFFFULL );
243254 r -> n [1 ] = (limbs [3 ] >> 52 ) | ((limbs [2 ] & 0xFFFFFFFFFFULL ) << 12 );
244255 r -> n [2 ] = (limbs [2 ] >> 40 ) | ((limbs [1 ] & 0xFFFFFFFULL ) << 24 );
@@ -291,6 +302,10 @@ static void secp256k1_fe_impl_get_b32(unsigned char *r, const secp256k1_fe *a) {
291302}
292303
293304SECP256K1_INLINE static void secp256k1_fe_impl_negate_unchecked (secp256k1_fe * r , const secp256k1_fe * a , int m ) {
305+ #if defined(__AVX__ ) && defined(__AVX2__ )
306+ /* load here to mitigate load latency */
307+ __m256i vec_a = _mm256_loadu_si256 ((__m256i * )a -> n );
308+ #endif
294309 const uint32_t two_m1 = 2 * (m + 1 );
295310 const uint64_t bound1 = 0xFFFFEFFFFFC2FULL * two_m1 ;
296311 const uint64_t bound2 = 0xFFFFFFFFFFFFFULL * two_m1 ;
@@ -303,10 +318,18 @@ SECP256K1_INLINE static void secp256k1_fe_impl_negate_unchecked(secp256k1_fe *r,
303318
304319 /* Due to the properties above, the left hand in the subtractions below is never less than
305320 * the right hand. */
321+ #if defined(__AVX__ ) && defined(__AVX2__ )
322+ {
323+ __m256i vec_bounds = _mm256_setr_epi64x (bound1 , bound2 , bound2 , bound2 );
324+ __m256i out = _mm256_sub_epi64 (vec_bounds , vec_a );
325+ _mm256_storeu_si256 ((__m256i * )r -> n , out );
326+ }
327+ #else
306328 r -> n [0 ] = bound1 - a -> n [0 ];
307329 r -> n [1 ] = bound2 - a -> n [1 ];
308330 r -> n [2 ] = bound2 - a -> n [2 ];
309331 r -> n [3 ] = bound2 - a -> n [3 ];
332+ #endif
310333 r -> n [4 ] = bound3 - a -> n [4 ];
311334}
312335
@@ -339,15 +362,32 @@ SECP256K1_INLINE static void secp256k1_fe_impl_sqr(secp256k1_fe *r, const secp25
339362}
340363
341364SECP256K1_INLINE static void secp256k1_fe_impl_cmov (secp256k1_fe * r , const secp256k1_fe * a , int flag ) {
365+ #if defined(__AVX__ ) && defined(__AVX2__ )
366+ /* load here to mitigate load latency */
367+ __m256i vec_r = _mm256_loadu_si256 ((__m256i * )(r -> n ));
368+ __m256i vec_a = _mm256_loadu_si256 ((__m256i * )(a -> n ));
369+ #endif
370+
342371 uint64_t mask0 , mask1 ;
343372 volatile int vflag = flag ;
344373 SECP256K1_CHECKMEM_CHECK_VERIFY (r -> n , sizeof (r -> n ));
345374 mask0 = vflag + ~((uint64_t )0 );
346375 mask1 = ~mask0 ;
376+
377+ #if defined(__AVX__ ) && defined(__AVX2__ )
378+ {
379+ __m256i vec_mask0 = _mm256_set1_epi64x (mask0 );
380+ __m256i vec_mask1 = _mm256_set1_epi64x (mask1 );
381+ vec_r = _mm256_and_si256 (vec_r , vec_mask0 );
382+ vec_a = _mm256_and_si256 (vec_a , vec_mask1 );
383+ _mm256_storeu_si256 ((__m256i * )r -> n , _mm256_or_si256 (vec_r , vec_a ));
384+ }
385+ #else
347386 r -> n [0 ] = (r -> n [0 ] & mask0 ) | (a -> n [0 ] & mask1 );
348387 r -> n [1 ] = (r -> n [1 ] & mask0 ) | (a -> n [1 ] & mask1 );
349388 r -> n [2 ] = (r -> n [2 ] & mask0 ) | (a -> n [2 ] & mask1 );
350389 r -> n [3 ] = (r -> n [3 ] & mask0 ) | (a -> n [3 ] & mask1 );
390+ #endif
351391 r -> n [4 ] = (r -> n [4 ] & mask0 ) | (a -> n [4 ] & mask1 );
352392}
353393
@@ -418,19 +458,42 @@ static SECP256K1_INLINE void secp256k1_fe_storage_cmov(secp256k1_fe_storage *r,
418458}
419459
420460static void secp256k1_fe_impl_to_storage (secp256k1_fe_storage * r , const secp256k1_fe * a ) {
461+ #if defined(__AVX__ ) && defined(__AVX2__ )
462+ __m256i limbs_0123 = _mm256_loadu_si256 ((__m256i * )a -> n );
463+ __m256i limbs_1234 = _mm256_loadu_si256 ((__m256i * )(a -> n + 1 ));
464+ const __m256i shift_lhs = _mm256_setr_epi64x (0 , 12 , 24 , 36 ); /* TODO: precompute */
465+ const __m256i shift_rhs = _mm256_setr_epi64x (52 , 40 , 28 , 16 ); /* TODO: precompute */
466+ __m256i rhs = _mm256_sllv_epi64 (limbs_1234 , shift_rhs );
467+ __m256i lhs = _mm256_srlv_epi64 (limbs_0123 , shift_lhs );
468+ _mm256_storeu_si256 ((__m256i * )r -> n , _mm256_or_si256 (lhs , rhs ));
469+ #else
421470 r -> n [0 ] = a -> n [0 ] | a -> n [1 ] << 52 ;
422471 r -> n [1 ] = a -> n [1 ] >> 12 | a -> n [2 ] << 40 ;
423472 r -> n [2 ] = a -> n [2 ] >> 24 | a -> n [3 ] << 28 ;
424473 r -> n [3 ] = a -> n [3 ] >> 36 | a -> n [4 ] << 16 ;
474+ #endif
425475}
426476
427477static SECP256K1_INLINE void secp256k1_fe_impl_from_storage (secp256k1_fe * r , const secp256k1_fe_storage * a ) {
428478 const uint64_t a0 = a -> n [0 ], a1 = a -> n [1 ], a2 = a -> n [2 ], a3 = a -> n [3 ];
429479
480+ #if defined(__AVX__ ) && defined(__AVX2__ )
481+ {
482+ __m256i limbs_0123 = _mm256_setr_epi64x (a0 , a1 , a2 , a3 );
483+ __m256i limbs_0012 = _mm256_setr_epi64x (a0 , a0 , a1 , a2 );
484+ const __m256i shift_lhs = _mm256_setr_epi64x (64 , 52 , 40 , 28 ); /* TODO: precompute */
485+ const __m256i shift_rhs = _mm256_setr_epi64x (0 , 12 , 24 , 36 ); /* TODO: precompute */
486+ const __m256i mask52 = _mm256_set1_epi64x (0xFFFFFFFFFFFFFULL ); /* TODO: precompute */
487+ __m256i rhs = _mm256_and_si256 (_mm256_sllv_epi64 (limbs_0123 , shift_rhs ), mask52 );
488+ __m256i lhs = _mm256_srlv_epi64 (limbs_0012 , shift_lhs );
489+ _mm256_storeu_si256 ((__m256i * )r -> n , _mm256_or_si256 (lhs , rhs ));
490+ }
491+ #else
430492 r -> n [0 ] = a0 & 0xFFFFFFFFFFFFFULL ;
431493 r -> n [1 ] = a0 >> 52 | ((a1 << 12 ) & 0xFFFFFFFFFFFFFULL );
432494 r -> n [2 ] = a1 >> 40 | ((a2 << 24 ) & 0xFFFFFFFFFFFFFULL );
433495 r -> n [3 ] = a2 >> 28 | ((a3 << 36 ) & 0xFFFFFFFFFFFFFULL );
496+ #endif
434497 r -> n [4 ] = a3 >> 16 ;
435498}
436499
@@ -447,21 +510,49 @@ static void secp256k1_fe_from_signed62(secp256k1_fe *r, const secp256k1_modinv64
447510 VERIFY_CHECK (a3 >> 62 == 0 );
448511 VERIFY_CHECK (a4 >> 8 == 0 );
449512
513+ #if defined(__AVX__ ) && defined(__AVX2__ )
514+ {
515+ __m256i limbs_0123 = _mm256_setr_epi64x (a0 , a1 , a2 , a3 );
516+ __m256i limbs_0012 = _mm256_setr_epi64x (a0 , a0 , a1 , a2 );
517+ const __m256i shift_lhs = _mm256_setr_epi64x (64 , 52 , 42 , 32 ); /*TODO: precompute */
518+ const __m256i shift_rhs = _mm256_setr_epi64x (0 , 10 , 20 , 30 ); /*TODO: precompute */
519+ const __m256i mask52 = _mm256_set1_epi64x (M52 ); /*TODO: precompute */
520+ __m256i rhs = _mm256_sllv_epi64 (limbs_0123 , shift_rhs );
521+ __m256i lhs = _mm256_srlv_epi64 (limbs_0012 , shift_lhs );
522+ __m256i out = _mm256_or_si256 (lhs , rhs );
523+ _mm256_storeu_si256 ((__m256i * )r -> n , _mm256_and_si256 (out , mask52 ));
524+ }
525+ #else
450526 r -> n [0 ] = a0 & M52 ;
451527 r -> n [1 ] = (a0 >> 52 | a1 << 10 ) & M52 ;
452528 r -> n [2 ] = (a1 >> 42 | a2 << 20 ) & M52 ;
453529 r -> n [3 ] = (a2 >> 32 | a3 << 30 ) & M52 ;
530+ #endif
454531 r -> n [4 ] = (a3 >> 22 | a4 << 40 );
455532}
456533
457534static void secp256k1_fe_to_signed62 (secp256k1_modinv64_signed62 * r , const secp256k1_fe * a ) {
458535 const uint64_t M62 = UINT64_MAX >> 2 ;
459536 const uint64_t a0 = a -> n [0 ], a1 = a -> n [1 ], a2 = a -> n [2 ], a3 = a -> n [3 ], a4 = a -> n [4 ];
460537
538+ #if defined(__AVX__ ) && defined(__AVX2__ )
539+ {
540+ __m256i limbs_0123 = _mm256_setr_epi64x (a0 , a1 , a2 , a3 );
541+ __m256i limbs_1234 = _mm256_setr_epi64x (a1 , a2 , a3 , a4 );
542+ const __m256i shift_lhs = _mm256_setr_epi64x (0 , 10 , 20 , 30 ); /*TODO: precompute */
543+ const __m256i shift_rhs = _mm256_setr_epi64x (52 , 42 , 32 , 22 ); /*TODO: precompute */
544+ const __m256i mask62 = _mm256_set1_epi64x (M62 ); /*TODO: precompute */
545+ __m256i lhs = _mm256_srlv_epi64 (limbs_0123 , shift_lhs );
546+ __m256i rhs = _mm256_sllv_epi64 (limbs_1234 , shift_rhs );
547+ __m256i out = _mm256_or_si256 (lhs , rhs );
548+ _mm256_storeu_si256 ((__m256i * )r -> v , _mm256_and_si256 (out , mask62 ));
549+ }
550+ #else
461551 r -> v [0 ] = (a0 | a1 << 52 ) & M62 ;
462552 r -> v [1 ] = (a1 >> 10 | a2 << 42 ) & M62 ;
463553 r -> v [2 ] = (a2 >> 20 | a3 << 32 ) & M62 ;
464554 r -> v [3 ] = (a3 >> 30 | a4 << 22 ) & M62 ;
555+ #endif
465556 r -> v [4 ] = a4 >> 40 ;
466557}
467558
0 commit comments