Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/kokkos/inclusive_scan_team.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pykokkos as pk


@pk.workunit
def init_data(i: int, view: pk.View1D[int]):
view[i] = i + 1


# Test inclusive_scan with scratch memory
@pk.workunit
def team_scan(team_member: pk.TeamMember, view: pk.View1D[int]):
team_size: int = team_member.team_size()
offset: int = team_member.league_rank() * team_size
localIdx: int = team_member.team_rank()
globalIdx: int = offset + localIdx
team_rank: int = team_member.team_rank()

scratch: pk.ScratchView1D[int] = pk.ScratchView1D(
team_member.team_scratch(0), team_size
)

scratch[team_rank] = view[globalIdx]
team_member.team_barrier()

pk.inclusive_scan(team_member, scratch)
team_member.team_barrier()

view[globalIdx] = scratch[team_rank]


def main():
N = 64
team_size = 32
num_teams = (N + team_size - 1) // team_size

view: pk.View1D[int] = pk.View([N], int)
p_init = pk.RangePolicy(pk.ExecutionSpace.OpenMP, 0, N)
pk.parallel_for(p_init, init_data, view=view)

print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}")

# Use TeamPolicy
team_policy = pk.TeamPolicy(pk.ExecutionSpace.OpenMP, num_teams, team_size)

# for now these functions are useless, since they are not implemented corectly
# TODO: implement scratch size setting
# scratch_size = pk.ScratchView1D[int].shmem_size(team_size)
# team_policy.set_scratch_size(0, pk.PerTeam(scratch_size))

# Kernel call - just allocate and write to scratch
print("Running kernel...")
pk.parallel_for(team_policy, team_scan, view=view)
print(f"View, splitted by two groups of size = {team_size}")
print(view)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions pykokkos/core/translators/static.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified to include inclusive_scan

Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def generate_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand All @@ -290,6 +291,7 @@ def generate_cast_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand Down
2 changes: 1 addition & 1 deletion pykokkos/core/translators/symbols_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str):
self.global_symbols.update(math_functions)
self.global_symbols.update(allowed_types)
self.global_symbols.update(view_dtypes)
self.global_symbols.update(["self", "range", "math", "List", "abs"])
self.global_symbols.update(["self", "range", "math", "List", "abs", "inclusive_scan"])
self.global_symbols.add(pk_import)

self.global_symbols.update([field.declname for field in members.fields])
Expand Down
45 changes: 45 additions & 0 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,51 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:

return real_number_call

# Kokkos `inclusive_scan`
if name == "inclusive_scan":
# Check if it's called via pk.inclusive_scan
is_pk_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
)

if not is_pk_call:
return super().visit_Call(node)

# Expected signature: pk.inclusive_scan(team_member, view, length)
# or: pk.inclusive_scan(team_member, view) - uses view size
if len(args) < 2 or len(args) > 3:
self.error(
node,
"pk.inclusive_scan() takes 2 or 3 arguments: team_member, view, [length]",
)

team_member_expr = args[0]
view_expr = args[1]

# Create Kokkos::Experimental::begin() and end() calls
begin_function = cppast.DeclRefExpr("Kokkos::Experimental::begin")
end_function = cppast.DeclRefExpr("Kokkos::Experimental::end")
view_begin = cppast.CallExpr(begin_function, [view_expr])

if len(args) == 3:
# Use provided length: begin + length
length_expr = args[2]
view_begin_for_end = cppast.CallExpr(begin_function, [view_expr])
view_end = cppast.BinaryOperator(
view_begin_for_end, length_expr, cppast.BinaryOperatorKind.Add
)
else:
# Use end() when no length is provided
view_end = cppast.CallExpr(end_function, [view_expr])

# Create Kokkos::Experimental::inclusive_scan call
function = cppast.DeclRefExpr("Kokkos::Experimental::inclusive_scan")
scan_args = [team_member_expr, view_begin, view_end, view_begin]

return cppast.CallExpr(function, scan_args)

return super().visit_Call(node)

def is_nested_call(self, node: ast.FunctionDef) -> bool:
Expand Down
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .accumulator import Acc
from .algorithms.inclusive_scan import inclusive_scan
from .atomic.atomic_fetch_op import (
atomic_fetch_add, atomic_fetch_and, atomic_fetch_div,
atomic_fetch_lshift, atomic_fetch_max, atomic_fetch_min,
Expand Down
3 changes: 3 additions & 0 deletions pykokkos/interface/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .inclusive_scan import inclusive_scan

__all__ = ["inclusive_scan"]
15 changes: 15 additions & 0 deletions pykokkos/interface/algorithms/inclusive_scan.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interfaces modified to support linting and highlighting

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pykokkos.interface.hierarchical import TeamMember
from pykokkos.interface.views import ViewType


def inclusive_scan(team_member: TeamMember, view: ViewType, size: int = -1):
"""
Perform an inclusive scan on a view using a team member.

**`team_barrier()` should always be called before accessing scanned data.**

:param team_member: the team member
:param view: the view to scan
:param size: (optional) the number of elements to scan
"""
pass