feat(gpu): optimize BLS12-446 field arithmetic for MSM performance#3448
feat(gpu): optimize BLS12-446 field arithmetic for MSM performance#3448bbarbakadze wants to merge 1 commit intomainfrom
Conversation
|
If you're new to commit signing, there are different ways to set it up: Sign commits with
|
There was a problem hiding this comment.
This PR should change this line to 32 by default.
|
@bbarbakadze Something is wrong with benchmarks: https://github.com/zama-ai/tfhe-rs/actions/runs/24038290191 |
8cfdd1b to
a5fd85a
Compare
pdroalves
left a comment
There was a problem hiding this comment.
I added a batch of high-level comments. Let me know when you are done with this PR so I can do a careful review line by line.
| // largest field alignment (4 bytes in 32-bit limb mode, 8 bytes in 64-bit). | ||
| // Forcing alignas(8) ensures sizeof(G1Affine)==120 in both modes, matching | ||
| // the Rust FFI bindings which are always generated from the 64-bit layout. | ||
| struct alignas(8) G1Affine { |
There was a problem hiding this comment.
Can you replace this magic number by a function based on LIMB_BITS_CONFIG?
There was a problem hiding this comment.
So this is actually not dependent on a LIMB_BITS_CONFIG, it depends on the layout Rust is using. So once 64-bit limbs are used, we need the same alignment. still the magic number is replaced with sizeof
| // across all 14 limbs. | ||
| // Operand map: %0..%13 = c[0..13], %14 = carry_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t carry_out; |
There was a problem hiding this comment.
@guillermo-oyarzun do you want to double check this PTX? It seems ok to me.
There was a problem hiding this comment.
yup the ptx looks good!
| // Operand map: %0..%13 = c[0..13], %14 = borrow_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t borrow_out; | ||
| asm("sub.cc.u32 %0, %15, %29;\n\t" // c[0] = a[0] - b[0], set BF |
There was a problem hiding this comment.
same here, it looks good too!
| #endif // LIMB_BITS_CONFIG == 64 | ||
| #endif // __CUDA_ARCH__ | ||
|
|
||
| // 32-bit dual MAD-chain Montgomery multiplication (device path) |
There was a problem hiding this comment.
Do you have a reference for this MAD-chain multiplication? If so, a link as comment would help.
| fp_qad_row_32(&wtemp[2 * i], &wide[2 * i + 2], &a32[i + 1], a32[i], n - i); | ||
| } | ||
|
|
||
| asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;" |
There was a problem hiding this comment.
I don't like PTX in the middle of a function like this one. Maybe you could move it to a macro and add comments explaining what it is.
| p4 = DEVICE_MODULUS.limb[4], p5 = DEVICE_MODULUS.limb[5], | ||
| p6 = DEVICE_MODULUS.limb[6]; | ||
| uint64_t r0, r1, r2, r3, r4, r5, r6, mask64; | ||
| asm("sub.cc.u64 %0, %8, %15;\n\t" |
There was a problem hiding this comment.
This diff is full of PTX. We need to careful read them and if possible remove them from within functions.
a5fd85a to
38e4101
Compare
| // across all 14 limbs. | ||
| // Operand map: %0..%13 = c[0..13], %14 = carry_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t carry_out; |
There was a problem hiding this comment.
yup the ptx looks good!
| // Operand map: %0..%13 = c[0..13], %14 = borrow_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t borrow_out; | ||
| asm("sub.cc.u32 %0, %15, %29;\n\t" // c[0] = a[0] - b[0], set BF |
There was a problem hiding this comment.
same here, it looks good too!
| #if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64 | ||
| // Device path: fully unrolled PTX with hardware carry flags | ||
| fp_mont_mul_cios_ptx(c, a, b); | ||
| #ifdef __CUDA_ARCH__ |
There was a problem hiding this comment.
i understand that now we have 2 versions for 32 and 64-bit limbs, can we add a panic in the correct place in case someone attempts to use it with 128-bit?
There was a problem hiding this comment.
@guillermo-oyarzun you mean if someone tries to set value other than 32 and 64 to LIMB_BITS_CONFIG
There was a problem hiding this comment.
in this case maybe use enum? with two values 32BIT and 64BIT
There was a problem hiding this comment.
yup enum should work, just trying be extra safe because the code shouldn't work with 128-bit, right? we would need to emulate them somehow
There was a problem hiding this comment.
for now limbs can only be 32 or 64 I will rewrite it with enum, should be better than panic.
There was a problem hiding this comment.
btw there is already a protection implemented for this inside fp.h line:55
static_assert(LIMB_BITS == 32 || LIMB_BITS == 64, "LIMB_BITS_CONFIG must be 32 or 64");
So I guess it is fine to leave it as it is.
- Replace 64-bit CIOS Montgomery multiplication with 32-bit MAD chains
(mad.lo.cc/madc.hi.cc), exploiting native 2x throughput of 32-bit ops
on NVIDIA GPUs via even/odd accumulator separation
- Add fp_mont_sqr using a triangular MAD chain (upper triangle computed
once and doubled, diagonal added separately), saving of the
multiplications versus treating squaring as a general multiplication
- Add fp_add_lazy/fp_sub_lazy (and Fp2 variants): skip the final
conditional subtraction when the result feeds fp_mont_mul, which
accepts inputs in [0, 2p). Wired into fp2_mont_mul, fp2_mont_square,
and G1/G2 projective_point_double
- Replace all fp_mont_mul(c, a, a) squaring patterns with fp_mont_sqr
across curve.cu and fp2.cu (point addition, doubling, inversion)
38e4101 to
e716051
Compare
PR content/description
Optimize BLS12-446 field arithmetic for MSM performance
Replace 64-bit CIOS Montgomery multiplication with 32-bit MAD chains
(mad.lo.cc/madc.hi.cc), exploiting native 2x throughput of 32-bit ops
on NVIDIA GPUs via even/odd accumulator separation
Add fp_mont_sqr using a triangular MAD chain (upper triangle computed
once and doubled, diagonal added separately), saving ~40% of the
multiplications versus treating squaring as a general multiplication
Add fp_add_lazy/fp_sub_lazy (and Fp2 variants): skip the final
conditional subtraction when the result feeds fp_mont_mul, which
accepts inputs in [0, 2p). Wired into fp2_mont_mul, fp2_mont_square,
and G1/G2 projective_point_double
Replace all fp_mont_mul(c, a, a) squaring patterns with fp_mont_sqr
across curve.cu and fp2.cu (point addition, doubling, inversion)
Check-list: