diff --git a/examples/kokkos-tutorials/workload/team_scratch_workunit.py b/examples/kokkos-tutorials/workload/team_scratch_workunit.py index 09a5c9c7..01e29931 100644 --- a/examples/kokkos-tutorials/workload/team_scratch_workunit.py +++ b/examples/kokkos-tutorials/workload/team_scratch_workunit.py @@ -71,10 +71,7 @@ def single_closure(): acc = 0 timer = pk.Timer() - # For workunits, pass M and let the C++ code compute scratch size - # Approximate scratch size: M * sizeof(double) = M * 8 bytes - # scratch_size: int = pk.ScratchView1D[float].shmem_size(M) - scratch_size: int = M * 8 + scratch_size: int = pk.ScratchView1D[float].shmem_size(M) print(f"Before: {N} | {M} | {E}") for i in range(nrepeat): @@ -90,4 +87,3 @@ def single_closure(): print( f"result: {result} | solution {solution} | result==soluton: {result==solution}" ) - print(f"Total size S = {N * M} N = {N} M = {M} E = {E}") diff --git a/pykokkos/interface/views.py b/pykokkos/interface/views.py index 787c5435..22274bd6 100644 --- a/pykokkos/interface/views.py +++ b/pykokkos/interface/views.py @@ -400,7 +400,7 @@ def _init_view( self.xp_array = cp_array else: self.xp_array = array - + else: if len(self.shape) == 0: shape = [1] @@ -488,8 +488,8 @@ def __hash__(self): def __index__(self) -> int: return int(self.data[0]) - - + + def __array__(self, dtype=None): return self.data @@ -819,12 +819,12 @@ def is_array(array) -> bool: :param array: the array of unknown type :returns: a true/false if object is an array-like struct """ - + test_attr = dir(array) if(not set(ARRAY_REQ_ATTR).issubset(set(test_attr))): return False - + for d in ARRAY_REQ_ATTR: if callable(getattr(array, d, None)): return False @@ -962,41 +962,246 @@ class View8D(Generic[T]): pass +def _get_type_size_from_generic(cls) -> int: + """ + Extract the type parameter from a parameterized generic class and return its size in bytes. + + :param cls: The parameterized generic class (e.g., ScratchView1D[float]) + :returns: The size of the type in bytes + """ + # Check if this is a parameterized generic (has __args__) + if not hasattr(cls, 'type_param') or not cls.type_param: + raise TypeError(f"Cannot determine type size for unparameterized {cls.__name__}. Use {cls.__name__}[type] format.") + + type_param = cls.type_param + + # Map Python types and DataTypeClass to numpy dtypes + # Check DataTypeClass first (like pk.float, pk.double) since they are classes + if isinstance(type_param, type) and issubclass(type_param, DataTypeClass): + # Get the numpy equivalent from the DataTypeClass + np_dtype_class = type_param.np_equiv + if np_dtype_class is None: + raise TypeError(f"Cannot determine size for type {type_param.__name__}") + # Convert numpy dtype class to dtype instance + dtype = np.dtype(np_dtype_class) + elif type_param is int: + dtype = np.dtype(np.int32) + elif type_param is float: + dtype = np.dtype(np.float64) + elif isinstance(type_param, DataType): + # Handle DataType enum values + if type_param is DataType.real: + default_prec = km.get_default_precision() + if default_prec is float32: + dtype = np.dtype(np.float32) + else: + dtype = np.dtype(np.float64) + elif type_param in {DataType.float, DataType.float32}: + dtype = np.dtype(np.float32) + elif type_param in {DataType.double, DataType.float64}: + dtype = np.dtype(np.float64) + elif type_param is DataType.int8: + dtype = np.dtype(np.int8) + elif type_param is DataType.int16: + dtype = np.dtype(np.int16) + elif type_param is DataType.int32: + dtype = np.dtype(np.int32) + elif type_param is DataType.int64: + dtype = np.dtype(np.int64) + elif type_param is DataType.uint8: + dtype = np.dtype(np.uint8) + elif type_param is DataType.uint16: + dtype = np.dtype(np.uint16) + elif type_param is DataType.uint32: + dtype = np.dtype(np.uint32) + elif type_param is DataType.uint64: + dtype = np.dtype(np.uint64) + elif type_param is DataType.complex64: + dtype = np.dtype(np.complex64) + elif type_param is DataType.complex128: + dtype = np.dtype(np.complex128) + else: + raise TypeError(f"Unsupported DataType: {type_param}") + else: + # Try to use the type directly as a numpy dtype + try: + dtype = np.dtype(type_param) + except (TypeError, ValueError): + raise TypeError(f"Unsupported type for scratch view: {type_param}") + + # Get itemsize as an int + type_size = int(dtype.itemsize) + return type_size + + +def _calculate_scratch_size(type_size: int, *dims: int, alignment: int = 8) -> int: + """ + Calculate scratch memory size for a scratch view. + + :param type_size: Size of the element type in bytes + :param dims: Dimensions of the scratch view + :param alignment: Alignment requirement (default 8 bytes, matching Kokkos) + :returns: Total scratch memory size in bytes + """ + # Calculate total number of elements + total_elements = 1 + for dim in dims: + total_elements *= dim + + # Calculate raw size + raw_size = total_elements * type_size + + # Align to the specified alignment (typically 8 bytes for Kokkos) + aligned_size = ((raw_size + alignment - 1) // alignment) * alignment + + return aligned_size + + class ScratchView: - def shmem_size(i: int): - pass + def __class_getitem__(cls, item): + generic_alias = super().__class_getitem__(item) + generic_alias.type_param = item + return generic_alias + + @staticmethod + def shmem_size(i: int) -> int: + """ + Calculate shared memory size for a scratch view. + This is a base implementation that should be overridden by specific view types. + """ + raise NotImplementedError("shmem_size must be implemented by specific ScratchView types") class ScratchView1D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int) -> int: + """ + Calculate shared memory size for a 1D scratch view. + + :param dim0: Size of the first dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0) class ScratchView2D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int, dim1: int) -> int: + """ + Calculate shared memory size for a 2D scratch view. + + :param dim0: Size of the first dimension + :param dim1: Size of the second dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0, dim1) class ScratchView3D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int, dim1: int, dim2: int) -> int: + """ + Calculate shared memory size for a 3D scratch view. + + :param dim0: Size of the first dimension + :param dim1: Size of the second dimension + :param dim2: Size of the third dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0, dim1, dim2) class ScratchView4D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int) -> int: + """ + Calculate shared memory size for a 4D scratch view. + + :param dim0: Size of the first dimension + :param dim1: Size of the second dimension + :param dim2: Size of the third dimension + :param dim3: Size of the fourth dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3) class ScratchView5D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int) -> int: + """ + Calculate shared memory size for a 5D scratch view. + + :param dim0: Size of the first dimension + :param dim1: Size of the second dimension + :param dim2: Size of the third dimension + :param dim3: Size of the fourth dimension + :param dim4: Size of the fifth dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4) class ScratchView6D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int, dim5: int) -> int: + """ + Calculate shared memory size for a 6D scratch view. + + :param dim0: Size of the first dimension + :param dim1: Size of the second dimension + :param dim2: Size of the third dimension + :param dim3: Size of the fourth dimension + :param dim4: Size of the fifth dimension + :param dim5: Size of the sixth dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4, dim5) class ScratchView7D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int, dim5: int, dim6: int) -> int: + """ + Calculate shared memory size for a 7D scratch view. + + :param dim0: Size of the first dimension + :param dim1: Size of the second dimension + :param dim2: Size of the third dimension + :param dim3: Size of the fourth dimension + :param dim4: Size of the fifth dimension + :param dim5: Size of the sixth dimension + :param dim6: Size of the seventh dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4, dim5, dim6) class ScratchView8D(ScratchView, Generic[T]): - pass + @classmethod + def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int, dim5: int, dim6: int, dim7: int) -> int: + """ + Calculate shared memory size for an 8D scratch view. + + :param dim0: Size of the first dimension + :param dim1: Size of the second dimension + :param dim2: Size of the third dimension + :param dim3: Size of the fourth dimension + :param dim4: Size of the fifth dimension + :param dim5: Size of the sixth dimension + :param dim6: Size of the seventh dimension + :param dim7: Size of the eighth dimension + :returns: Total scratch memory size in bytes + """ + type_size = _get_type_size_from_generic(cls) + return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7) def astype(view, dtype): diff --git a/tests/test_scratch_size.py b/tests/test_scratch_size.py index b4b16100..fb5a22a4 100644 --- a/tests/test_scratch_size.py +++ b/tests/test_scratch_size.py @@ -1,4 +1,5 @@ import unittest +import numpy as np import pykokkos as pk @@ -174,5 +175,278 @@ def test_scratch_size_multiple_iterations(self): self.assertAlmostEqual(expected_result, result, places=5) +class TestScratchViewShmemSizeFail(unittest.TestCase): + """ + Tests that scratch view size selection fails when the type is not specified. + """ + + def test_scratch_view_unparameterized_error(self): + """Test that unparameterized ScratchView raises error""" + try: + result = pk.ScratchView1D.shmem_size(10) + print(f"No exception raised. Result: {result}") + except Exception as e: + print(f"Exception type: {type(e).__name__}") + print(f"Exception message: {e}") + + with self.assertRaises(TypeError): + pk.ScratchView1D.shmem_size(10) + + +class TestScratchViewShmemSize(unittest.TestCase): + """ + Unit tests for ScratchView shmem_size methods. + Tests correctness of return values for all dimensions and types. + """ + + def _calculate_expected_size(self, dtype_or_size, *dims: int, alignment: int = 8) -> int: + """ + Helper to calculate expected scratch size with alignment. + + :param dtype_or_size: Either a numpy dtype class (e.g., np.float64) or an int size in bytes + :param dims: Dimensions of the scratch view + :param alignment: Alignment requirement (default 8 bytes) + :returns: Total scratch memory size in bytes + """ + if isinstance(dtype_or_size, int): + type_size = dtype_or_size + else: + # Assume it's a numpy dtype class, convert to dtype instance and get itemsize + type_size = int(np.dtype(dtype_or_size).itemsize) + + total_elements = 1 + for dim in dims: + total_elements *= dim + raw_size = total_elements * type_size + aligned_size = ((raw_size + alignment - 1) // alignment) * alignment + return aligned_size + + # Test ScratchView1D + def test_scratch_view_1d_float(self): + """Test ScratchView1D[float].shmem_size""" + dim = 10 + result = pk.ScratchView1D[float].shmem_size(dim) + expected = self._calculate_expected_size(np.float64, dim) # float is float64 + self.assertEqual(result, expected) + + def test_scratch_view_1d_pk_float(self): + """Test ScratchView1D[pk.float].shmem_size""" + dim = 10 + result = pk.ScratchView1D[pk.float].shmem_size(dim) + expected = self._calculate_expected_size(np.float32, dim) # pk.float is float32 + self.assertEqual(result, expected) + + def test_scratch_view_1d_pk_double(self): + """Test ScratchView1D[pk.double].shmem_size""" + dim = 10 + result = pk.ScratchView1D[pk.double].shmem_size(dim) + expected = self._calculate_expected_size(np.float64, dim) # pk.double is float64 + self.assertEqual(result, expected) + + def test_scratch_view_1d_int(self): + """Test ScratchView1D[int].shmem_size""" + dim = 10 + result = pk.ScratchView1D[int].shmem_size(dim) + expected = self._calculate_expected_size(np.int32, dim) # int is int32 + self.assertEqual(result, expected) + + def test_scratch_view_1d_pk_int32(self): + """Test ScratchView1D[pk.int32].shmem_size""" + dim = 10 + result = pk.ScratchView1D[pk.int32].shmem_size(dim) + expected = self._calculate_expected_size(np.int32, dim) # int32 + self.assertEqual(result, expected) + + def test_scratch_view_1d_pk_int64(self): + """Test ScratchView1D[pk.int64].shmem_size""" + dim = 10 + result = pk.ScratchView1D[pk.int64].shmem_size(dim) + expected = self._calculate_expected_size(np.int64, dim) # int64 + self.assertEqual(result, expected) + + def test_scratch_view_1d_pk_uint8(self): + """Test ScratchView1D[pk.uint8].shmem_size""" + dim = 10 + result = pk.ScratchView1D[pk.uint8].shmem_size(dim) + expected = self._calculate_expected_size(np.uint8, dim) # uint8 + self.assertEqual(result, expected) + + def test_scratch_view_1d_alignment(self): + """Test that ScratchView1D properly aligns to 8 bytes""" + # 3 elements of float32 (4 bytes) = 12 bytes, should align to 16 bytes + dim = 3 + result = pk.ScratchView1D[pk.float].shmem_size(dim) + expected = self._calculate_expected_size(np.float32, dim) # 12 bytes -> 16 bytes aligned + self.assertEqual(result, expected) + # Explicit check: 3 * 4 = 12, aligned to 8 = 16 + self.assertEqual(result, 16) + + # Test ScratchView2D + def test_scratch_view_2d_float(self): + """Test ScratchView2D[float].shmem_size""" + dim0, dim1 = 5, 10 + result = pk.ScratchView2D[float].shmem_size(dim0, dim1) + expected = self._calculate_expected_size(np.float64, dim0, dim1) + self.assertEqual(result, expected) + + def test_scratch_view_2d_pk_float(self): + """Test ScratchView2D[pk.float].shmem_size""" + dim0, dim1 = 5, 10 + result = pk.ScratchView2D[pk.float].shmem_size(dim0, dim1) + expected = self._calculate_expected_size(np.float32, dim0, dim1) + self.assertEqual(result, expected) + + def test_scratch_view_2d_alignment(self): + """Test that ScratchView2D properly aligns""" + # 3*3*4 = 36 bytes, should align to 40 bytes + dim0, dim1 = 3, 3 + result = pk.ScratchView2D[pk.float].shmem_size(dim0, dim1) + expected = self._calculate_expected_size(np.float32, dim0, dim1) + self.assertEqual(result, expected) + # Explicit check: 3*3*4 = 36, aligned to 8 = 40 + self.assertEqual(result, 40) + + # Test ScratchView3D + def test_scratch_view_3d_float(self): + """Test ScratchView3D[float].shmem_size""" + dim0, dim1, dim2 = 2, 3, 4 + result = pk.ScratchView3D[float].shmem_size(dim0, dim1, dim2) + expected = self._calculate_expected_size(np.float64, dim0, dim1, dim2) + self.assertEqual(result, expected) + + def test_scratch_view_3d_pk_double(self): + """Test ScratchView3D[pk.double].shmem_size""" + dim0, dim1, dim2 = 2, 3, 4 + result = pk.ScratchView3D[pk.double].shmem_size(dim0, dim1, dim2) + expected = self._calculate_expected_size(np.float64, dim0, dim1, dim2) + self.assertEqual(result, expected) + + # Test ScratchView4D + def test_scratch_view_4d_float(self): + """Test ScratchView4D[float].shmem_size""" + dim0, dim1, dim2, dim3 = 2, 2, 2, 2 + result = pk.ScratchView4D[float].shmem_size(dim0, dim1, dim2, dim3) + expected = self._calculate_expected_size(np.float64, dim0, dim1, dim2, dim3) + self.assertEqual(result, expected) + + def test_scratch_view_4d_pk_int32(self): + """Test ScratchView4D[pk.int32].shmem_size""" + dim0, dim1, dim2, dim3 = 2, 2, 2, 2 + result = pk.ScratchView4D[pk.int32].shmem_size(dim0, dim1, dim2, dim3) + expected = self._calculate_expected_size(np.int32, dim0, dim1, dim2, dim3) + self.assertEqual(result, expected) + + # Test ScratchView5D + def test_scratch_view_5d_float(self): + """Test ScratchView5D[float].shmem_size""" + dim0, dim1, dim2, dim3, dim4 = 2, 2, 2, 2, 2 + result = pk.ScratchView5D[float].shmem_size(dim0, dim1, dim2, dim3, dim4) + expected = self._calculate_expected_size(np.float64, dim0, dim1, dim2, dim3, dim4) + self.assertEqual(result, expected) + + # Test ScratchView6D + def test_scratch_view_6d_float(self): + """Test ScratchView6D[float].shmem_size""" + dim0, dim1, dim2, dim3, dim4, dim5 = 2, 2, 2, 2, 2, 2 + result = pk.ScratchView6D[float].shmem_size(dim0, dim1, dim2, dim3, dim4, dim5) + expected = self._calculate_expected_size(np.float64, dim0, dim1, dim2, dim3, dim4, dim5) + self.assertEqual(result, expected) + + # Test ScratchView7D + def test_scratch_view_7d_float(self): + """Test ScratchView7D[float].shmem_size""" + dim0, dim1, dim2, dim3, dim4, dim5, dim6 = 2, 2, 2, 2, 2, 2, 2 + result = pk.ScratchView7D[float].shmem_size(dim0, dim1, dim2, dim3, dim4, dim5, dim6) + expected = self._calculate_expected_size(np.float64, dim0, dim1, dim2, dim3, dim4, dim5, dim6) + self.assertEqual(result, expected) + + # Test ScratchView8D + def test_scratch_view_8d_float(self): + """Test ScratchView8D[float].shmem_size""" + dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7 = 2, 2, 2, 2, 2, 2, 2, 2 + result = pk.ScratchView8D[float].shmem_size(dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7) + expected = self._calculate_expected_size(np.float64, dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7) + self.assertEqual(result, expected) + + # Test edge cases + def test_scratch_view_1d_small_size(self): + """Test ScratchView1D with small size""" + dim = 1 + result = pk.ScratchView1D[float].shmem_size(dim) + expected = self._calculate_expected_size(np.float64, dim) + self.assertEqual(result, expected) + # Explicit check: 1 * 8 = 8, already aligned + self.assertEqual(result, int(np.dtype(np.float64).itemsize)) + + def test_scratch_view_1d_large_size(self): + """Test ScratchView1D with large size""" + dim = 1000 + result = pk.ScratchView1D[float].shmem_size(dim) + expected = self._calculate_expected_size(np.float64, dim) + self.assertEqual(result, expected) + # Explicit check: 1000 * 8 = 8000, already aligned + self.assertEqual(result, 1000 * int(np.dtype(np.float64).itemsize)) + + def test_scratch_view_1d_odd_alignment(self): + """Test ScratchView1D with size requiring alignment""" + # 5 elements of float32 = 20 bytes, should align to 24 bytes + dim = 5 + result = pk.ScratchView1D[pk.float].shmem_size(dim) + expected = self._calculate_expected_size(np.float32, dim) + self.assertEqual(result, expected) + # Explicit check: 5 * 4 = 20, aligned to 8 = 24 + self.assertEqual(result, 24) + + # Test different integer types + def test_scratch_view_1d_all_int_types(self): + """Test ScratchView1D with all integer types""" + dim = 10 + test_cases = [ + (pk.int8, np.int8), + (pk.int16, np.int16), + (pk.int32, np.int32), + (pk.int64, np.int64), + (pk.uint8, np.uint8), + (pk.uint16, np.uint16), + (pk.uint32, np.uint32), + (pk.uint64, np.uint64), + ] + for type_class, np_dtype in test_cases: + with self.subTest(type_class=type_class): + result = pk.ScratchView1D[type_class].shmem_size(dim) + expected = self._calculate_expected_size(np_dtype, dim) + self.assertEqual(result, expected) + + def test_scratch_view_2d_various_types(self): + """Test ScratchView2D with various types""" + dim0, dim1 = 4, 4 + test_cases = [ + (float, np.float64), + (pk.float, np.float32), + (pk.double, np.float64), + (int, np.int32), + (pk.int32, np.int32), + ] + for type_class, np_dtype in test_cases: + with self.subTest(type_class=type_class): + result = pk.ScratchView2D[type_class].shmem_size(dim0, dim1) + expected = self._calculate_expected_size(np_dtype, dim0, dim1) + self.assertEqual(result, expected) + + def test_scratch_view_3d_various_types(self): + """Test ScratchView3D with various types""" + dim0, dim1, dim2 = 3, 3, 3 + test_cases = [ + (float, np.float64), + (pk.float, np.float32), + (pk.int32, np.int32), + ] + for type_class, np_dtype in test_cases: + with self.subTest(type_class=type_class): + result = pk.ScratchView3D[type_class].shmem_size(dim0, dim1, dim2) + expected = self._calculate_expected_size(np_dtype, dim0, dim1, dim2) + self.assertEqual(result, expected) + + if __name__ == "__main__": unittest.main()