Skip to content

feat(gpu): optimize BLS12-446 field arithmetic for MSM performance#3448

Open
bbarbakadze wants to merge 1 commit intomainfrom
bb/zk/32_bit_limbs
Open

feat(gpu): optimize BLS12-446 field arithmetic for MSM performance#3448
bbarbakadze wants to merge 1 commit intomainfrom
bb/zk/32_bit_limbs

Conversation

@bbarbakadze
Copy link
Copy Markdown
Contributor

PR content/description

Optimize BLS12-446 field arithmetic for MSM performance

  • Replace 64-bit CIOS Montgomery multiplication with 32-bit MAD chains
    (mad.lo.cc/madc.hi.cc), exploiting native 2x throughput of 32-bit ops
    on NVIDIA GPUs via even/odd accumulator separation

  • Add fp_mont_sqr using a triangular MAD chain (upper triangle computed
    once and doubled, diagonal added separately), saving ~40% of the
    multiplications versus treating squaring as a general multiplication

  • Add fp_add_lazy/fp_sub_lazy (and Fp2 variants): skip the final
    conditional subtraction when the result feeds fp_mont_mul, which
    accepts inputs in [0, 2p). Wired into fp2_mont_mul, fp2_mont_square,
    and G1/G2 projective_point_double

  • Replace all fp_mont_mul(c, a, a) squaring patterns with fp_mont_sqr
    across curve.cu and fp2.cu (point addition, doubling, inversion)

Check-list:

  • Tests for the changes have been added (for bug fixes / features)
  • Docs have been added / updated (for bug fixes / features)
  • Relevant issues are marked as resolved/closed, related issues are linked in the description
  • Check for breaking changes (including serialization changes) and add them to commit message following the conventional commit specification

@bbarbakadze bbarbakadze requested a review from a team as a code owner April 3, 2026 16:58
@cla-bot cla-bot bot added the cla-signed label Apr 3, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 3, 2026

⚠️ This PR contains unsigned commits. To get your PR merged, please sign those commits (git rebase --exec 'git commit -S --amend --no-edit -n' @{upstream}) and force push them to this branch (git push --force-with-lease).

If you're new to commit signing, there are different ways to set it up:

Sign commits with gpg

Follow the steps below to set up commit signing with gpg:

  1. Generate a GPG key
  2. Add the GPG key to your GitHub account
  3. Configure git to use your GPG key for commit signing
Sign commits with ssh-agent

Follow the steps below to set up commit signing with ssh-agent:

  1. Generate an SSH key and add it to ssh-agent
  2. Add the SSH key to your GitHub account
  3. Configure git to use your SSH key for commit signing
Sign commits with 1Password

You can also sign commits using 1Password, which lets you sign commits with biometrics without the signing key leaving the local 1Password process.

Learn how to use 1Password to sign your commits.

Watch the demo

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR should change this line to 32 by default.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@pdroalves
Copy link
Copy Markdown
Contributor

@bbarbakadze Something is wrong with benchmarks: https://github.com/zama-ai/tfhe-rs/actions/runs/24038290191

Copy link
Copy Markdown
Contributor

@pdroalves pdroalves left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a batch of high-level comments. Let me know when you are done with this PR so I can do a careful review line by line.

// largest field alignment (4 bytes in 32-bit limb mode, 8 bytes in 64-bit).
// Forcing alignas(8) ensures sizeof(G1Affine)==120 in both modes, matching
// the Rust FFI bindings which are always generated from the 64-bit layout.
struct alignas(8) G1Affine {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you replace this magic number by a function based on LIMB_BITS_CONFIG?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is actually not dependent on a LIMB_BITS_CONFIG, it depends on the layout Rust is using. So once 64-bit limbs are used, we need the same alignment. still the magic number is replaced with sizeof

// across all 14 limbs.
// Operand map: %0..%13 = c[0..13], %14 = carry_out,
// %15..%28 = a[0..13], %29..%42 = b[0..13].
uint32_t carry_out;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guillermo-oyarzun do you want to double check this PTX? It seems ok to me.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup the ptx looks good!

// Operand map: %0..%13 = c[0..13], %14 = borrow_out,
// %15..%28 = a[0..13], %29..%42 = b[0..13].
uint32_t borrow_out;
asm("sub.cc.u32 %0, %15, %29;\n\t" // c[0] = a[0] - b[0], set BF
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guillermo-oyarzun here too.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, it looks good too!

#endif // LIMB_BITS_CONFIG == 64
#endif // __CUDA_ARCH__

// 32-bit dual MAD-chain Montgomery multiplication (device path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reference for this MAD-chain multiplication? If so, a link as comment would help.

fp_qad_row_32(&wtemp[2 * i], &wide[2 * i + 2], &a32[i + 1], a32[i], n - i);
}

asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like PTX in the middle of a function like this one. Maybe you could move it to a macro and add comments explaining what it is.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do it

p4 = DEVICE_MODULUS.limb[4], p5 = DEVICE_MODULUS.limb[5],
p6 = DEVICE_MODULUS.limb[6];
uint64_t r0, r1, r2, r3, r4, r5, r6, mask64;
asm("sub.cc.u64 %0, %8, %15;\n\t"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is full of PTX. We need to careful read them and if possible remove them from within functions.

// across all 14 limbs.
// Operand map: %0..%13 = c[0..13], %14 = carry_out,
// %15..%28 = a[0..13], %29..%42 = b[0..13].
uint32_t carry_out;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup the ptx looks good!

// Operand map: %0..%13 = c[0..13], %14 = borrow_out,
// %15..%28 = a[0..13], %29..%42 = b[0..13].
uint32_t borrow_out;
asm("sub.cc.u32 %0, %15, %29;\n\t" // c[0] = a[0] - b[0], set BF
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, it looks good too!

#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64
// Device path: fully unrolled PTX with hardware carry flags
fp_mont_mul_cios_ptx(c, a, b);
#ifdef __CUDA_ARCH__
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i understand that now we have 2 versions for 32 and 64-bit limbs, can we add a panic in the correct place in case someone attempts to use it with 128-bit?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guillermo-oyarzun you mean if someone tries to set value other than 32 and 64 to LIMB_BITS_CONFIG

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case maybe use enum? with two values 32BIT and 64BIT

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup enum should work, just trying be extra safe because the code shouldn't work with 128-bit, right? we would need to emulate them somehow

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now limbs can only be 32 or 64 I will rewrite it with enum, should be better than panic.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw there is already a protection implemented for this inside fp.h line:55

static_assert(LIMB_BITS == 32 || LIMB_BITS == 64, "LIMB_BITS_CONFIG must be 32 or 64");

So I guess it is fine to leave it as it is.

  - Replace 64-bit CIOS Montgomery multiplication with 32-bit MAD chains
    (mad.lo.cc/madc.hi.cc), exploiting native 2x throughput of 32-bit ops
    on NVIDIA GPUs via even/odd accumulator separation

  - Add fp_mont_sqr using a triangular MAD chain (upper triangle computed
    once and doubled, diagonal added separately), saving of the
    multiplications versus treating squaring as a general multiplication

  - Add fp_add_lazy/fp_sub_lazy (and Fp2 variants): skip the final
    conditional subtraction when the result feeds fp_mont_mul, which
    accepts inputs in [0, 2p). Wired into fp2_mont_mul, fp2_mont_square,
    and G1/G2 projective_point_double

  - Replace all fp_mont_mul(c, a, a) squaring patterns with fp_mont_sqr
    across curve.cu and fp2.cu (point addition, doubling, inversion)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants