diff --git a/src/internal_math.rs b/src/internal_math.rs index 9f53a84..34d541a 100644 --- a/src/internal_math.rs +++ b/src/internal_math.rs @@ -59,28 +59,26 @@ impl Barrett { /// /// * `a` `0 <= a < m` /// * `b` `0 <= b < m` -/// * `m` `1 <= m <= 2^31` -/// * `im` = ceil(2^64 / `m`) +/// * `m` `1 <= m < 2^32` +/// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1 #[allow(clippy::many_single_char_names)] pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 { // [1] m = 1 // a = b = im = 0, so okay // [2] m >= 2 - // im = ceil(2^64 / m) + // im = ceil(2^64 / m) = floor((2^64 - 1) / m) + 1 // -> im * m = 2^64 + r (0 <= r < m) // let z = a*b = c*m + d (0 <= c, d < m) // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 // ((ab * im) >> 64) == c or c + 1 - let mut z = a as u64; - z *= b as u64; + let z = (a as u64) * (b as u64); let x = (((z as u128) * (im as u128)) >> 64) as u64; - let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32; - if m <= v { - v = v.wrapping_add(m); + match z.overflowing_sub(x.wrapping_mul(m as u64)) { + (v, true) => (v as u32).wrapping_add(m), + (v, false) => v as u32, } - v } /// # Parameters @@ -320,6 +318,34 @@ mod tests { let b = Barrett::new(2147483647); assert_eq!(b.umod(), 2147483647); assert_eq!(b.mul(1073741824, 2147483645), 2147483646); + + // test `2^31 < self._m < 2^32` case. + // https://github.com/rust-lang-ja/ac-library-rs/pull/112 + let b = Barrett::new(3221225471); + assert_eq!(b.umod(), 3221225471); + assert_eq!(b.mul(3188445886, 2844002853), 1840468257); + assert_eq!(b.mul(2834869488, 2779159607), 2084027561); + assert_eq!(b.mul(3032263594, 3039996727), 2130247251); + assert_eq!(b.mul(3029175553, 3140869278), 1892378237); + // https://github.com/atcoder/ac-library/issues/149 + // https://github.com/atcoder/ac-library/pull/163 + for m in u32::MAX - 20..=u32::MAX { + let b = Barrett::new(m); + let mut v: Vec = vec![]; + for i in 0..10 { + v.push(i); + v.push(m - i); + v.push(m / 2 + i); + v.push(m / 2 - i); + } + for a in v { + let a2 = u64::from(a); + assert_eq!( + (((a2 * a2) % u64::from(m) * a2) % u64::from(m)) as u32, + b.mul(a, b.mul(a, a)) + ); + } + } } #[test] diff --git a/src/modint.rs b/src/modint.rs index bdca791..d02d25a 100644 --- a/src/modint.rs +++ b/src/modint.rs @@ -793,20 +793,21 @@ trait InternalImplementations: ModIntBase { #[inline] fn add_impl(lhs: Self, rhs: Self) -> Self { let modulus = Self::modulus(); - let mut val = lhs.val() + rhs.val(); - if val >= modulus { - val -= modulus; - } + let v = u64::from(lhs.val()) + u64::from(rhs.val()); + let val = match v.overflowing_sub(u64::from(modulus)) { + (_, true) => v as u32, + (w, false) => w as u32, + }; Self::raw(val) } #[inline] fn sub_impl(lhs: Self, rhs: Self) -> Self { let modulus = Self::modulus(); - let mut val = lhs.val().wrapping_sub(rhs.val()); - if val >= modulus { - val = val.wrapping_add(modulus) - } + let val = match lhs.val().overflowing_sub(rhs.val()) { + (v, true) => v.wrapping_add(modulus), + (v, false) => v, + }; Self::raw(val) } @@ -1171,4 +1172,32 @@ mod tests { let y = ModInt::new(123).pow(0); assert_eq!(y.val(), 0); } + + // test `2^31 < modulus < 2^32` case + // https://github.com/rust-lang-ja/ac-library-rs/issues/111 + // https://github.com/atcoder/ac-library/issues/149 + // https://github.com/atcoder/ac-library/pull/163 + // https://github.com/atcoder/ac-library/issues/164 + #[test] + fn dynamic_modint_m32() { + let m = 3221225471; + ModInt::set_modulus(m); + let f = ModInt::new::; + assert_eq!(f(1398188832) + f(3184083880), f(1361047241)); + assert_eq!(f(3013899062) + f(2238406135), f(2031079726)); + assert_eq!(f(2699997885) + f(2745140255), f(2223912669)); + assert_eq!(f(2824399978) + f(2531872141), f(2135046648)); + assert_eq!(f(36496612) - f(2039504668), f(1218217415)); + assert_eq!(f(266176802) - f(1609833977), f(1877568296)); + assert_eq!(f(713535382) - f(2153383999), f(1781376854)); + assert_eq!(f(1249965147) - f(3144251805), f(1326938813)); + assert_eq!(f(2692223381) * f(2935379475), f(2084179397)); + assert_eq!(f(2800462205) * f(2822998916), f(2089431198)); + assert_eq!(f(3061947734) * f(3210920667), f(1962208034)); + assert_eq!(f(3138997926) * f(2994465129), f(1772479317)); + assert_eq!(f(2947552629) / f(576466398), f(2041593039)); + assert_eq!(f(2914694891) / f(399734126), f(1983162347)); + assert_eq!(f(2202862138) / f(1154428799), f(2139936238)); + assert_eq!(f(3037207894) / f(2865447143), f(1894581230)); + } }