Skip to content

Commit

Permalink
Merge pull request #257 from Tom-Hubrecht/array-api
Browse files Browse the repository at this point in the history
Add some methods defined in the Array-API
  • Loading branch information
oscarbenjamin authored Jan 31, 2025
2 parents e50cda6 + 2bcdaa9 commit 9b00a9d
Showing 1 changed file with 72 additions and 22 deletions.
94 changes: 72 additions & 22 deletions src/flint/types/_gr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,47 @@ cdef class gr_ctx(flint_ctx):
def max(self, x, y) -> gr:
return self._max(self(x), self(y))
###
# Array-API wrappers
def divide(self, x, y) -> gr:
return self.div(x, y)
def greater(self, x, y):
return self.gt(x, y)
def greater_equal(self, x, y):
return self.ge(x, y)
def less(self, x, y):
return self.lt(x, y)
def less_equal(self, x, y):
return self.le(x, y)
def imag(self, x):
return self.im(x)
def real(self, x):
return self.re(x)
def maximum(self, x, y):
return self.max(x, y)
def minimum(self, x, y):
return self.min(x, y)
def multiply(self, x, y):
return self.mul(x, y)
def negative(self, x):
return self.neg(x)
def not_equal(self, x, y):
eq = self.equal(x, y)
if eq is None:
return None
return not eq
cdef class gr_scalar_ctx(gr_ctx):
"""Base class for all scalar contexts."""
Expand Down Expand Up @@ -1661,6 +1702,18 @@ cdef class gr(flint_scalar):
def __repr__(self):
return self.ctx.to_str(self)

def parent(self) -> gr_ctx:
"""
Return the parent context.

>>> from flint.types._gr import gr_complex_acb_ctx
>>> acb = gr_complex_acb_ctx.new(53)
>>> x = acb("pi")
>>> x.parent()
gr_complex_acb_ctx(53)
"""
return self.ctx

def is_zero(self):
"""Return whether the element is zero (may return ``None``).
Expand Down Expand Up @@ -1838,10 +1891,7 @@ cdef class gr(flint_scalar):
return NotImplemented

def __pow__(self, other) -> gr:
if isinstance(other, int):
return self._pow_si(other)
else:
return NotImplemented
return self.ctx.pow(self, other)

def is_square(self):
"""Return whether the element is a square (may return ``None``).
Expand All @@ -1854,7 +1904,7 @@ cdef class gr(flint_scalar):
>>> Q(4).sqrt()
2
"""
return truth_to_py(self._is_square())
return truth_to_py(self.ctx.is_square(self))

def sqrt(self):
"""Return the square root of the element if it exists.
Expand All @@ -1863,7 +1913,7 @@ cdef class gr(flint_scalar):
>>> Z(4).sqrt()
2
"""
return self._sqrt()
return self.ctx.sqrt(self)

def rsqrt(self):
"""Return the reciprocal square root of the element if it exists.
Expand All @@ -1872,7 +1922,7 @@ cdef class gr(flint_scalar):
>>> Q(4).rsqrt()
1/2
"""
return self._rsqrt()
return self.ctx.rsqrt(self)

def gcd(self, other):
"""Return the greatest common divisor of two elements.
Expand Down Expand Up @@ -1902,7 +1952,7 @@ cdef class gr(flint_scalar):
other_gr = other
if not self.ctx == other_gr.ctx:
raise TypeError("gcd of gr with different contexts.")
return self._lcm(other_gr)
return self.ctx.lcm(self, other_gr)

def factor(self):
"""Return the factorization of the element.
Expand All @@ -1911,7 +1961,7 @@ cdef class gr(flint_scalar):
>>> Z(12).factor()
(1, [(2, 2), (3, 1)])
"""
return self._factor()
return self.ctx.factor(self)

def numer(self) -> gr:
"""Return the numerator of the element.
Expand All @@ -1925,7 +1975,7 @@ cdef class gr(flint_scalar):

See also :meth:`denom`.
"""
return self._numerator()
return self.ctx.numerator(self)

def denom(self) -> gr:
"""Return the denominator of the element.
Expand All @@ -1939,21 +1989,21 @@ cdef class gr(flint_scalar):

See also :meth:`numer`.
"""
return self._denominator()
return self.ctx.denominator(self)

def __floor__(self) -> gr:
return self._floor()
return self.ctx.floor(self)

def __ceil__(self) -> gr:
return self._ceil()
return self.ctx.ceil(self)

def __trunc__(self) -> gr:
return self._trunc()
return self.ctx.trunc(self)

def __round__(self, ndigits: int = 0) -> gr:
if ndigits != 0:
raise NotImplementedError("Rounding to a specific number of digits is not supported")
return self._nint()
return self.ctx.nint(self)

# def __int__(self) -> int:
# return self._floor().to_int()
Expand All @@ -1962,7 +2012,7 @@ cdef class gr(flint_scalar):
# return ...

def __abs__(self) -> gr:
return self._abs()
return self.ctx.abs(self)

def conjugate(self) -> gr:
"""Return complex conjugate of the element.
Expand All @@ -1972,7 +2022,7 @@ cdef class gr(flint_scalar):
>>> (1 + I).conjugate()
(1-I)
"""
return self._conj()
return self.ctx.conj(self)

@property
def real(self) -> gr:
Expand All @@ -1983,7 +2033,7 @@ cdef class gr(flint_scalar):
>>> (1 + I).real
1
"""
return self._re()
return self.ctx.re(self)

@property
def imag(self) -> gr:
Expand All @@ -1994,7 +2044,7 @@ cdef class gr(flint_scalar):
>>> (1 + I).imag
1
"""
return self._im()
return self.ctx.im(self)

# XXX: Return -1, 0, 1 as int?
def sgn(self) -> gr:
Expand All @@ -2008,7 +2058,7 @@ cdef class gr(flint_scalar):
>>> Q(0).sgn()
0
"""
return self._sgn()
return self.ctx.sgn(self)

def csgn(self) -> gr:
"""Return the complex sign of the element.
Expand All @@ -2018,7 +2068,7 @@ cdef class gr(flint_scalar):
>>> (1 + C.i()).csgn() # doctest: +SKIP
1
"""
return self._csgn()
return self.ctx.csgn(self)

def arg(self) -> gr:
"""Return the argument of the element.
Expand All @@ -2028,4 +2078,4 @@ cdef class gr(flint_scalar):
>>> (1 + C.i()).arg()
[0.785 +/- 6.45e-4]
"""
return self._arg()
return self.ctx.arg(self)

0 comments on commit 9b00a9d

Please sign in to comment.