Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: test binary operators vs. numpy generics #145

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,15 @@ def _check_allowed_dtypes(

return other

def _check_device(self, other: Array | bool | int | float | complex) -> None:
"""Check that other is on a device compatible with the current array"""
if isinstance(other, (bool, int, float, complex)):
return
elif isinstance(other, Array):
def _check_type_device(self, other: Array | bool | int | float | complex) -> None:
"""Check that other is either a Python scalar or an array on a device
compatible with the current array.
"""
if isinstance(other, Array):
if self.device != other.device:
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
else:
raise TypeError(f"Expected Array | python scalar; got {type(other)}")
elif not isinstance(other, bool | int | float | complex):
raise TypeError(f"Expected Array or Python scalar; got {type(other)}")

# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
Expand Down Expand Up @@ -542,7 +542,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
"""
Performs the operation __add__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__add__")
if other is NotImplemented:
return other
Expand All @@ -554,7 +554,7 @@ def __and__(self, other: Array | bool | int, /) -> Array:
"""
Performs the operation __and__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -651,7 +651,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
"""
Performs the operation __eq__.
"""
self._check_device(other)
self._check_type_device(other)
# Even though "all" dtypes are allowed, we still require them to be
# promotable with each other.
other = self._check_allowed_dtypes(other, "all", "__eq__")
Expand All @@ -677,7 +677,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __floordiv__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
Expand All @@ -689,7 +689,7 @@ def __ge__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __ge__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -741,7 +741,7 @@ def __gt__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __gt__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -796,7 +796,7 @@ def __le__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __le__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
if other is NotImplemented:
return other
Expand All @@ -808,7 +808,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
"""
Performs the operation __lshift__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
if other is NotImplemented:
return other
Expand All @@ -820,7 +820,7 @@ def __lt__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __lt__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
if other is NotImplemented:
return other
Expand All @@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array:
"""
Performs the operation __matmul__.
"""
self._check_device(other)
self._check_type_device(other)
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
Expand All @@ -845,7 +845,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __mod__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
Expand All @@ -857,7 +857,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
"""
Performs the operation __mul__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
if other is NotImplemented:
return other
Expand All @@ -869,7 +869,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
"""
Performs the operation __ne__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
Expand All @@ -890,7 +890,7 @@ def __or__(self, other: Array | bool | int, /) -> Array:
"""
Performs the operation __or__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
if other is NotImplemented:
return other
Expand All @@ -913,7 +913,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array:
"""
from ._elementwise_functions import pow # type: ignore[attr-defined]

self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
Expand All @@ -925,7 +925,7 @@ def __rshift__(self, other: Array | int, /) -> Array:
"""
Performs the operation __rshift__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer", "__rshift__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -961,7 +961,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
"""
Performs the operation __sub__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__sub__")
if other is NotImplemented:
return other
Expand All @@ -975,7 +975,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
"""
Performs the operation __truediv__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
if other is NotImplemented:
return other
Expand All @@ -987,7 +987,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
"""
Performs the operation __xor__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
if other is NotImplemented:
return other
Expand All @@ -999,7 +999,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
"""
Performs the operation __iadd__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__iadd__")
if other is NotImplemented:
return other
Expand All @@ -1010,7 +1010,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
"""
Performs the operation __radd__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__radd__")
if other is NotImplemented:
return other
Expand All @@ -1022,7 +1022,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
"""
Performs the operation __iand__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__")
if other is NotImplemented:
return other
Expand All @@ -1033,7 +1033,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
"""
Performs the operation __rand__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
if other is NotImplemented:
return other
Expand All @@ -1045,7 +1045,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __ifloordiv__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
if other is NotImplemented:
return other
Expand All @@ -1056,7 +1056,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array:
"""
Performs the operation __rfloordiv__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
if other is NotImplemented:
return other
Expand All @@ -1068,7 +1068,7 @@ def __ilshift__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ilshift__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer", "__ilshift__")
if other is NotImplemented:
return other
Expand All @@ -1079,7 +1079,7 @@ def __rlshift__(self, other: Array | int, /) -> Array:
"""
Performs the operation __rlshift__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
if other is NotImplemented:
return other
Expand All @@ -1096,7 +1096,7 @@ def __imatmul__(self, other: Array, /) -> Array:
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
res = self._array.__imatmul__(other._array)
return self.__class__._new(res, device=self.device)

Expand All @@ -1109,7 +1109,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
res = self._array.__rmatmul__(other._array)
return self.__class__._new(res, device=self.device)

Expand All @@ -1130,7 +1130,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
self, other = self._normalize_two_args(self, other)
res = self._array.__rmod__(other._array)
return self.__class__._new(res, device=self.device)
Expand All @@ -1152,7 +1152,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
self, other = self._normalize_two_args(self, other)
res = self._array.__rmul__(other._array)
return self.__class__._new(res, device=self.device)
Expand All @@ -1171,7 +1171,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
"""
Performs the operation __ror__.
"""
self._check_device(other)
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
self, other = self._normalize_two_args(self, other)
res = self._array.__rrshift__(other._array)
return self.__class__._new(res, device=self.device)
Expand All @@ -1241,7 +1241,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
self, other = self._normalize_two_args(self, other)
res = self._array.__rsub__(other._array)
return self.__class__._new(res, device=self.device)
Expand All @@ -1263,7 +1263,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
self, other = self._normalize_two_args(self, other)
res = self._array.__rtruediv__(other._array)
return self.__class__._new(res, device=self.device)
Expand All @@ -1285,7 +1285,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array:
other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
if other is NotImplemented:
return other
self._check_device(other)
self._check_type_device(other)
self, other = self._normalize_two_args(self, other)
res = self._array.__rxor__(other._array)
return self.__class__._new(res, device=self.device)
Expand Down
Loading
Loading