diff --git a/src/field/crypto_bigint_const_monty.rs b/src/field/crypto_bigint_const_monty.rs index 118f79b..e5002ee 100644 --- a/src/field/crypto_bigint_const_monty.rs +++ b/src/field/crypto_bigint_const_monty.rs @@ -9,7 +9,7 @@ use core::{ str::FromStr, }; use crypto_bigint::{ - Limb, + Limb, NonZeroUint, Uint as CBUint, modular::{ConstMontyForm, ConstMontyParams as Params, Retrieve}, subtle::{Choice, ConditionallySelectable, ConstantTimeEq}, }; @@ -505,25 +505,28 @@ impl, const LIMBS: usize> Semiring for ConstMontyField, const LIMBS: usize> Ring for ConstMontyField {} impl, const LIMBS: usize> Field for ConstMontyField { - type Inner = ConstMontyForm; + type Inner = Uint; #[inline(always)] fn inner(&self) -> &Self::Inner { - &self.0 + Uint::new_ref(self.0.as_montgomery()) } } impl, const LIMBS: usize> ConstPrimeField for ConstMontyField { - const MODULUS: Self::Inner = ConstMontyForm::::new(Mod::PARAMS.modulus().as_ref()); + const MODULUS: Self::Inner = *Uint::new_ref(Mod::PARAMS.modulus().as_ref()); const MODULUS_MINUS_ONE_DIV_TWO: Self::Inner = { - let m_minus_one = ConstMontyForm::sub(&Self::MODULUS, &ConstMontyForm::ONE); - m_minus_one.div_by_2() + let m_minus_one = CBUint::wrapping_sub(Self::MODULUS.inner(), &CBUint::ONE); + let two = CBUint::::wrapping_add(&CBUint::ONE, &CBUint::ONE); + Uint::new(CBUint::wrapping_div( + &m_minus_one, + &NonZeroUint::new_unwrap(two), + )) }; #[inline(always)] fn new_unchecked(inner: Self::Inner) -> Self { - // Inner value is a ConstMontyForm so it's guaranteed to be valid - Self(inner) + Self(ConstMontyForm::from_montgomery(inner.into_inner())) } } @@ -1295,6 +1298,14 @@ mod prop_tests { ); type F = ConstMontyField; + #[test] + fn modulus_minus_one_div_two_correct() { + assert_eq!( + F::MODULUS_MINUS_ONE_DIV_TWO, + Uint::from_be_hex("006E54A6C50F6671DB743AAEC4CCBC3E82926C650F53AAF3D7C27DB237D18F93") + ) + } + fn any_f() -> impl Strategy { any::().prop_map(F::from) } diff --git a/src/semiring/crypto_bigint_uint.rs b/src/semiring/crypto_bigint_uint.rs index 4f3aa5c..482027f 100644 --- a/src/semiring/crypto_bigint_uint.rs +++ b/src/semiring/crypto_bigint_uint.rs @@ -6,7 +6,7 @@ use core::{ hash::{Hash, Hasher}, iter::{Product, Sum}, ops::{ - Add, AddAssign, Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, + Add, AddAssign, Div, Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }, str::FromStr, @@ -269,6 +269,24 @@ impl_basic_op!(Add, add); impl_basic_op!(Sub, sub); impl_basic_op!(Mul, mul); +impl Div for Uint { + type Output = Self; + + #[inline(always)] + fn div(self, rhs: Self) -> Self::Output { + self.div(&rhs) + } +} + +impl<'a, const LIMBS: usize> Div<&'a Self> for Uint { + type Output = Self; + + fn div(self, rhs: &'a Self) -> Self::Output { + let non_zero = crypto_bigint::NonZero::new(rhs.0).expect("division by zero"); + Self(self.0.div(&non_zero)) + } +} + impl Rem for Uint { type Output = Self;