Skip to content

Commit fcd3e52

Browse files
authored
Merge pull request #179 from tylerjereddy/treddy_positive_api_std
ENH: positive() to API std
2 parents 4432889 + a46fe03 commit fcd3e52

File tree

3 files changed

+62
-27
lines changed

3 files changed

+62
-27
lines changed

.github/workflows/array_api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ jobs:
4949
# for hypothesis-driven test case generation
5050
pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s
5151
# only run a subset of the conformance tests to get started
52-
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like array_api_tests/test_creation_functions.py::test_full_like
52+
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_operators_and_elementwise_functions.py::test_positive array_api_tests/test_operators_and_elementwise_functions.py::test_isnan array_api_tests/test_operators_and_elementwise_functions.py::test_equal "array_api_tests/test_has_names.py::test_has_names[array_method-__pos__]"

pykokkos/interface/views.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,10 @@ def __array__(self, dtype=None):
422422
return self.data
423423

424424

425+
def __pos__(self):
426+
return pk.positive(self)
427+
428+
425429
@staticmethod
426430
def _get_dtype_name(type_name: str) -> str:
427431
"""
@@ -520,11 +524,42 @@ def _get_base_view(self, parent_view: Union[Subview, View]) -> View:
520524
return base_view
521525

522526
def __eq__(self, other):
523-
if isinstance(other, View):
524-
if len(self.data) == 0 and len(other.data) == 0:
525-
return True
526-
result_of_eq = self.data == other.data
527-
return result_of_eq
527+
# avoid circular import with scoped import
528+
from pykokkos.lib.ufuncs import equal
529+
if isinstance(other, float):
530+
new_other = pk.View((), dtype=pk.double)
531+
new_other[:] = other
532+
elif isinstance(other, bool):
533+
new_other = pk.View((), dtype=pk.bool)
534+
new_other[:] = other
535+
elif isinstance(other, int):
536+
if self.ndim == 0:
537+
ret = pk.View((), dtype=pk.bool)
538+
ret[:] = int(self) == other
539+
return ret
540+
if 0 <= other <= 255:
541+
other_dtype = pk.uint8
542+
elif 0 <= other <= 65535:
543+
other_dtype = pk.uint16
544+
elif 0 <= other <= 4294967295:
545+
other_dtype = pk.uint32
546+
elif 0 <= other <= 18446744073709551615:
547+
other_dtype = pk.uint64
548+
elif -128 <= other <= 127:
549+
other_dtype = pk.int8
550+
elif -32768 <= other <= 32767:
551+
other_dtype = pk.int16
552+
elif -2147483648 <= other <= 2147483647:
553+
other_dtype = pk.int32
554+
elif -9223372036854775808 <= other <= 9223372036854775807:
555+
other_dtype = pk.int64
556+
new_other = pk.View((), dtype=other_dtype)
557+
new_other[:] = other
558+
elif isinstance(other, pk.Subview):
559+
new_other = other
560+
else:
561+
raise ValueError("unexpected types!")
562+
return equal(self, new_other)
528563

529564

530565
def __add__(self, other):

pykokkos/lib/ufuncs.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ def _typematch_views(view1, view2):
9090
if dtype_1_width >= dtype_2_width:
9191
effective_dtype = dtype1
9292
view2_new = pk.View([*view2.shape], dtype=effective_dtype)
93-
view2_new[:] = view2
93+
view2_new[:] = view2.data
9494
view2 = view2_new
9595
else:
9696
effective_dtype = dtype2
9797
view1_new = pk.View([*view1.shape], dtype=effective_dtype)
98-
view1_new[:] = view1
98+
view1_new[:] = view1.data
9999
view1 = view1_new
100100
return view1, view2, effective_dtype
101101

@@ -1152,15 +1152,6 @@ def negative(view):
11521152
return out
11531153

11541154

1155-
@pk.workunit
1156-
def positive_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]):
1157-
out[tid] = view[tid]
1158-
1159-
1160-
@pk.workunit
1161-
def positive_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.float]):
1162-
out[tid] = view[tid]
1163-
11641155
def positive(view):
11651156
"""
11661157
Element-wise positive of the view;
@@ -1177,16 +1168,11 @@ def positive(view):
11771168
Output view.
11781169
11791170
"""
1180-
if len(view.shape) > 1:
1181-
raise NotImplementedError("only 1D views currently supported for positive() ufunc.")
1182-
if str(view.dtype) == "DataType.double":
1183-
out = pk.View([view.shape[0]], pk.double)
1184-
pk.parallel_for(view.shape[0], positive_impl_1d_double, view=view, out=out)
1185-
elif str(view.dtype) == "DataType.float":
1186-
out = pk.View([view.shape[0]], pk.float)
1187-
pk.parallel_for(view.shape[0], positive_impl_1d_float, view=view, out=out)
1171+
if view.shape == ():
1172+
out = pk.View((), dtype=view.dtype)
11881173
else:
1189-
raise NotImplementedError
1174+
out = pk.View([*view.shape], dtype=view.dtype)
1175+
out[...] = view
11901176
return out
11911177

11921178

@@ -2442,6 +2428,10 @@ def isnan(view):
24422428
tid = 1
24432429
else:
24442430
tid = view.shape[0]
2431+
if view.ndim == 0:
2432+
new_view = pk.View([1], dtype=view.dtype)
2433+
new_view[0] = view
2434+
view = new_view
24452435
_ufunc_kernel_dispatcher(tid=tid,
24462436
dtype=dtype,
24472437
ndims=ndims,
@@ -2493,7 +2483,9 @@ def equal(view1, view2):
24932483
Output view.
24942484
"""
24952485
if view1.size == 0 and view2.size == 0:
2496-
return pk.View((), dtype=pk.bool)
2486+
ret = pk.View((), dtype=pk.bool)
2487+
ret[...] = 1
2488+
return ret
24972489
view1, view2 = _broadcast_views(view1, view2)
24982490
dtype1 = view1.dtype
24992491
dtype2 = view2.dtype
@@ -2506,6 +2498,14 @@ def equal(view1, view2):
25062498
tid = 1
25072499
else:
25082500
tid = view1.shape[0]
2501+
if isinstance(view1, pk.Subview):
2502+
new_view = pk.View((), dtype=view1.dtype)
2503+
new_view[:] = view1.data
2504+
view1 = new_view
2505+
if isinstance(view2, pk.Subview):
2506+
new_view = pk.View((), dtype=view2.dtype)
2507+
new_view[:] = view2.data
2508+
view2 = new_view
25092509
_ufunc_kernel_dispatcher(tid=tid,
25102510
dtype=effective_dtype,
25112511
ndims=ndims,

0 commit comments

Comments
 (0)