From 2c44aa2565538a2de1c9e1d9544aae57033a86d7 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Thu, 16 Feb 2023 11:37:45 -0700 Subject: [PATCH 1/3] ENH: add full_like * add `full_like` creation function and turn the matching array API standard test on in the CI --- .github/workflows/array_api.yml | 2 +- pykokkos/__init__.py | 3 ++- pykokkos/lib/create.py | 9 +++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 2226b675..72cb61e2 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -49,4 +49,4 @@ jobs: # for hypothesis-driven test case generation pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s # only run a subset of the conformance tests to get started - pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like + pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_exp array_api_tests/test_operators_and_elementwise_functions.py::test_sin array_api_tests/test_operators_and_elementwise_functions.py::test_tan array_api_tests/test_operators_and_elementwise_functions.py::test_tanh array_api_tests/test_creation_functions.py::test_zeros array_api_tests/test_creation_functions.py::test_zeros_like array_api_tests/test_creation_functions.py::test_full_like diff --git a/pykokkos/__init__.py b/pykokkos/__init__.py index 0aca2372..bdf4c87c 100644 --- a/pykokkos/__init__.py +++ b/pykokkos/__init__.py @@ -68,7 +68,8 @@ zeros_like, ones, ones_like, - full) + full, + full_like) from pykokkos.lib.manipulate import reshape, ravel, expand_dims from pykokkos.lib.util import all, any, sum, find_max, searchsorted, col, linspace, logspace from pykokkos.lib.constants import e, pi, inf, nan diff --git a/pykokkos/lib/create.py b/pykokkos/lib/create.py index 70f14ee2..ec916f4f 100644 --- a/pykokkos/lib/create.py +++ b/pykokkos/lib/create.py @@ -47,3 +47,12 @@ def full(shape, fill_value, *, dtype=None, device=None): view: pk.View = pk.View([shape], dtype=dtype) view[:] = fill_value return view + + +def full_like(x, /, fill_value, *, dtype=None, device=None): + if dtype is None: + dtype = x.dtype + shape = x.shape + view: pk.View = pk.View([*shape], dtype=dtype) + view[:] = fill_value + return view From 2f8b845c8caaa97348e879cfbf368ecf79a54c7c Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Thu, 29 Dec 2022 12:33:26 -0700 Subject: [PATCH 2/3] ENH: isnan to API standard * `isnan()` already existed, but was not passing the API standard test, so it has been adjusted to pass the test, which is now enforced in CI --- pykokkos/lib/ufunc_workunits.py | 110 ++++++++++++++++++++++++++++++++ pykokkos/lib/ufuncs.py | 39 +++++------ 2 files changed, 126 insertions(+), 23 deletions(-) diff --git a/pykokkos/lib/ufunc_workunits.py b/pykokkos/lib/ufunc_workunits.py index 89b190ca..60a4fffd 100644 --- a/pykokkos/lib/ufunc_workunits.py +++ b/pykokkos/lib/ufunc_workunits.py @@ -225,6 +225,116 @@ def round_impl_3d_float(tid: int, view: pk.View3D[pk.float], out: pk.View3D[pk.f for j in range(view.extent(2)): out[tid][i][j] = round(view[tid][i][j]) +@pk.workunit +def isnan_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_float(tid: int, view: pk.View2D[pk.float], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_uint8(tid: int, view: pk.View2D[pk.uint8], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_uint16(tid: int, view: pk.View2D[pk.uint16], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_uint32(tid: int, view: pk.View2D[pk.uint32], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_uint64(tid: int, view: pk.View2D[pk.uint64], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_int8(tid: int, view: pk.View2D[pk.int8], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_int16(tid: int, view: pk.View2D[pk.int16], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_int32(tid: int, view: pk.View2D[pk.int32], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_2d_int64(tid: int, view: pk.View2D[pk.int64], out: pk.View2D[pk.uint8]): + for i in range(view.extent(1)): + out[tid][i] = isnan(view[tid][i]) + + +@pk.workunit +def isnan_impl_1d_uint8(tid: int, view: pk.View1D[pk.uint8], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_uint16(tid: int, view: pk.View1D[pk.uint16], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_uint32(tid: int, view: pk.View1D[pk.uint32], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_uint64(tid: int, view: pk.View1D[pk.uint64], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_int8(tid: int, view: pk.View1D[pk.int8], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_int16(tid: int, view: pk.View1D[pk.int16], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_int32(tid: int, view: pk.View1D[pk.int32], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_int64(tid: int, view: pk.View1D[pk.int64], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + +@pk.workunit +def isnan_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.uint8]): + out[tid] = isnan(view[tid]) + + @pk.workunit def isfinite_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.uint8]): out[tid] = isfinite(view[tid]) diff --git a/pykokkos/lib/ufuncs.py b/pykokkos/lib/ufuncs.py index 93bb5e6d..bffa6279 100644 --- a/pykokkos/lib/ufuncs.py +++ b/pykokkos/lib/ufuncs.py @@ -2374,30 +2374,23 @@ def index(viewA, viewB): return out -@pk.workunit -def isnan_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.uint8]): - out[tid] = isnan(view[tid]) - - -@pk.workunit -def isnan_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.uint8]): - out[tid] = isnan(view[tid]) - - def isnan(view): - if len(view.shape) > 1: - raise NotImplementedError("isnan() ufunc only supports 1D views") - out = pk.View([*view.shape], dtype=pk.uint8) - if "double" in str(view.dtype) or "float64" in str(view.dtype): - pk.parallel_for(view.shape[0], - isnan_impl_1d_double, - view=view, - out=out) - elif "float" in str(view.dtype): - pk.parallel_for(view.shape[0], - isnan_impl_1d_float, - view=view, - out=out) + dtype = view.dtype + ndims = len(view.shape) + if ndims > 2: + raise NotImplementedError("isnan() ufunc only supports up to 2D views") + out = pk.View([*view.shape], dtype=pk.bool) + if view.shape == (): + tid = 1 + else: + tid = view.shape[0] + _ufunc_kernel_dispatcher(tid=tid, + dtype=dtype, + ndims=ndims, + op="isnan", + sub_dispatcher=pk.parallel_for, + out=out, + view=view) return out From 2f348558bf2528a2d3d1a53416b2a4d90d618a8e Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Thu, 29 Dec 2022 12:10:00 -0700 Subject: [PATCH 3/3] ENH: equal() to API standard * adjust the `equal()` ufunc to pass its array API standard test, and turn this test on in the CI * the changes include improved broadcasting and (type) casting support * because equality testing is so common/rampant, this required a few more shims than I'd like... --- .github/workflows/array_api.yml | 2 +- .github/workflows/main_ci.yml | 2 +- pykokkos/__init__.py | 1 + pykokkos/interface/views.py | 48 ++- pykokkos/lib/ufunc_workunits.py | 642 +++++++++++++++++++++++++++----- pykokkos/lib/ufuncs.py | 116 ++++-- pykokkos/lib/util.py | 6 +- 7 files changed, 696 insertions(+), 121 deletions(-) diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 72cb61e2..729c0b7b 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -30,7 +30,7 @@ jobs: cd /tmp git clone https://github.com/kokkos/pykokkos-base.git cd pykokkos-base - python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF + python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5 - name: Install pykokkos run: | python -m pip install . diff --git a/.github/workflows/main_ci.yml b/.github/workflows/main_ci.yml index 000af412..35326b0c 100644 --- a/.github/workflows/main_ci.yml +++ b/.github/workflows/main_ci.yml @@ -30,7 +30,7 @@ jobs: cd /tmp git clone https://github.com/kokkos/pykokkos-base.git cd pykokkos-base - python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF + python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5 - name: Install pykokkos run: | python -m pip install . diff --git a/pykokkos/__init__.py b/pykokkos/__init__.py index bdf4c87c..9c966187 100644 --- a/pykokkos/__init__.py +++ b/pykokkos/__init__.py @@ -73,6 +73,7 @@ from pykokkos.lib.manipulate import reshape, ravel, expand_dims from pykokkos.lib.util import all, any, sum, find_max, searchsorted, col, linspace, logspace from pykokkos.lib.constants import e, pi, inf, nan +from pykokkos.interface.views import astype __array_api_version__ = "2021.12" diff --git a/pykokkos/interface/views.py b/pykokkos/interface/views.py index 7b169874..0c25dd15 100644 --- a/pykokkos/interface/views.py +++ b/pykokkos/interface/views.py @@ -367,13 +367,43 @@ def _get_type(self, dtype: Union[DataType, type]) -> Optional[DataType]: def __eq__(self, other): - if not isinstance(other, pk.View) and self.rank() > 0: - return [i == other for i in self] - - if self.array == other: - return True + # avoid circular import with scoped import + from pykokkos.lib.ufuncs import equal + if isinstance(other, float): + new_other = pk.View((), dtype=pk.double) + new_other[:] = other + elif isinstance(other, bool): + new_other = pk.View((), dtype=pk.bool) + new_other[:] = other + elif isinstance(other, int): + if self.ndim == 0: + ret = pk.View((), dtype=pk.bool) + ret[:] = int(self) == other + return ret + if 0 <= other <= 255: + other_dtype = pk.uint8 + elif 0 <= other <= 65535: + other_dtype = pk.uint16 + elif 0 <= other <= 4294967295: + other_dtype = pk.uint32 + elif 0 <= other <= 18446744073709551615: + other_dtype = pk.uint64 + elif -128 <= other <= 127: + other_dtype = pk.int8 + elif -32768 <= other <= 32767: + other_dtype = pk.int16 + elif -2147483648 <= other <= 2147483647: + other_dtype = pk.int32 + elif -9223372036854775808 <= other <= 9223372036854775807: + other_dtype = pk.int64 + new_other = pk.View((), dtype=other_dtype) + new_other[:] = other + elif isinstance(other, pk.View): + new_other = other else: - return False + raise ValueError("unexpected types!") + return equal(self, new_other) + def __hash__(self): @@ -785,3 +815,9 @@ class ScratchView7D(ScratchView, Generic[T]): class ScratchView8D(ScratchView, Generic[T]): pass + + +def astype(view, dtype): + new_view = pk.View([*view.shape], dtype=dtype) + new_view[:] = view + return new_view diff --git a/pykokkos/lib/ufunc_workunits.py b/pykokkos/lib/ufunc_workunits.py index 60a4fffd..b4259c9c 100644 --- a/pykokkos/lib/ufunc_workunits.py +++ b/pykokkos/lib/ufunc_workunits.py @@ -89,6 +89,562 @@ def tanh_impl_2d_float(tid: int, view: pk.View2D[pk.float], out: pk.View2D[pk.fl out[tid][i] = tanh(view[tid][i]) +def equal_impl_5d_int8(tid: int, + view1: pk.View5D[pk.int8], + view2: pk.View5D[pk.int8], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_float(tid: int, + view1: pk.View5D[pk.float], + view2: pk.View5D[pk.float], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_double(tid: int, + view1: pk.View5D[pk.double], + view2: pk.View5D[pk.double], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_int16(tid: int, + view1: pk.View5D[pk.int16], + view2: pk.View5D[pk.int16], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_int32(tid: int, + view1: pk.View5D[pk.int32], + view2: pk.View5D[pk.int32], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_int64(tid: int, + view1: pk.View5D[pk.int64], + view2: pk.View5D[pk.int64], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_uint8(tid: int, + view1: pk.View5D[pk.uint8], + view2: pk.View5D[pk.uint8], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_bool(tid: int, + view1: pk.View5D[pk.uint8], + view2: pk.View5D[pk.uint8], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_uint16(tid: int, + view1: pk.View5D[pk.uint16], + view2: pk.View5D[pk.uint16], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_uint32(tid: int, + view1: pk.View5D[pk.uint32], + view2: pk.View5D[pk.uint32], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_5d_uint64(tid: int, + view1: pk.View5D[pk.uint64], + view2: pk.View5D[pk.uint64], + out: pk.View5D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + for l in range(view1.extent(4)): + out[tid][i][j][k][l] = view1[tid][i][j][k][l] == view2[tid][i][j][k][l] + + +@pk.workunit +def equal_impl_4d_uint8(tid: int, + view1: pk.View4D[pk.uint8], + view2: pk.View4D[pk.uint8], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_bool(tid: int, + view1: pk.View4D[pk.uint8], + view2: pk.View4D[pk.uint8], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_float(tid: int, + view1: pk.View4D[pk.float], + view2: pk.View4D[pk.float], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_double(tid: int, + view1: pk.View4D[pk.double], + view2: pk.View4D[pk.double], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_uint16(tid: int, + view1: pk.View4D[pk.uint16], + view2: pk.View4D[pk.uint16], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_uint32(tid: int, + view1: pk.View4D[pk.uint32], + view2: pk.View4D[pk.uint32], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_uint64(tid: int, + view1: pk.View4D[pk.uint64], + view2: pk.View4D[pk.uint64], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_3d_uint8(tid: int, + view1: pk.View3D[pk.uint8], + view2: pk.View3D[pk.uint8], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_bool(tid: int, + view1: pk.View3D[pk.uint8], + view2: pk.View3D[pk.uint8], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_uint16(tid: int, + view1: pk.View3D[pk.uint16], + view2: pk.View3D[pk.uint16], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_uint32(tid: int, + view1: pk.View3D[pk.uint32], + view2: pk.View3D[pk.uint32], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_uint64(tid: int, + view1: pk.View3D[pk.uint64], + view2: pk.View3D[pk.uint64], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_float(tid: int, + view1: pk.View3D[pk.float], + view2: pk.View3D[pk.float], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_double(tid: int, + view1: pk.View3D[pk.double], + view2: pk.View3D[pk.double], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_2d_uint8(tid: int, + view1: pk.View2D[pk.uint8], + view2: pk.View2D[pk.uint8], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_uint16(tid: int, + view1: pk.View2D[pk.uint16], + view2: pk.View2D[pk.uint16], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_uint32(tid: int, + view1: pk.View2D[pk.uint32], + view2: pk.View2D[pk.uint32], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_uint64(tid: int, + view1: pk.View2D[pk.uint64], + view2: pk.View2D[pk.uint64], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_float(tid: int, + view1: pk.View2D[pk.float], + view2: pk.View2D[pk.float], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_double(tid: int, + view1: pk.View2D[pk.double], + view2: pk.View2D[pk.double], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_1d_uint8(tid: int, + view1: pk.View1D[pk.uint8], + view2: pk.View1D[pk.uint8], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_bool(tid: int, + view1: pk.View1D[pk.uint8], + view2: pk.View1D[pk.uint8], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_float(tid: int, + view1: pk.View1D[pk.float], + view2: pk.View1D[pk.float], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_double(tid: int, + view1: pk.View1D[pk.double], + view2: pk.View1D[pk.double], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_int8(tid: int, + view1: pk.View1D[pk.int8], + view2: pk.View1D[pk.int8], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_int16(tid: int, + view1: pk.View1D[pk.int16], + view2: pk.View1D[pk.int16], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_int32(tid: int, + view1: pk.View1D[pk.int32], + view2: pk.View1D[pk.int32], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_int64(tid: int, + view1: pk.View1D[pk.int64], + view2: pk.View1D[pk.int64], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_uint16(tid: int, + view1: pk.View1D[pk.uint16], + view2: pk.View1D[pk.uint16], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_uint32(tid: int, + view1: pk.View1D[pk.uint32], + view2: pk.View1D[pk.uint32], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_uint64(tid: int, + view1: pk.View1D[pk.uint64], + view2: pk.View1D[pk.uint64], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + + +@pk.workunit +def equal_impl_1d_int64(tid: int, + view1: pk.View1D[pk.int64], + view2: pk.View1D[pk.int64], + out: pk.View1D[pk.uint8]): + out[tid] = view1[tid] == view2[tid] + +@pk.workunit +def equal_impl_2d_int8(tid: int, + view1: pk.View2D[pk.int8], + view2: pk.View2D[pk.int8], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_bool(tid: int, + view1: pk.View2D[pk.uint8], + view2: pk.View2D[pk.uint8], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_int16(tid: int, + view1: pk.View2D[pk.int16], + view2: pk.View2D[pk.int16], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_int32(tid: int, + view1: pk.View2D[pk.int32], + view2: pk.View2D[pk.int32], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_2d_int64(tid: int, + view1: pk.View2D[pk.int64], + view2: pk.View2D[pk.int64], + out: pk.View2D[pk.uint8]): + for i in range(view1.extent(1)): + out[tid][i] = view1[tid][i] == view2[tid][i] + + +@pk.workunit +def equal_impl_3d_int8(tid: int, + view1: pk.View3D[pk.int8], + view2: pk.View3D[pk.int8], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_int16(tid: int, + view1: pk.View3D[pk.int16], + view2: pk.View3D[pk.int16], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_int32(tid: int, + view1: pk.View3D[pk.int32], + view2: pk.View3D[pk.int32], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_3d_int64(tid: int, + view1: pk.View3D[pk.int64], + view2: pk.View3D[pk.int64], + out: pk.View3D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + out[tid][i][j] = view1[tid][i][j] == view2[tid][i][j] + + +@pk.workunit +def equal_impl_4d_int8(tid: int, + view1: pk.View4D[pk.int8], + view2: pk.View4D[pk.int8], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_int16(tid: int, + view1: pk.View4D[pk.int16], + view2: pk.View4D[pk.int16], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_int32(tid: int, + view1: pk.View4D[pk.int32], + view2: pk.View4D[pk.int32], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + +@pk.workunit +def equal_impl_4d_int64(tid: int, + view1: pk.View4D[pk.int64], + view2: pk.View4D[pk.int64], + out: pk.View4D[pk.uint8]): + for i in range(view1.extent(1)): + for j in range(view1.extent(2)): + for k in range(view1.extent(3)): + out[tid][i][j][k] = view1[tid][i][j][k] == view2[tid][i][j][k] + + @pk.workunit def floor_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]): out[tid] = floor(view[tid]) @@ -225,6 +781,7 @@ def round_impl_3d_float(tid: int, view: pk.View3D[pk.float], out: pk.View3D[pk.f for j in range(view.extent(2)): out[tid][i][j] = round(view[tid][i][j]) + @pk.workunit def isnan_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View2D[pk.uint8]): for i in range(view.extent(1)): @@ -445,91 +1002,6 @@ def isfinite_impl_2d_uint64(tid: int, view: pk.View2D[pk.uint64], out: pk.View2D out[tid][i] = isfinite(view[tid][i]) # type: ignore -@pk.workunit -def equal_impl_1d_double(tid: int, - view1: pk.View1D[pk.double], - view2: pk.View1D[pk.double], - view2_size: int, - view_result: pk.View1D[pk.uint8]): - view2_idx: int = 0 - if view2_size == 1: - view2_idx = 0 - else: - view2_idx = tid - if view1[tid] == view2[view2_idx]: - view_result[tid] = 1 - else: - view_result[tid] = 0 - - -@pk.workunit -def equal_impl_1d_uint16(tid: int, - view1: pk.View1D[pk.uint16], - view2: pk.View1D[pk.uint16], - view2_size: int, - view_result: pk.View1D[pk.uint8]): - view2_idx: int = 0 - if view2_size == 1: - view2_idx = 0 - else: - view2_idx = tid - if view1[tid] == view2[view2_idx]: - view_result[tid] = 1 - else: - view_result[tid] = 0 - - -@pk.workunit -def equal_impl_1d_int16(tid: int, - view1: pk.View1D[pk.int16], - view2: pk.View1D[pk.int16], - view2_size: int, - view_result: pk.View1D[pk.uint8]): - view2_idx: int = 0 - if view2_size == 1: - view2_idx = 0 - else: - view2_idx = tid - if view1[tid] == view2[view2_idx]: - view_result[tid] = 1 - else: - view_result[tid] = 0 - - -@pk.workunit -def equal_impl_1d_int32(tid: int, - view1: pk.View1D[pk.int32], - view2: pk.View1D[pk.int32], - view2_size: int, - view_result: pk.View1D[pk.uint8]): - view2_idx: int = 0 - if view2_size == 1: - view2_idx = 0 - else: - view2_idx = tid - if view1[tid] == view2[view2_idx]: - view_result[tid] = 1 - else: - view_result[tid] = 0 - - -@pk.workunit -def equal_impl_1d_int64(tid: int, - view1: pk.View1D[pk.int64], - view2: pk.View1D[pk.int64], - view2_size: int, - view_result: pk.View1D[pk.uint8]): - view2_idx: int = 0 - if view2_size == 1: - view2_idx = 0 - else: - view2_idx = tid - if view1[tid] == view2[view2_idx]: - view_result[tid] = 1 - else: - view_result[tid] = 0 - - @pk.workunit def isinf_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.uint8]): out[tid] = isinf(view[tid]) diff --git a/pykokkos/lib/ufuncs.py b/pykokkos/lib/ufuncs.py index bffa6279..2bdf1a7a 100644 --- a/pykokkos/lib/ufuncs.py +++ b/pykokkos/lib/ufuncs.py @@ -42,6 +42,64 @@ def _ufunc_kernel_dispatcher(tid, return ret +def _broadcast_views(view1, view2): + # support broadcasting by using the same + # shape matching rules as NumPy + # TODO: determine if this can be done with + # more memory efficiency? + if view1.shape != view2.shape: + new_shape = np.broadcast_shapes(view1.shape, view2.shape) + view1_new = pk.View([*new_shape], dtype=view1.dtype) + view1_new[:] = view1 + view1 = view1_new + view2_new = pk.View([*new_shape], dtype=view2.dtype) + view2_new[:] = view2 + view2 = view2_new + return view1, view2 + + +def _typematch_views(view1, view2): + # very crude casting implementation + # for binary ufuncs + dtype1 = view1.dtype + dtype2 = view2.dtype + dtype_extractor = re.compile(r".*(?:data_types|DataType)\.(\w+)") + res1 = dtype_extractor.match(str(dtype1)) + res2 = dtype_extractor.match(str(dtype2)) + effective_dtype = dtype1 + if res1 is not None and res2 is not None: + res1_dtype_str = res1.group(1) + res2_dtype_str = res2.group(1) + if res1_dtype_str == "double": + res1_dtype_str = "float64" + elif res1_dtype_str == "float": + res1_dtype_str = "float32" + if res2_dtype_str == "double": + res2_dtype_str = "float64" + elif res2_dtype_str == "float": + res2_dtype_str = "float32" + if res1_dtype_str == "bool" or res2_dtype_str == "bool": + res1_dtype_str = "uint8" + dtype1 = pk.uint8 + res2_dtype_str = "uint8" + dtype2 = pk.uint8 + if (("int" in res1_dtype_str and "int" in res2_dtype_str) or + ("float" in res1_dtype_str and "float" in res2_dtype_str)): + dtype_1_width = int(res1_dtype_str.split("t")[1]) + dtype_2_width = int(res2_dtype_str.split("t")[1]) + if dtype_1_width >= dtype_2_width: + effective_dtype = dtype1 + view2_new = pk.View([*view2.shape], dtype=effective_dtype) + view2_new[:] = view2 + view2 = view2_new + else: + effective_dtype = dtype2 + view1_new = pk.View([*view1.shape], dtype=effective_dtype) + view1_new[:] = view1 + view1 = view1_new + return view1, view2, effective_dtype + + def reciprocal(view): """ Return the reciprocal of the argument, element-wise. @@ -2415,36 +2473,48 @@ def isinf(view): def equal(view1, view2): - # TODO: write even more dispatching for cases where view1 and view2 - # have different, but comparable, types (like float32 vs. float64?) - # this may "explode" without templating - - ndims = len(view2.shape) - dtype = view1.dtype - # array API suite will fail if we check view1.shape here... - if ndims > 1: - raise NotImplementedError("only 1D views currently supported for equal() ufunc.") + """ + Computes the truth value of ``view1_i`` == ``view2_i`` for each element + ``x1_i`` of the input view ``view1`` with the respective element ``x2_i`` + of the input view ``view2``. - if sum(view1.shape) == 0 or sum(view2.shape) == 0: - return np.empty(shape=(0,)) - if view1.shape != view2.shape: - if not view1.size <= 1 and not view2.size <= 1: - # TODO: supporting __eq__ over broadcasted shapes beyond - # scalar (i.e., matching number of columns) - raise ValueError("view1 and view2 have incompatible shapes") + Parameters + ---------- + view1 : pykokkos view + Input view. May have any data type. + view2 : pykokkos view + Input view. May have any data type, but must be shape-compatible + with ``view1`` via broadcasting. - view_result = pk.View([*view1.shape], dtype=pk.uint8) - _ufunc_kernel_dispatcher(tid=view1.size, - dtype=dtype, + Returns + ------- + out : pykokkos view (bool) + Output view. + """ + if view1.size == 0 and view2.size == 0: + return pk.View((), dtype=pk.bool) + view1, view2 = _broadcast_views(view1, view2) + dtype1 = view1.dtype + dtype2 = view2.dtype + view1, view2, effective_dtype = _typematch_views(view1, view2) + ndims = len(view1.shape) + if ndims > 5: + raise NotImplementedError("equal() ufunc only supports up to 5D views") + out = pk.View([*view1.shape], dtype=pk.bool) + if view1.shape == (): + tid = 1 + else: + tid = view1.shape[0] + _ufunc_kernel_dispatcher(tid=tid, + dtype=effective_dtype, ndims=ndims, op="equal", sub_dispatcher=pk.parallel_for, - view_result=view_result, + out=out, view1=view1, - view2=view2, - view2_size=view2.size) - return view_result + view2=view2) + return out def isfinite(view): diff --git a/pykokkos/lib/util.py b/pykokkos/lib/util.py index bf2a87c9..d3217461 100644 --- a/pykokkos/lib/util.py +++ b/pykokkos/lib/util.py @@ -8,10 +8,6 @@ # https://data-apis.org/array-api/2021.12/API_specification/utility_functions.html def all(x, /, *, axis=None, keepdims=False): - if x == True: - return True - elif x == False: - return False np_result = np.all(x) ret_val = pk.from_numpy(np_result) return ret_val @@ -111,4 +107,4 @@ def linspace(start, stop, num=50): def logspace(start, stop, num=50, base=10): y = linspace(start, stop, num) - return power(base, y) \ No newline at end of file + return power(base, y)