Skip to content

Commit 0a59f53

Browse files
Merge pull request #322 from remyoudompheng/mullow
Extend mul_low and pow_trunc methods to fmpz_poly, fmpq_poly, nmod_poly
2 parents 955f1df + eebd8c2 commit 0a59f53

File tree

11 files changed

+256
-12
lines changed

11 files changed

+256
-12
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ Contributors (0.9.0):
170170

171171
Changes (0.9.0):
172172

173+
- [gh-322](https://github.com/flintlib/python-flint/pull/322),
174+
Add `mul_low` and `pow_trunc` methods to `fmpz_poly`, `fmpq_poly` and
175+
`nmod_poly`. (RO)
173176
- [gh-318](https://github.com/flintlib/python-flint/pull/318),
174177
Add emscripten build in CI. Polynomial factors and roots are
175178
now sorted into a consistent order for `nmod_poly` and

src/flint/test/test_all.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2964,6 +2964,20 @@ def setbad(obj, i, val):
29642964
assert raises(lambda: P([1, 1]) ** -1, DomainError)
29652965
assert raises(lambda: P([1, 1]) ** None, TypeError) # type: ignore
29662966

2967+
# Truncated operations
2968+
assert P([1, 2, 3]).mul_low(P([4, 5, 6]), 3) == P([4, 13, 28])
2969+
assert raises(lambda: P([1, 2, 3]).mul_low(None, 3), TypeError) # type: ignore
2970+
assert raises(lambda: P([1, 2, 3]).mul_low(P([4, 5, 6]), None), TypeError) # type: ignore
2971+
2972+
p = P([1, 2, 3])
2973+
assert p.pow_trunc(1234, 3) == P([1, 2468, 3046746])
2974+
assert raises(lambda: p.pow_trunc(None, 3), TypeError) # type: ignore
2975+
assert raises(lambda: p.pow_trunc(3, "A"), TypeError) # type: ignore
2976+
assert raises(lambda: p.pow_trunc(P([4, 5, 6]), 3), TypeError) # type: ignore
2977+
# Large exponents are allowed
2978+
assert p.pow_trunc(2**100, 2) == P([1, 2**101])
2979+
assert p.pow_trunc(6**60, 3) == p.pow_trunc(2**60, 3).pow_trunc(3**60, 3)
2980+
29672981
# XXX: Not sure what this should do in general:
29682982
p = P([1, 1])
29692983
mod = P([1, 1])

src/flint/types/fmpq_poly.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class fmpq_poly(flint_poly[fmpq]):
7171
def left_shift(self, n: int, /) -> fmpq_poly: ...
7272
def right_shift(self, n: int, /) -> fmpq_poly: ...
7373
def truncate(self, n: int, /) -> fmpq_poly: ...
74+
def mul_low(self, other: fmpq_poly, n: int) -> fmpq_poly: ...
75+
def pow_trunc(self, e: int, n: int) -> fmpq_poly: ...
7476

7577
def gcd(self, other: ifmpq_poly, /) -> fmpq_poly: ...
7678
def discriminant(self) -> fmpq: ...

src/flint/types/fmpq_poly.pyx

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,66 @@ cdef class fmpq_poly(flint_poly):
502502
fmpq_poly_pow(res.val, self.val, <ulong>exp)
503503
return res
504504

505+
def mul_low(self, other, slong n):
506+
r"""
507+
Returns the lowest ``n`` coefficients of the multiplication of ``self`` with ``other``
508+
509+
Equivalent to computing `f(x) \cdot g(x) \mod x^n`
510+
511+
>>> f = fmpq_poly([2,3,5,7,11])
512+
>>> g = fmpq_poly([1,2,4,8,16])
513+
>>> f.mul_low(g, 5)
514+
101*x^4 + 45*x^3 + 19*x^2 + 7*x + 2
515+
>>> f.mul_low(g, 3)
516+
19*x^2 + 7*x + 2
517+
>>> f.mul_low(g, 1)
518+
2
519+
"""
520+
# Only allow multiplication with other fmpq_poly
521+
if not typecheck(other, fmpq_poly):
522+
raise TypeError("other polynomial must be of type fmpq_poly")
523+
524+
cdef fmpq_poly res
525+
res = fmpq_poly.__new__(fmpq_poly)
526+
fmpq_poly_mullow(res.val, self.val, (<fmpq_poly>other).val, n)
527+
return res
528+
529+
def pow_trunc(self, e, slong n):
530+
r"""
531+
Returns ``self`` raised to the power ``e`` modulo `x^n`:
532+
:math:`f^e \mod x^n`/
533+
534+
>>> f = fmpq_poly([1, 2, 3])
535+
>>> x = fmpq_poly([0, 1])
536+
>>> f.pow_trunc(2**20, 4)
537+
1537230871828889600*x^3 + 2199024304128*x^2 + 2097152*x + 1
538+
>>> f.pow_trunc(5**25, 3)
539+
177635683940025046765804290771484375*x^2 + 596046447753906250*x + 1
540+
"""
541+
if e < 0:
542+
raise ValueError("Exponent must be non-negative")
543+
544+
cdef slong e_c
545+
cdef fmpq_poly res, tmp
546+
547+
try:
548+
e_c = e
549+
except OverflowError:
550+
# Exponent does not fit slong
551+
res = fmpq_poly.__new__(fmpq_poly)
552+
tmp = fmpq_poly.__new__(fmpq_poly)
553+
ebytes = e.to_bytes((e.bit_length() + 15) // 16 * 2, "big")
554+
fmpq_poly_pow_trunc(res.val, self.val, ebytes[0] * 256 + ebytes[1], n)
555+
for i in range(2, len(ebytes), 2):
556+
fmpq_poly_pow_trunc(res.val, res.val, 1 << 16, n)
557+
fmpq_poly_pow_trunc(tmp.val, self.val, ebytes[i] * 256 + ebytes[i+1], n)
558+
fmpq_poly_mullow(res.val, res.val, tmp.val, n)
559+
return res
560+
561+
res = fmpq_poly.__new__(fmpq_poly)
562+
fmpq_poly_pow_trunc(res.val, self.val, e_c, n)
563+
return res
564+
505565
def gcd(self, other):
506566
"""
507567
Returns the greatest common divisor of *self* and *other*.

src/flint/types/fmpz_mod_poly.pyx

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,28 +1668,43 @@ cdef class fmpz_mod_poly(flint_poly):
16681668
)
16691669
return res
16701670

1671-
def pow_trunc(self, slong e, slong n):
1671+
def pow_trunc(self, e, slong n):
16721672
r"""
16731673
Returns ``self`` raised to the power ``e`` modulo `x^n`:
16741674
:math:`f^e \mod x^n`/
16751675
1676-
Note: For exponents larger that 2^31 (which do not fit inside a ulong) use the
1677-
method :meth:`~.pow_mod` with the explicit modulus `x^n`.
1678-
16791676
>>> R = fmpz_mod_poly_ctx(163)
16801677
>>> x = R.gen()
16811678
>>> f = 30*x**6 + 104*x**5 + 76*x**4 + 33*x**3 + 70*x**2 + 44*x + 65
16821679
>>> f.pow_trunc(2**20, 30) == pow(f, 2**20, x**30)
16831680
True
16841681
>>> f.pow_trunc(2**20, 5)
16851682
132*x^4 + 113*x^3 + 36*x^2 + 48*x + 6
1683+
>>> f.pow_trunc(5**25, 5)
1684+
147*x^4 + 98*x^3 + 95*x^2 + 33*x + 126
16861685
"""
16871686
if e < 0:
16881687
raise ValueError("Exponent must be non-negative")
16891688

1690-
cdef fmpz_mod_poly res
1689+
cdef fmpz_mod_poly res, tmp
1690+
cdef slong e_c
1691+
1692+
try:
1693+
e_c = e
1694+
except OverflowError:
1695+
# Exponent does not fit slong
1696+
res = self.ctx.new_ctype_poly()
1697+
tmp = self.ctx.new_ctype_poly()
1698+
ebytes = e.to_bytes((e.bit_length() + 15) // 16 * 2, "big")
1699+
fmpz_mod_poly_pow_trunc(res.val, self.val, ebytes[0] * 256 + ebytes[1], n, res.ctx.mod.val)
1700+
for i in range(2, len(ebytes), 2):
1701+
fmpz_mod_poly_pow_trunc(res.val, res.val, 1 << 16, n, res.ctx.mod.val)
1702+
fmpz_mod_poly_pow_trunc(tmp.val, self.val, ebytes[i] * 256 + ebytes[i+1], n, res.ctx.mod.val)
1703+
fmpz_mod_poly_mullow(res.val, res.val, tmp.val, n, res.ctx.mod.val)
1704+
return res
1705+
16911706
res = self.ctx.new_ctype_poly()
1692-
fmpz_mod_poly_pow_trunc(res.val, self.val, e, n, res.ctx.mod.val)
1707+
fmpz_mod_poly_pow_trunc(res.val, self.val, e_c, n, res.ctx.mod.val)
16931708
return res
16941709

16951710
def inflate(self, ulong n):

src/flint/types/fmpz_poly.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class fmpz_poly(flint_poly[fmpz]):
5959
def left_shift(self, n: int, /) -> fmpz_poly: ...
6060
def right_shift(self, n: int, /) -> fmpz_poly: ...
6161
def truncate(self, n: int, /) -> fmpz_poly: ...
62+
def mul_low(self, other: fmpz_poly, n: int) -> fmpz_poly: ...
63+
def pow_trunc(self, e: int, n: int) -> fmpz_poly: ...
6264

6365
def gcd(self, other: ifmpz_poly, /) -> fmpz_poly: ...
6466
def content(self) -> fmpz: ...

src/flint/types/fmpz_poly.pyx

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,66 @@ cdef class fmpz_poly(flint_poly):
483483
fmpz_poly_pow(res.val, self.val, <ulong>exp)
484484
return res
485485

486+
def mul_low(self, other, slong n):
487+
r"""
488+
Returns the lowest ``n`` coefficients of the multiplication of ``self`` with ``other``
489+
490+
Equivalent to computing `f(x) \cdot g(x) \mod x^n`
491+
492+
>>> f = fmpz_poly([2,3,5,7,11])
493+
>>> g = fmpz_poly([1,2,4,8,16])
494+
>>> f.mul_low(g, 5)
495+
101*x^4 + 45*x^3 + 19*x^2 + 7*x + 2
496+
>>> f.mul_low(g, 3)
497+
19*x^2 + 7*x + 2
498+
>>> f.mul_low(g, 1)
499+
2
500+
"""
501+
# Only allow multiplication with other fmpz_poly
502+
if not typecheck(other, fmpz_poly):
503+
raise TypeError("other polynomial must be of type fmpz_poly")
504+
505+
cdef fmpz_poly res
506+
res = fmpz_poly.__new__(fmpz_poly)
507+
fmpz_poly_mullow(res.val, self.val, (<fmpz_poly>other).val, n)
508+
return res
509+
510+
def pow_trunc(self, e, slong n):
511+
r"""
512+
Returns ``self`` raised to the power ``e`` modulo `x^n`:
513+
:math:`f^e \mod x^n`/
514+
515+
>>> f = fmpz_poly([1, 2, 3])
516+
>>> x = fmpz_poly([0, 1])
517+
>>> f.pow_trunc(2**20, 4)
518+
1537230871828889600*x^3 + 2199024304128*x^2 + 2097152*x + 1
519+
>>> f.pow_trunc(5**25, 3)
520+
177635683940025046765804290771484375*x^2 + 596046447753906250*x + 1
521+
"""
522+
if e < 0:
523+
raise ValueError("Exponent must be non-negative")
524+
525+
cdef slong e_c
526+
cdef fmpz_poly res, tmp
527+
528+
try:
529+
e_c = e
530+
except OverflowError:
531+
# Exponent does not fit slong
532+
res = fmpz_poly.__new__(fmpz_poly)
533+
tmp = fmpz_poly.__new__(fmpz_poly)
534+
ebytes = e.to_bytes((e.bit_length() + 15) // 16 * 2, "big")
535+
fmpz_poly_pow_trunc(res.val, self.val, ebytes[0] * 256 + ebytes[1], n)
536+
for i in range(2, len(ebytes), 2):
537+
fmpz_poly_pow_trunc(res.val, res.val, 1 << 16, n)
538+
fmpz_poly_pow_trunc(tmp.val, self.val, ebytes[i] * 256 + ebytes[i+1], n)
539+
fmpz_poly_mullow(res.val, res.val, tmp.val, n)
540+
return res
541+
542+
res = fmpz_poly.__new__(fmpz_poly)
543+
fmpz_poly_pow_trunc(res.val, self.val, e_c, n)
544+
return res
545+
486546
def gcd(self, other):
487547
"""
488548
Returns the greatest common divisor of self and other.

src/flint/types/fq_default_poly.pyx

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,28 +1190,43 @@ cdef class fq_default_poly(flint_poly):
11901190
)
11911191
return res
11921192

1193-
def pow_trunc(self, slong e, slong n):
1193+
def pow_trunc(self, e, slong n):
11941194
r"""
11951195
Returns ``self`` raised to the power ``e`` modulo `x^n`:
11961196
:math:`f^e \mod x^n`/
11971197
1198-
Note: For exponents larger that 2^31 (which do not fit inside a ulong) use the
1199-
method :meth:`~.pow_mod` with the explicit modulus `x^n`.
1200-
12011198
>>> R = fq_default_poly_ctx(163)
12021199
>>> x = R.gen()
12031200
>>> f = 30*x**6 + 104*x**5 + 76*x**4 + 33*x**3 + 70*x**2 + 44*x + 65
12041201
>>> f.pow_trunc(2**20, 30) == pow(f, 2**20, x**30)
12051202
True
12061203
>>> f.pow_trunc(2**20, 5)
12071204
132*x^4 + 113*x^3 + 36*x^2 + 48*x + 6
1205+
>>> f.pow_trunc(5**25, 5)
1206+
147*x^4 + 98*x^3 + 95*x^2 + 33*x + 126
12081207
"""
12091208
if e < 0:
12101209
raise ValueError("Exponent must be non-negative")
12111210

1212-
cdef fq_default_poly res
1211+
cdef slong e_c
1212+
cdef fq_default_poly res, tmp
1213+
1214+
try:
1215+
e_c = e
1216+
except OverflowError:
1217+
# Exponent does not fit slong
1218+
res = self.ctx.new_ctype_poly()
1219+
tmp = self.ctx.new_ctype_poly()
1220+
ebytes = e.to_bytes((e.bit_length() + 15) // 16 * 2, "big")
1221+
fq_default_poly_pow_trunc(res.val, self.val, ebytes[0] * 256 + ebytes[1], n, res.ctx.field.val)
1222+
for i in range(2, len(ebytes), 2):
1223+
fq_default_poly_pow_trunc(res.val, res.val, 1 << 16, n, res.ctx.field.val)
1224+
fq_default_poly_pow_trunc(tmp.val, self.val, ebytes[i] * 256 + ebytes[i+1], n, res.ctx.field.val)
1225+
fq_default_poly_mullow(res.val, res.val, tmp.val, n, res.ctx.field.val)
1226+
return res
1227+
12131228
res = self.ctx.new_ctype_poly()
1214-
fq_default_poly_pow_trunc(res.val, self.val, e, n, res.ctx.field.val)
1229+
fq_default_poly_pow_trunc(res.val, self.val, e_c, n, res.ctx.field.val)
12151230
return res
12161231

12171232
def sqrt_trunc(self, slong n):

src/flint/types/nmod_poly.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class nmod_poly(flint_poly[nmod]):
5656
def __rdivmod__(self, other: inmod_poly) -> tuple[nmod_poly, nmod_poly]: ...
5757
def left_shift(self, n: int) -> nmod_poly: ...
5858
def right_shift(self, n: int) -> nmod_poly: ...
59+
def mul_low(self, other: nmod_poly, n: int) -> nmod_poly: ...
60+
def pow_trunc(self, e: int, n: int) -> nmod_poly: ...
5961
def __pow__(self, other: int, mod: inmod_poly | None = None) -> nmod_poly: ...
6062
def pow_mod(
6163
self, e: int, modulus: inmod_poly, mod_rev_inv: inmod_poly | None = None

src/flint/types/nmod_poly.pyx

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,75 @@ cdef class nmod_poly(flint_poly):
724724
)
725725
return res
726726

727+
def mul_low(self, other, slong n):
728+
r"""
729+
Returns the lowest ``n`` coefficients of the multiplication of ``self`` with ``other``
730+
731+
Equivalent to computing `f(x) \cdot g(x) \mod x^n`
732+
733+
>>> f = nmod_poly([2,3,5,7,11], 163)
734+
>>> g = nmod_poly([1,2,4,8,16], 163)
735+
>>> f.mul_low(g, 5)
736+
101*x^4 + 45*x^3 + 19*x^2 + 7*x + 2
737+
>>> f.mul_low(g, 3)
738+
19*x^2 + 7*x + 2
739+
>>> f.mul_low(g, 1)
740+
2
741+
"""
742+
# Only allow multiplication with other nmod_poly
743+
if not typecheck(other, nmod_poly):
744+
raise TypeError("other polynomial must be of type nmod_poly")
745+
746+
if (<nmod_poly>self).val.mod.n != (<nmod_poly>other).val.mod.n:
747+
raise ValueError("cannot multiply nmod_polys with different moduli")
748+
749+
cdef nmod_poly res = nmod_poly.__new__(nmod_poly)
750+
res = nmod_poly.__new__(nmod_poly)
751+
nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv)
752+
nmod_poly_mullow(res.val, self.val, (<nmod_poly>other).val, n)
753+
return res
754+
755+
def pow_trunc(self, e, slong n):
756+
r"""
757+
Returns ``self`` raised to the power ``e`` modulo `x^n`:
758+
:math:`f^e \mod x^n`/
759+
760+
>>> f = nmod_poly([65, 44, 70, 33, 76, 104, 30], 163)
761+
>>> x = nmod_poly([0, 1], 163)
762+
>>> f.pow_trunc(2**20, 30) == pow(f, 2**20, x**30)
763+
True
764+
>>> f.pow_trunc(2**20, 5)
765+
132*x^4 + 113*x^3 + 36*x^2 + 48*x + 6
766+
>>> f.pow_trunc(5**25, 5)
767+
147*x^4 + 98*x^3 + 95*x^2 + 33*x + 126
768+
"""
769+
if e < 0:
770+
raise ValueError("Exponent must be non-negative")
771+
772+
cdef nmod_poly res, tmp
773+
cdef slong e_c
774+
775+
try:
776+
e_c = e
777+
except OverflowError:
778+
# Exponent does not fit slong
779+
res = nmod_poly.__new__(nmod_poly)
780+
tmp = nmod_poly.__new__(nmod_poly)
781+
nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv)
782+
nmod_poly_init_preinv(tmp.val, self.val.mod.n, self.val.mod.ninv)
783+
ebytes = e.to_bytes((e.bit_length() + 15) // 16 * 2, "big")
784+
nmod_poly_pow_trunc(res.val, self.val, ebytes[0] * 256 + ebytes[1], n)
785+
for i in range(2, len(ebytes), 2):
786+
nmod_poly_pow_trunc(res.val, res.val, 1 << 16, n)
787+
nmod_poly_pow_trunc(tmp.val, self.val, ebytes[i] * 256 + ebytes[i+1], n)
788+
nmod_poly_mullow(res.val, res.val, tmp.val, n)
789+
return res
790+
791+
res = nmod_poly.__new__(nmod_poly)
792+
nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv)
793+
nmod_poly_pow_trunc(res.val, self.val, e_c, n)
794+
return res
795+
727796
def gcd(self, other):
728797
"""
729798
Returns the monic greatest common divisor of self and other.

0 commit comments

Comments
 (0)