Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some methods defined in the Array-API #257

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions src/flint/types/_gr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,47 @@ cdef class gr_ctx(flint_ctx):
def max(self, x, y) -> gr:
return self._max(self(x), self(y))

###
# Array-API wrappers

def divide(self, x, y) -> gr:
return self.div(x, y)

def greater(self, x, y):
return self.gt(x, y)

def greater_equal(self, x, y):
return self.ge(x, y)

def less(self, x, y):
return self.lt(x, y)

def less_equal(self, x, y):
return self.le(x, y)

def imag(self, x):
return self.im(x)

def real(self, x):
return self.re(x)

def maximum(self, x, y):
return self.max(x, y)

def minimum(self, x, y):
return self.min(x, y)

def multiply(self, x, y):
return self.mul(x, y)

def negative(self, x):
return self.neg(x)

def not_equal(self, x, y):
eq = self.equal(x, y)
if eq is None:
return None
return not eq

cdef class gr_scalar_ctx(gr_ctx):
"""Base class for all scalar contexts."""
Expand Down Expand Up @@ -1661,6 +1702,18 @@ cdef class gr(flint_scalar):
def __repr__(self):
return self.ctx.to_str(self)

def parent(self) -> gr_ctx:
"""
Return the parent context.

>>> from flint.types._gr import gr_complex_acb_ctx
>>> acb = gr_complex_acb_ctx.new(53)
>>> x = acb("pi")
>>> x.parent()
gr_complex_acb_ctx(53)
"""
return self.ctx

def is_zero(self):
"""Return whether the element is zero (may return ``None``).

Expand Down Expand Up @@ -1838,10 +1891,7 @@ cdef class gr(flint_scalar):
return NotImplemented

def __pow__(self, other) -> gr:
if isinstance(other, int):
return self._pow_si(other)
else:
return NotImplemented
return self.ctx.pow(self, other)

def is_square(self):
"""Return whether the element is a square (may return ``None``).
Expand All @@ -1854,7 +1904,7 @@ cdef class gr(flint_scalar):
>>> Q(4).sqrt()
2
"""
return truth_to_py(self._is_square())
return truth_to_py(self.ctx.is_square(self))

def sqrt(self):
"""Return the square root of the element if it exists.
Expand All @@ -1863,7 +1913,7 @@ cdef class gr(flint_scalar):
>>> Z(4).sqrt()
2
"""
return self._sqrt()
return self.ctx.sqrt(self)

def rsqrt(self):
"""Return the reciprocal square root of the element if it exists.
Expand All @@ -1872,7 +1922,7 @@ cdef class gr(flint_scalar):
>>> Q(4).rsqrt()
1/2
"""
return self._rsqrt()
return self.ctx.rsqrt(self)

def gcd(self, other):
"""Return the greatest common divisor of two elements.
Expand Down Expand Up @@ -1902,7 +1952,7 @@ cdef class gr(flint_scalar):
other_gr = other
if not self.ctx == other_gr.ctx:
raise TypeError("gcd of gr with different contexts.")
return self._lcm(other_gr)
return self.ctx.lcm(self, other_gr)

def factor(self):
"""Return the factorization of the element.
Expand All @@ -1911,7 +1961,7 @@ cdef class gr(flint_scalar):
>>> Z(12).factor()
(1, [(2, 2), (3, 1)])
"""
return self._factor()
return self.ctx.factor(self)

def numer(self) -> gr:
"""Return the numerator of the element.
Expand All @@ -1925,7 +1975,7 @@ cdef class gr(flint_scalar):

See also :meth:`denom`.
"""
return self._numerator()
return self.ctx.numerator(self)

def denom(self) -> gr:
"""Return the denominator of the element.
Expand All @@ -1939,21 +1989,21 @@ cdef class gr(flint_scalar):

See also :meth:`numer`.
"""
return self._denominator()
return self.ctx.denominator(self)

def __floor__(self) -> gr:
return self._floor()
return self.ctx.floor(self)

def __ceil__(self) -> gr:
return self._ceil()
return self.ctx.ceil(self)

def __trunc__(self) -> gr:
return self._trunc()
return self.ctx.trunc(self)

def __round__(self, ndigits: int = 0) -> gr:
if ndigits != 0:
raise NotImplementedError("Rounding to a specific number of digits is not supported")
return self._nint()
return self.ctx.nint(self)

# def __int__(self) -> int:
# return self._floor().to_int()
Expand All @@ -1962,7 +2012,7 @@ cdef class gr(flint_scalar):
# return ...

def __abs__(self) -> gr:
return self._abs()
return self.ctx.abs(self)

def conjugate(self) -> gr:
"""Return complex conjugate of the element.
Expand All @@ -1972,7 +2022,7 @@ cdef class gr(flint_scalar):
>>> (1 + I).conjugate()
(1-I)
"""
return self._conj()
return self.ctx.conj(self)

@property
def real(self) -> gr:
Expand All @@ -1983,7 +2033,7 @@ cdef class gr(flint_scalar):
>>> (1 + I).real
1
"""
return self._re()
return self.ctx.re(self)

@property
def imag(self) -> gr:
Expand All @@ -1994,7 +2044,7 @@ cdef class gr(flint_scalar):
>>> (1 + I).imag
1
"""
return self._im()
return self.ctx.im(self)

# XXX: Return -1, 0, 1 as int?
def sgn(self) -> gr:
Expand All @@ -2008,7 +2058,7 @@ cdef class gr(flint_scalar):
>>> Q(0).sgn()
0
"""
return self._sgn()
return self.ctx.sgn(self)

def csgn(self) -> gr:
"""Return the complex sign of the element.
Expand All @@ -2018,7 +2068,7 @@ cdef class gr(flint_scalar):
>>> (1 + C.i()).csgn() # doctest: +SKIP
1
"""
return self._csgn()
return self.ctx.csgn(self)

def arg(self) -> gr:
"""Return the argument of the element.
Expand All @@ -2028,4 +2078,4 @@ cdef class gr(flint_scalar):
>>> (1 + C.i()).arg()
[0.785 +/- 6.45e-4]
"""
return self._arg()
return self.ctx.arg(self)
Loading