diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 61bb5eca..d9116af3 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -49,4 +49,4 @@ jobs: # for hypothesis-driven test case generation pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s # only run a subset of the conformance tests to get started - pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like array_api_tests/test_creation_functions.py::test_full_like + pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_operators_and_elementwise_functions.py::test_positive array_api_tests/test_operators_and_elementwise_functions.py::test_isnan array_api_tests/test_operators_and_elementwise_functions.py::test_equal "array_api_tests/test_has_names.py::test_has_names[array_method-__pos__]" diff --git a/pykokkos/interface/views.py b/pykokkos/interface/views.py index 0c25dd15..7005fd52 100644 --- a/pykokkos/interface/views.py +++ b/pykokkos/interface/views.py @@ -422,6 +422,10 @@ def __array__(self, dtype=None): return self.data + def __pos__(self): + return pk.positive(self) + + @staticmethod def _get_dtype_name(type_name: str) -> str: """ @@ -520,11 +524,42 @@ def _get_base_view(self, parent_view: Union[Subview, View]) -> View: return base_view def __eq__(self, other): - if isinstance(other, View): - if len(self.data) == 0 and len(other.data) == 0: - return True - result_of_eq = self.data == other.data - return result_of_eq + # avoid circular import with scoped import + from pykokkos.lib.ufuncs import equal + if isinstance(other, float): + new_other = pk.View((), dtype=pk.double) + new_other[:] = other + elif isinstance(other, bool): + new_other = pk.View((), dtype=pk.bool) + new_other[:] = other + elif isinstance(other, int): + if self.ndim == 0: + ret = pk.View((), dtype=pk.bool) + ret[:] = int(self) == other + return ret + if 0 <= other <= 255: + other_dtype = pk.uint8 + elif 0 <= other <= 65535: + other_dtype = pk.uint16 + elif 0 <= other <= 4294967295: + other_dtype = pk.uint32 + elif 0 <= other <= 18446744073709551615: + other_dtype = pk.uint64 + elif -128 <= other <= 127: + other_dtype = pk.int8 + elif -32768 <= other <= 32767: + other_dtype = pk.int16 + elif -2147483648 <= other <= 2147483647: + other_dtype = pk.int32 + elif -9223372036854775808 <= other <= 9223372036854775807: + other_dtype = pk.int64 + new_other = pk.View((), dtype=other_dtype) + new_other[:] = other + elif isinstance(other, pk.Subview): + new_other = other + else: + raise ValueError("unexpected types!") + return equal(self, new_other) def __add__(self, other): diff --git a/pykokkos/lib/ufuncs.py b/pykokkos/lib/ufuncs.py index 2bdf1a7a..1f20a993 100644 --- a/pykokkos/lib/ufuncs.py +++ b/pykokkos/lib/ufuncs.py @@ -90,12 +90,12 @@ def _typematch_views(view1, view2): if dtype_1_width >= dtype_2_width: effective_dtype = dtype1 view2_new = pk.View([*view2.shape], dtype=effective_dtype) - view2_new[:] = view2 + view2_new[:] = view2.data view2 = view2_new else: effective_dtype = dtype2 view1_new = pk.View([*view1.shape], dtype=effective_dtype) - view1_new[:] = view1 + view1_new[:] = view1.data view1 = view1_new return view1, view2, effective_dtype @@ -1152,15 +1152,6 @@ def negative(view): return out -@pk.workunit -def positive_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]): - out[tid] = view[tid] - - -@pk.workunit -def positive_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.float]): - out[tid] = view[tid] - def positive(view): """ Element-wise positive of the view; @@ -1177,16 +1168,11 @@ def positive(view): Output view. """ - if len(view.shape) > 1: - raise NotImplementedError("only 1D views currently supported for positive() ufunc.") - if str(view.dtype) == "DataType.double": - out = pk.View([view.shape[0]], pk.double) - pk.parallel_for(view.shape[0], positive_impl_1d_double, view=view, out=out) - elif str(view.dtype) == "DataType.float": - out = pk.View([view.shape[0]], pk.float) - pk.parallel_for(view.shape[0], positive_impl_1d_float, view=view, out=out) + if view.shape == (): + out = pk.View((), dtype=view.dtype) else: - raise NotImplementedError + out = pk.View([*view.shape], dtype=view.dtype) + out[...] = view return out @@ -2442,6 +2428,10 @@ def isnan(view): tid = 1 else: tid = view.shape[0] + if view.ndim == 0: + new_view = pk.View([1], dtype=view.dtype) + new_view[0] = view + view = new_view _ufunc_kernel_dispatcher(tid=tid, dtype=dtype, ndims=ndims, @@ -2493,7 +2483,9 @@ def equal(view1, view2): Output view. """ if view1.size == 0 and view2.size == 0: - return pk.View((), dtype=pk.bool) + ret = pk.View((), dtype=pk.bool) + ret[...] = 1 + return ret view1, view2 = _broadcast_views(view1, view2) dtype1 = view1.dtype dtype2 = view2.dtype @@ -2506,6 +2498,14 @@ def equal(view1, view2): tid = 1 else: tid = view1.shape[0] + if isinstance(view1, pk.Subview): + new_view = pk.View((), dtype=view1.dtype) + new_view[:] = view1.data + view1 = new_view + if isinstance(view2, pk.Subview): + new_view = pk.View((), dtype=view2.dtype) + new_view[:] = view2.data + view2 = new_view _ufunc_kernel_dispatcher(tid=tid, dtype=effective_dtype, ndims=ndims,