Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ jobs:
# for hypothesis-driven test case generation
pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s
# only run a subset of the conformance tests to get started
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like array_api_tests/test_creation_functions.py::test_full_like
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_operators_and_elementwise_functions.py::test_positive array_api_tests/test_operators_and_elementwise_functions.py::test_isnan array_api_tests/test_operators_and_elementwise_functions.py::test_equal "array_api_tests/test_has_names.py::test_has_names[array_method-__pos__]"
45 changes: 40 additions & 5 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def __array__(self, dtype=None):
return self.data


def __pos__(self):
return pk.positive(self)


@staticmethod
def _get_dtype_name(type_name: str) -> str:
"""
Expand Down Expand Up @@ -520,11 +524,42 @@ def _get_base_view(self, parent_view: Union[Subview, View]) -> View:
return base_view

def __eq__(self, other):
if isinstance(other, View):
if len(self.data) == 0 and len(other.data) == 0:
return True
result_of_eq = self.data == other.data
return result_of_eq
# avoid circular import with scoped import
from pykokkos.lib.ufuncs import equal
if isinstance(other, float):
new_other = pk.View((), dtype=pk.double)
new_other[:] = other
elif isinstance(other, bool):
new_other = pk.View((), dtype=pk.bool)
new_other[:] = other
elif isinstance(other, int):
if self.ndim == 0:
ret = pk.View((), dtype=pk.bool)
ret[:] = int(self) == other
return ret
if 0 <= other <= 255:
other_dtype = pk.uint8
elif 0 <= other <= 65535:
other_dtype = pk.uint16
elif 0 <= other <= 4294967295:
other_dtype = pk.uint32
elif 0 <= other <= 18446744073709551615:
other_dtype = pk.uint64
elif -128 <= other <= 127:
other_dtype = pk.int8
elif -32768 <= other <= 32767:
other_dtype = pk.int16
elif -2147483648 <= other <= 2147483647:
other_dtype = pk.int32
elif -9223372036854775808 <= other <= 9223372036854775807:
other_dtype = pk.int64
new_other = pk.View((), dtype=other_dtype)
new_other[:] = other
elif isinstance(other, pk.Subview):
new_other = other
else:
raise ValueError("unexpected types!")
return equal(self, new_other)


def __add__(self, other):
Expand Down
42 changes: 21 additions & 21 deletions pykokkos/lib/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def _typematch_views(view1, view2):
if dtype_1_width >= dtype_2_width:
effective_dtype = dtype1
view2_new = pk.View([*view2.shape], dtype=effective_dtype)
view2_new[:] = view2
view2_new[:] = view2.data
view2 = view2_new
else:
effective_dtype = dtype2
view1_new = pk.View([*view1.shape], dtype=effective_dtype)
view1_new[:] = view1
view1_new[:] = view1.data
view1 = view1_new
return view1, view2, effective_dtype

Expand Down Expand Up @@ -1152,15 +1152,6 @@ def negative(view):
return out


@pk.workunit
def positive_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]):
out[tid] = view[tid]


@pk.workunit
def positive_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = view[tid]

def positive(view):
"""
Element-wise positive of the view;
Expand All @@ -1177,16 +1168,11 @@ def positive(view):
Output view.

"""
if len(view.shape) > 1:
raise NotImplementedError("only 1D views currently supported for positive() ufunc.")
if str(view.dtype) == "DataType.double":
out = pk.View([view.shape[0]], pk.double)
pk.parallel_for(view.shape[0], positive_impl_1d_double, view=view, out=out)
elif str(view.dtype) == "DataType.float":
out = pk.View([view.shape[0]], pk.float)
pk.parallel_for(view.shape[0], positive_impl_1d_float, view=view, out=out)
if view.shape == ():
out = pk.View((), dtype=view.dtype)
else:
raise NotImplementedError
out = pk.View([*view.shape], dtype=view.dtype)
out[...] = view
return out


Expand Down Expand Up @@ -2442,6 +2428,10 @@ def isnan(view):
tid = 1
else:
tid = view.shape[0]
if view.ndim == 0:
new_view = pk.View([1], dtype=view.dtype)
new_view[0] = view
view = new_view
_ufunc_kernel_dispatcher(tid=tid,
dtype=dtype,
ndims=ndims,
Expand Down Expand Up @@ -2493,7 +2483,9 @@ def equal(view1, view2):
Output view.
"""
if view1.size == 0 and view2.size == 0:
return pk.View((), dtype=pk.bool)
ret = pk.View((), dtype=pk.bool)
ret[...] = 1
return ret
view1, view2 = _broadcast_views(view1, view2)
dtype1 = view1.dtype
dtype2 = view2.dtype
Expand All @@ -2506,6 +2498,14 @@ def equal(view1, view2):
tid = 1
else:
tid = view1.shape[0]
if isinstance(view1, pk.Subview):
new_view = pk.View((), dtype=view1.dtype)
new_view[:] = view1.data
view1 = new_view
if isinstance(view2, pk.Subview):
new_view = pk.View((), dtype=view2.dtype)
new_view[:] = view2.data
view2 = new_view
_ufunc_kernel_dispatcher(tid=tid,
dtype=effective_dtype,
ndims=ndims,
Expand Down