Skip to content

Commit a543ede

Browse files
committed
Add intel simd
1 parent ba13656 commit a543ede

File tree

4 files changed

+189
-13
lines changed

4 files changed

+189
-13
lines changed

src/field_5x52_impl.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
#include "field_5x52_int128_impl.h"
1616

17+
#ifdef X86
18+
# include <immintrin.h>
19+
#endif
20+
1721
#ifdef VERIFY
1822
static void secp256k1_fe_impl_verify(const secp256k1_fe *a) {
1923
const uint64_t *d = a->n;
@@ -37,10 +41,15 @@ static void secp256k1_fe_impl_get_bounds(secp256k1_fe *r, int m) {
3741
const uint64_t bound1 = 0xFFFFFFFFFFFFFULL * two_m;
3842
const uint64_t bound2 = 0x0FFFFFFFFFFFFULL * two_m;
3943

44+
#ifdef __AVX__
45+
__m256i vec = _mm256_set1_epi64x(bound1);
46+
_mm256_storeu_si256((__m256i *)r->n, vec);
47+
#else
4048
r->n[0] = bound1;
4149
r->n[1] = bound1;
4250
r->n[2] = bound1;
4351
r->n[3] = bound1;
52+
#endif
4453
r->n[4] = bound2;
4554
}
4655

@@ -239,6 +248,8 @@ static void secp256k1_fe_impl_set_b32_mod(secp256k1_fe *r, const unsigned char *
239248
limbs[3] = BYTESWAP_64(limbs[3]);
240249
#endif
241250

251+
/* TODO: parallelize avx2 */
252+
242253
r->n[0] = (limbs[3] & 0xFFFFFFFFFFFFFULL);
243254
r->n[1] = (limbs[3] >> 52) | ((limbs[2] & 0xFFFFFFFFFFULL) << 12);
244255
r->n[2] = (limbs[2] >> 40) | ((limbs[1] & 0xFFFFFFFULL) << 24);
@@ -291,6 +302,10 @@ static void secp256k1_fe_impl_get_b32(unsigned char *r, const secp256k1_fe *a) {
291302
}
292303

293304
SECP256K1_INLINE static void secp256k1_fe_impl_negate_unchecked(secp256k1_fe *r, const secp256k1_fe *a, int m) {
305+
#if defined(__AVX__) && defined(__AVX2__)
306+
/* load here to mitigate load latency */
307+
__m256i vec_a = _mm256_loadu_si256((__m256i *)a->n);
308+
#endif
294309
const uint32_t two_m1 = 2 * (m + 1);
295310
const uint64_t bound1 = 0xFFFFEFFFFFC2FULL * two_m1;
296311
const uint64_t bound2 = 0xFFFFFFFFFFFFFULL * two_m1;
@@ -303,10 +318,18 @@ SECP256K1_INLINE static void secp256k1_fe_impl_negate_unchecked(secp256k1_fe *r,
303318

304319
/* Due to the properties above, the left hand in the subtractions below is never less than
305320
* the right hand. */
321+
#if defined(__AVX__) && defined(__AVX2__)
322+
{
323+
__m256i vec_bounds = _mm256_setr_epi64x(bound1, bound2, bound2, bound2);
324+
__m256i out = _mm256_sub_epi64(vec_bounds, vec_a);
325+
_mm256_storeu_si256((__m256i *)r->n, out);
326+
}
327+
#else
306328
r->n[0] = bound1 - a->n[0];
307329
r->n[1] = bound2 - a->n[1];
308330
r->n[2] = bound2 - a->n[2];
309331
r->n[3] = bound2 - a->n[3];
332+
#endif
310333
r->n[4] = bound3 - a->n[4];
311334
}
312335

@@ -339,15 +362,32 @@ SECP256K1_INLINE static void secp256k1_fe_impl_sqr(secp256k1_fe *r, const secp25
339362
}
340363

341364
SECP256K1_INLINE static void secp256k1_fe_impl_cmov(secp256k1_fe *r, const secp256k1_fe *a, int flag) {
365+
#if defined(__AVX__) && defined(__AVX2__)
366+
/* load here to mitigate load latency */
367+
__m256i vec_r = _mm256_loadu_si256((__m256i *)(r->n));
368+
__m256i vec_a = _mm256_loadu_si256((__m256i *)(a->n));
369+
#endif
370+
342371
uint64_t mask0, mask1;
343372
volatile int vflag = flag;
344373
SECP256K1_CHECKMEM_CHECK_VERIFY(r->n, sizeof(r->n));
345374
mask0 = vflag + ~((uint64_t)0);
346375
mask1 = ~mask0;
376+
377+
#if defined(__AVX__) && defined(__AVX2__)
378+
{
379+
__m256i vec_mask0 = _mm256_set1_epi64x(mask0);
380+
__m256i vec_mask1 = _mm256_set1_epi64x(mask1);
381+
vec_r = _mm256_and_si256(vec_r, vec_mask0);
382+
vec_a = _mm256_and_si256(vec_a, vec_mask1);
383+
_mm256_storeu_si256((__m256i *)r->n, _mm256_or_si256(vec_r, vec_a));
384+
}
385+
#else
347386
r->n[0] = (r->n[0] & mask0) | (a->n[0] & mask1);
348387
r->n[1] = (r->n[1] & mask0) | (a->n[1] & mask1);
349388
r->n[2] = (r->n[2] & mask0) | (a->n[2] & mask1);
350389
r->n[3] = (r->n[3] & mask0) | (a->n[3] & mask1);
390+
#endif
351391
r->n[4] = (r->n[4] & mask0) | (a->n[4] & mask1);
352392
}
353393

@@ -418,19 +458,42 @@ static SECP256K1_INLINE void secp256k1_fe_storage_cmov(secp256k1_fe_storage *r,
418458
}
419459

420460
static void secp256k1_fe_impl_to_storage(secp256k1_fe_storage *r, const secp256k1_fe *a) {
461+
#if defined(__AVX__) && defined(__AVX2__)
462+
__m256i limbs_0123 = _mm256_loadu_si256((__m256i *)a->n);
463+
__m256i limbs_1234 = _mm256_loadu_si256((__m256i *)(a->n + 1));
464+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 12, 24, 36); /* TODO: precompute */
465+
const __m256i shift_rhs = _mm256_setr_epi64x(52, 40, 28, 16); /* TODO: precompute */
466+
__m256i rhs = _mm256_sllv_epi64(limbs_1234, shift_rhs);
467+
__m256i lhs = _mm256_srlv_epi64(limbs_0123, shift_lhs);
468+
_mm256_storeu_si256((__m256i *)r->n, _mm256_or_si256(lhs, rhs));
469+
#else
421470
r->n[0] = a->n[0] | a->n[1] << 52;
422471
r->n[1] = a->n[1] >> 12 | a->n[2] << 40;
423472
r->n[2] = a->n[2] >> 24 | a->n[3] << 28;
424473
r->n[3] = a->n[3] >> 36 | a->n[4] << 16;
474+
#endif
425475
}
426476

427477
static SECP256K1_INLINE void secp256k1_fe_impl_from_storage(secp256k1_fe *r, const secp256k1_fe_storage *a) {
428478
const uint64_t a0 = a->n[0], a1 = a->n[1], a2 = a->n[2], a3 = a->n[3];
429479

480+
#if defined(__AVX__) && defined(__AVX2__)
481+
{
482+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
483+
__m256i limbs_0012 = _mm256_setr_epi64x(a0, a0, a1, a2);
484+
const __m256i shift_lhs = _mm256_setr_epi64x(64, 52, 40, 28); /* TODO: precompute */
485+
const __m256i shift_rhs = _mm256_setr_epi64x(0, 12, 24, 36); /* TODO: precompute */
486+
const __m256i mask52 = _mm256_set1_epi64x(0xFFFFFFFFFFFFFULL); /* TODO: precompute */
487+
__m256i rhs = _mm256_and_si256(_mm256_sllv_epi64(limbs_0123, shift_rhs), mask52);
488+
__m256i lhs = _mm256_srlv_epi64(limbs_0012, shift_lhs);
489+
_mm256_storeu_si256((__m256i*)r->n, _mm256_or_si256(lhs, rhs));
490+
}
491+
#else
430492
r->n[0] = a0 & 0xFFFFFFFFFFFFFULL;
431493
r->n[1] = a0 >> 52 | ((a1 << 12) & 0xFFFFFFFFFFFFFULL);
432494
r->n[2] = a1 >> 40 | ((a2 << 24) & 0xFFFFFFFFFFFFFULL);
433495
r->n[3] = a2 >> 28 | ((a3 << 36) & 0xFFFFFFFFFFFFFULL);
496+
#endif
434497
r->n[4] = a3 >> 16;
435498
}
436499

@@ -447,21 +510,49 @@ static void secp256k1_fe_from_signed62(secp256k1_fe *r, const secp256k1_modinv64
447510
VERIFY_CHECK(a3 >> 62 == 0);
448511
VERIFY_CHECK(a4 >> 8 == 0);
449512

513+
#if defined(__AVX__) && defined(__AVX2__)
514+
{
515+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
516+
__m256i limbs_0012 = _mm256_setr_epi64x(a0, a0, a1, a2);
517+
const __m256i shift_lhs = _mm256_setr_epi64x(64, 52, 42, 32); /*TODO: precompute */
518+
const __m256i shift_rhs = _mm256_setr_epi64x(0, 10, 20, 30); /*TODO: precompute */
519+
const __m256i mask52 = _mm256_set1_epi64x(M52); /*TODO: precompute */
520+
__m256i rhs = _mm256_sllv_epi64(limbs_0123, shift_rhs);
521+
__m256i lhs = _mm256_srlv_epi64(limbs_0012, shift_lhs);
522+
__m256i out = _mm256_or_si256(lhs, rhs);
523+
_mm256_storeu_si256((__m256i*)r->n, _mm256_and_si256(out, mask52));
524+
}
525+
#else
450526
r->n[0] = a0 & M52;
451527
r->n[1] = (a0 >> 52 | a1 << 10) & M52;
452528
r->n[2] = (a1 >> 42 | a2 << 20) & M52;
453529
r->n[3] = (a2 >> 32 | a3 << 30) & M52;
530+
#endif
454531
r->n[4] = (a3 >> 22 | a4 << 40);
455532
}
456533

457534
static void secp256k1_fe_to_signed62(secp256k1_modinv64_signed62 *r, const secp256k1_fe *a) {
458535
const uint64_t M62 = UINT64_MAX >> 2;
459536
const uint64_t a0 = a->n[0], a1 = a->n[1], a2 = a->n[2], a3 = a->n[3], a4 = a->n[4];
460537

538+
#if defined(__AVX__) && defined(__AVX2__)
539+
{
540+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
541+
__m256i limbs_1234 = _mm256_setr_epi64x(a1, a2, a3, a4);
542+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 10, 20, 30); /*TODO: precompute */
543+
const __m256i shift_rhs = _mm256_setr_epi64x(52, 42, 32, 22); /*TODO: precompute */
544+
const __m256i mask62 = _mm256_set1_epi64x(M62); /*TODO: precompute */
545+
__m256i lhs = _mm256_srlv_epi64(limbs_0123, shift_lhs);
546+
__m256i rhs = _mm256_sllv_epi64(limbs_1234, shift_rhs);
547+
__m256i out = _mm256_or_si256(lhs, rhs);
548+
_mm256_storeu_si256((__m256i *)r->v, _mm256_and_si256(out, mask62));
549+
}
550+
#else
461551
r->v[0] = (a0 | a1 << 52) & M62;
462552
r->v[1] = (a1 >> 10 | a2 << 42) & M62;
463553
r->v[2] = (a2 >> 20 | a3 << 32) & M62;
464554
r->v[3] = (a3 >> 30 | a4 << 22) & M62;
555+
#endif
465556
r->v[4] = a4 >> 40;
466557
}
467558

src/hash_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
#include <stdint.h>
1515
#include <string.h>
1616

17+
#ifdef X86
18+
# include <immintrin.h>
19+
#endif
20+
1721
#define Ch(x,y,z) ((z) ^ ((x) & ((y) ^ (z))))
1822
#define Maj(x,y,z) (((x) & (y)) | ((z) & ((x) | (y))))
1923
#define Sigma0(x) (((x) >> 2 | (x) << 30) ^ ((x) >> 13 | (x) << 19) ^ ((x) >> 22 | (x) << 10))

src/scalar_4x64_impl.h

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include "modinv64_impl.h"
1313
#include "util.h"
1414

15+
#ifdef X86
16+
# include <immintrin.h>
17+
#endif
18+
1519
/* Limbs of the secp256k1 order. */
1620
#define SECP256K1_N_0 ((uint64_t)0xBFD25E8CD0364141ULL)
1721
#define SECP256K1_N_1 ((uint64_t)0xBAAEDCE6AF48A03BULL)
@@ -143,10 +147,25 @@ static void secp256k1_scalar_cadd_bit(secp256k1_scalar *r, unsigned int bit, int
143147

144148
static void secp256k1_scalar_set_b32(secp256k1_scalar *r, const unsigned char *b32, int *overflow) {
145149
int over;
150+
151+
#if defined(__AVX__) && defined(__AVX2__)
152+
{
153+
__m256i input = _mm256_loadu_si256((const __m256i*)b32);
154+
input = _mm256_permute4x64_epi64(input, _MM_SHUFFLE(0,1,2,3));
155+
const __m256i bswap_mask = _mm256_setr_epi8( /* TODO: precompute */
156+
7,6,5,4,3,2,1,0,
157+
15,14,13,12,11,10,9,8,
158+
23,22,21,20,19,18,17,16,
159+
31,30,29,28,27,26,25,24);
160+
__m256i output = _mm256_shuffle_epi8(input, bswap_mask);
161+
_mm256_storeu_si256((__m256i*)r->d, output);
162+
}
163+
#else
146164
r->d[0] = secp256k1_read_be64(&b32[24]);
147165
r->d[1] = secp256k1_read_be64(&b32[16]);
148166
r->d[2] = secp256k1_read_be64(&b32[8]);
149167
r->d[3] = secp256k1_read_be64(&b32[0]);
168+
#endif
150169
over = secp256k1_scalar_reduce(r, secp256k1_scalar_check_overflow(r));
151170
if (overflow) {
152171
*overflow = over;
@@ -158,6 +177,8 @@ static void secp256k1_scalar_set_b32(secp256k1_scalar *r, const unsigned char *b
158177
static void secp256k1_scalar_get_b32(unsigned char *bin, const secp256k1_scalar* a) {
159178
SECP256K1_SCALAR_VERIFY(a);
160179

180+
/* TODO: parallelize */
181+
161182
secp256k1_write_be64(&bin[0], a->d[3]);
162183
secp256k1_write_be64(&bin[8], a->d[2]);
163184
secp256k1_write_be64(&bin[16], a->d[1]);
@@ -166,7 +187,6 @@ static void secp256k1_scalar_get_b32(unsigned char *bin, const secp256k1_scalar*
166187

167188
SECP256K1_INLINE static int secp256k1_scalar_is_zero(const secp256k1_scalar *a) {
168189
SECP256K1_SCALAR_VERIFY(a);
169-
170190
return (a->d[0] | a->d[1] | a->d[2] | a->d[3]) == 0;
171191
}
172192

@@ -882,8 +902,16 @@ static void secp256k1_scalar_split_128(secp256k1_scalar *r1, secp256k1_scalar *r
882902
SECP256K1_INLINE static int secp256k1_scalar_eq(const secp256k1_scalar *a, const secp256k1_scalar *b) {
883903
SECP256K1_SCALAR_VERIFY(a);
884904
SECP256K1_SCALAR_VERIFY(b);
885-
905+
#if defined(__AVX__) && defined(__AVX2__)
906+
{
907+
__m256i vec_a = _mm256_loadu_si256((__m256i *)a->d);
908+
__m256i vec_b = _mm256_loadu_si256((__m256i *)b->d);
909+
__m256i vec_xor = _mm256_xor_si256(vec_a, vec_b);
910+
return _mm256_testz_si256(vec_xor, vec_xor);
911+
}
912+
#else
886913
return ((a->d[0] ^ b->d[0]) | (a->d[1] ^ b->d[1]) | (a->d[2] ^ b->d[2]) | (a->d[3] ^ b->d[3])) == 0;
914+
#endif
887915
}
888916

889917
SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r, const secp256k1_scalar *a, const secp256k1_scalar *b, unsigned int shift) {
@@ -899,6 +927,9 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
899927
shiftlimbs = shift >> 6;
900928
shiftlow = shift & 0x3F;
901929
shifthigh = 64 - shiftlow;
930+
931+
/* TODO: parallelize */
932+
902933
r->d[0] = shift < 512 ? (l[0 + shiftlimbs] >> shiftlow | (shift < 448 && shiftlow ? (l[1 + shiftlimbs] << shifthigh) : 0)) : 0;
903934
r->d[1] = shift < 448 ? (l[1 + shiftlimbs] >> shiftlow | (shift < 384 && shiftlow ? (l[2 + shiftlimbs] << shifthigh) : 0)) : 0;
904935
r->d[2] = shift < 384 ? (l[2 + shiftlimbs] >> shiftlow | (shift < 320 && shiftlow ? (l[3 + shiftlimbs] << shifthigh) : 0)) : 0;
@@ -909,37 +940,68 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
909940
}
910941

911942
static SECP256K1_INLINE void secp256k1_scalar_cmov(secp256k1_scalar *r, const secp256k1_scalar *a, int flag) {
943+
#if defined(__AVX__) && defined(__AVX2__)
944+
/* load here to mitigate load latency */
945+
__m256i vec_r = _mm256_loadu_si256((__m256i *)(r->d));
946+
__m256i vec_a = _mm256_loadu_si256((__m256i *)(a->d));
947+
#endif
948+
912949
uint64_t mask0, mask1;
913950
volatile int vflag = flag;
914951
SECP256K1_SCALAR_VERIFY(a);
915952
SECP256K1_CHECKMEM_CHECK_VERIFY(r->d, sizeof(r->d));
916953

917954
mask0 = vflag + ~((uint64_t)0);
918955
mask1 = ~mask0;
956+
957+
#if defined(__AVX__) && defined(__AVX2__)
958+
{
959+
__m256i vec_mask0 = _mm256_set1_epi64x(mask0);
960+
__m256i vec_mask1 = _mm256_set1_epi64x(mask1);
961+
vec_r = _mm256_and_si256(vec_r, vec_mask0);
962+
vec_a = _mm256_and_si256(vec_a, vec_mask1);
963+
_mm256_storeu_si256((__m256i *)(r->d), _mm256_or_si256(vec_r, vec_a));
964+
}
965+
#else
919966
r->d[0] = (r->d[0] & mask0) | (a->d[0] & mask1);
920967
r->d[1] = (r->d[1] & mask0) | (a->d[1] & mask1);
921968
r->d[2] = (r->d[2] & mask0) | (a->d[2] & mask1);
922969
r->d[3] = (r->d[3] & mask0) | (a->d[3] & mask1);
970+
#endif
923971

924972
SECP256K1_SCALAR_VERIFY(r);
925973
}
926974

927975
static void secp256k1_scalar_from_signed62(secp256k1_scalar *r, const secp256k1_modinv64_signed62 *a) {
928-
const uint64_t a0 = a->v[0], a1 = a->v[1], a2 = a->v[2], a3 = a->v[3], a4 = a->v[4];
929-
930976
/* The output from secp256k1_modinv64{_var} should be normalized to range [0,modulus), and
931977
* have limbs in [0,2^62). The modulus is < 2^256, so the top limb must be below 2^(256-62*4).
932978
*/
933-
VERIFY_CHECK(a0 >> 62 == 0);
934-
VERIFY_CHECK(a1 >> 62 == 0);
935-
VERIFY_CHECK(a2 >> 62 == 0);
936-
VERIFY_CHECK(a3 >> 62 == 0);
937-
VERIFY_CHECK(a4 >> 8 == 0);
979+
VERIFY_CHECK(a->v[0] >> 62 == 0);
980+
VERIFY_CHECK(a->v[1] >> 62 == 0);
981+
VERIFY_CHECK(a->v[2] >> 62 == 0);
982+
VERIFY_CHECK(a->v[3] >> 62 == 0);
983+
VERIFY_CHECK(a->v[4] >> 8 == 0);
984+
985+
#if defined(__AVX__) && defined(__AVX2__)
986+
{
987+
__m256i limbs_0123 = _mm256_loadu_si256((__m256i *)a->v);
988+
__m256i limbs_1234 = _mm256_loadu_si256((__m256i *)(a->v + 1));
989+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 2, 4, 6); /* TODO: precompute */
990+
const __m256i shift_rhs = _mm256_setr_epi64x(62, 60, 58, 56); /* TODO: precompute */
991+
__m256i lhs = _mm256_srlv_epi64(limbs_0123, shift_lhs);
992+
__m256i rhs = _mm256_sllv_epi64(limbs_1234, shift_rhs);
993+
_mm256_storeu_si256((__m256i *)(r->d), _mm256_or_si256(lhs, rhs));
994+
}
995+
#else
996+
{
997+
const uint64_t a0 = a->v[0], a1 = a->v[1], a2 = a->v[2], a3 = a->v[3], a4 = a->v[4];
938998

939-
r->d[0] = a0 | a1 << 62;
940-
r->d[1] = a1 >> 2 | a2 << 60;
941-
r->d[2] = a2 >> 4 | a3 << 58;
942-
r->d[3] = a3 >> 6 | a4 << 56;
999+
r->d[0] = a0 | a1 << 62;
1000+
r->d[1] = a1 >> 2 | a2 << 60;
1001+
r->d[2] = a2 >> 4 | a3 << 58;
1002+
r->d[3] = a3 >> 6 | a4 << 56;
1003+
}
1004+
#endif
9431005

9441006
SECP256K1_SCALAR_VERIFY(r);
9451007
}
@@ -949,10 +1011,24 @@ static void secp256k1_scalar_to_signed62(secp256k1_modinv64_signed62 *r, const s
9491011
const uint64_t a0 = a->d[0], a1 = a->d[1], a2 = a->d[2], a3 = a->d[3];
9501012
SECP256K1_SCALAR_VERIFY(a);
9511013

1014+
#if defined(__AVX__) && defined(__AVX2__)
1015+
{
1016+
__m256i limbs_0012 = _mm256_setr_epi64x(a0, a0, a1, a2);
1017+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
1018+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 62, 60, 58); /*TODO: precompute */
1019+
const __m256i shift_rhs = _mm256_setr_epi64x(64, 2, 4, 6); /*TODO: precompute */
1020+
const __m256i mask62 = _mm256_set1_epi64x(M62); /*TODO: precompute */
1021+
__m256i lhs = _mm256_srlv_epi64(limbs_0012, shift_lhs);
1022+
__m256i rhs = _mm256_sllv_epi64(limbs_0123, shift_rhs);
1023+
__m256i out = _mm256_or_si256(lhs, rhs);
1024+
_mm256_storeu_si256((__m256i *)r->v, _mm256_and_si256(out, mask62));
1025+
}
1026+
#else
9521027
r->v[0] = a0 & M62;
9531028
r->v[1] = (a0 >> 62 | a1 << 2) & M62;
9541029
r->v[2] = (a1 >> 60 | a2 << 4) & M62;
9551030
r->v[3] = (a2 >> 58 | a3 << 6) & M62;
1031+
#endif
9561032
r->v[4] = a3 >> 56;
9571033
}
9581034

0 commit comments

Comments
 (0)