Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/kokkos-tutorials/workload/team_scratch_workunit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def single_closure():
acc = 0

timer = pk.Timer()
# For workunits, pass M and let the C++ code compute scratch size
# Approximate scratch size: M * sizeof(double) = M * 8 bytes
# scratch_size: int = pk.ScratchView1D[float].shmem_size(M)
scratch_size: int = M * 8
scratch_size: int = pk.ScratchView1D[float].shmem_size(M)
print(f"Before: {N} | {M} | {E}")

for i in range(nrepeat):
Expand All @@ -90,4 +87,3 @@ def single_closure():
print(
f"result: {result} | solution {solution} | result==soluton: {result==solution}"
)
print(f"Total size S = {N * M} N = {N} M = {M} E = {E}")
235 changes: 220 additions & 15 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def _init_view(
self.xp_array = cp_array
else:
self.xp_array = array

else:
if len(self.shape) == 0:
shape = [1]
Expand Down Expand Up @@ -488,8 +488,8 @@ def __hash__(self):

def __index__(self) -> int:
return int(self.data[0])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we formatting code?
Can we please check formatting of all files in our CI.

Copy link
Collaborator Author

@IvanGrigorik IvanGrigorik Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a bit problem with file formatting in PyKokkos
I will open another PR in which I will format all pykokkos file in a single black stype


def __array__(self, dtype=None):
return self.data

Expand Down Expand Up @@ -819,12 +819,12 @@ def is_array(array) -> bool:
:param array: the array of unknown type
:returns: a true/false if object is an array-like struct
"""

test_attr = dir(array)

if(not set(ARRAY_REQ_ATTR).issubset(set(test_attr))):
return False

for d in ARRAY_REQ_ATTR:
if callable(getattr(array, d, None)):
return False
Expand Down Expand Up @@ -962,41 +962,246 @@ class View8D(Generic[T]):
pass


def _get_type_size_from_generic(cls) -> int:
"""
Extract the type parameter from a parameterized generic class and return its size in bytes.

:param cls: The parameterized generic class (e.g., ScratchView1D[float])
:returns: The size of the type in bytes
"""
# Check if this is a parameterized generic (has __args__)
if not hasattr(cls, 'type_param') or not cls.type_param:
raise TypeError(f"Cannot determine type size for unparameterized {cls.__name__}. Use {cls.__name__}[type] format.")

type_param = cls.type_param

# Map Python types and DataTypeClass to numpy dtypes
# Check DataTypeClass first (like pk.float, pk.double) since they are classes
if isinstance(type_param, type) and issubclass(type_param, DataTypeClass):
# Get the numpy equivalent from the DataTypeClass
np_dtype_class = type_param.np_equiv
if np_dtype_class is None:
raise TypeError(f"Cannot determine size for type {type_param.__name__}")
# Convert numpy dtype class to dtype instance
dtype = np.dtype(np_dtype_class)
elif type_param is int:
dtype = np.dtype(np.int32)
elif type_param is float:
dtype = np.dtype(np.float64)
elif isinstance(type_param, DataType):
# Handle DataType enum values
if type_param is DataType.real:
default_prec = km.get_default_precision()
if default_prec is float32:
dtype = np.dtype(np.float32)
else:
dtype = np.dtype(np.float64)
elif type_param in {DataType.float, DataType.float32}:
dtype = np.dtype(np.float32)
elif type_param in {DataType.double, DataType.float64}:
dtype = np.dtype(np.float64)
elif type_param is DataType.int8:
dtype = np.dtype(np.int8)
elif type_param is DataType.int16:
dtype = np.dtype(np.int16)
elif type_param is DataType.int32:
dtype = np.dtype(np.int32)
elif type_param is DataType.int64:
dtype = np.dtype(np.int64)
elif type_param is DataType.uint8:
dtype = np.dtype(np.uint8)
elif type_param is DataType.uint16:
dtype = np.dtype(np.uint16)
elif type_param is DataType.uint32:
dtype = np.dtype(np.uint32)
elif type_param is DataType.uint64:
dtype = np.dtype(np.uint64)
elif type_param is DataType.complex64:
dtype = np.dtype(np.complex64)
elif type_param is DataType.complex128:
dtype = np.dtype(np.complex128)
else:
raise TypeError(f"Unsupported DataType: {type_param}")
else:
# Try to use the type directly as a numpy dtype
try:
dtype = np.dtype(type_param)
except (TypeError, ValueError):
raise TypeError(f"Unsupported type for scratch view: {type_param}")

# Get itemsize as an int
type_size = int(dtype.itemsize)
return type_size


def _calculate_scratch_size(type_size: int, *dims: int, alignment: int = 8) -> int:
"""
Calculate scratch memory size for a scratch view.

:param type_size: Size of the element type in bytes
:param dims: Dimensions of the scratch view
:param alignment: Alignment requirement (default 8 bytes, matching Kokkos)
:returns: Total scratch memory size in bytes
"""
# Calculate total number of elements
total_elements = 1
for dim in dims:
total_elements *= dim

# Calculate raw size
raw_size = total_elements * type_size

# Align to the specified alignment (typically 8 bytes for Kokkos)
aligned_size = ((raw_size + alignment - 1) // alignment) * alignment

return aligned_size


class ScratchView:
def shmem_size(i: int):
pass
def __class_getitem__(cls, item):
generic_alias = super().__class_getitem__(item)
generic_alias.type_param = item
return generic_alias

@staticmethod
def shmem_size(i: int) -> int:
"""
Calculate shared memory size for a scratch view.
This is a base implementation that should be overridden by specific view types.
"""
raise NotImplementedError("shmem_size must be implemented by specific ScratchView types")


class ScratchView1D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int) -> int:
"""
Calculate shared memory size for a 1D scratch view.

:param dim0: Size of the first dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0)


class ScratchView2D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int, dim1: int) -> int:
"""
Calculate shared memory size for a 2D scratch view.

:param dim0: Size of the first dimension
:param dim1: Size of the second dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0, dim1)


class ScratchView3D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int, dim1: int, dim2: int) -> int:
"""
Calculate shared memory size for a 3D scratch view.

:param dim0: Size of the first dimension
:param dim1: Size of the second dimension
:param dim2: Size of the third dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0, dim1, dim2)


class ScratchView4D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int) -> int:
"""
Calculate shared memory size for a 4D scratch view.

:param dim0: Size of the first dimension
:param dim1: Size of the second dimension
:param dim2: Size of the third dimension
:param dim3: Size of the fourth dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3)


class ScratchView5D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int) -> int:
"""
Calculate shared memory size for a 5D scratch view.

:param dim0: Size of the first dimension
:param dim1: Size of the second dimension
:param dim2: Size of the third dimension
:param dim3: Size of the fourth dimension
:param dim4: Size of the fifth dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4)


class ScratchView6D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int, dim5: int) -> int:
"""
Calculate shared memory size for a 6D scratch view.

:param dim0: Size of the first dimension
:param dim1: Size of the second dimension
:param dim2: Size of the third dimension
:param dim3: Size of the fourth dimension
:param dim4: Size of the fifth dimension
:param dim5: Size of the sixth dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4, dim5)


class ScratchView7D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int, dim5: int, dim6: int) -> int:
"""
Calculate shared memory size for a 7D scratch view.

:param dim0: Size of the first dimension
:param dim1: Size of the second dimension
:param dim2: Size of the third dimension
:param dim3: Size of the fourth dimension
:param dim4: Size of the fifth dimension
:param dim5: Size of the sixth dimension
:param dim6: Size of the seventh dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4, dim5, dim6)


class ScratchView8D(ScratchView, Generic[T]):
pass
@classmethod
def shmem_size(cls, dim0: int, dim1: int, dim2: int, dim3: int, dim4: int, dim5: int, dim6: int, dim7: int) -> int:
"""
Calculate shared memory size for an 8D scratch view.

:param dim0: Size of the first dimension
:param dim1: Size of the second dimension
:param dim2: Size of the third dimension
:param dim3: Size of the fourth dimension
:param dim4: Size of the fifth dimension
:param dim5: Size of the sixth dimension
:param dim6: Size of the seventh dimension
:param dim7: Size of the eighth dimension
:returns: Total scratch memory size in bytes
"""
type_size = _get_type_size_from_generic(cls)
return _calculate_scratch_size(type_size, dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7)


def astype(view, dtype):
Expand Down
Loading