diff --git a/examples/kokkos/inclusive_scan_team.py b/examples/kokkos/inclusive_scan_team.py new file mode 100644 index 00000000..62fd3506 --- /dev/null +++ b/examples/kokkos/inclusive_scan_team.py @@ -0,0 +1,58 @@ +import pykokkos as pk + + +@pk.workunit +def init_data(i: int, view: pk.View1D[int]): + view[i] = i + 1 + + +# Test inclusive_scan with scratch memory +@pk.workunit +def team_scan(team_member: pk.TeamMember, view: pk.View1D[int]): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[int] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + + pk.inclusive_scan(team_member, scratch) + team_member.team_barrier() + + view[globalIdx] = scratch[team_rank] + + +def main(): + N = 64 + team_size = 32 + num_teams = (N + team_size - 1) // team_size + + view: pk.View1D[int] = pk.View([N], int) + p_init = pk.RangePolicy(pk.ExecutionSpace.OpenMP, 0, N) + pk.parallel_for(p_init, init_data, view=view) + + print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}") + + # Use TeamPolicy + team_policy = pk.TeamPolicy(pk.ExecutionSpace.OpenMP, num_teams, team_size) + + # for now these functions are useless, since they are not implemented corectly + # TODO: implement scratch size setting + # scratch_size = pk.ScratchView1D[int].shmem_size(team_size) + # team_policy.set_scratch_size(0, pk.PerTeam(scratch_size)) + + # Kernel call - just allocate and write to scratch + print("Running kernel...") + pk.parallel_for(team_policy, team_scan, view=view) + print(f"View, splitted by two groups of size = {team_size}") + print(view) + + +if __name__ == "__main__": + main() diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index 2d830b23..faeafe26 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -268,6 +268,7 @@ def generate_includes(self) -> str: "Kokkos_Core.hpp", "Kokkos_Random.hpp", "Kokkos_Sort.hpp", + "Kokkos_StdAlgorithms.hpp", "fstream", "iostream", "cmath", @@ -290,6 +291,7 @@ def generate_cast_includes(self) -> str: "Kokkos_Core.hpp", "Kokkos_Random.hpp", "Kokkos_Sort.hpp", + "Kokkos_StdAlgorithms.hpp", "fstream", "iostream", "cmath", diff --git a/pykokkos/core/translators/symbols_pass.py b/pykokkos/core/translators/symbols_pass.py index d8b997af..e21dbc36 100644 --- a/pykokkos/core/translators/symbols_pass.py +++ b/pykokkos/core/translators/symbols_pass.py @@ -54,7 +54,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str): self.global_symbols.update(math_functions) self.global_symbols.update(allowed_types) self.global_symbols.update(view_dtypes) - self.global_symbols.update(["self", "range", "math", "List", "abs"]) + self.global_symbols.update(["self", "range", "math", "List", "abs", "inclusive_scan"]) self.global_symbols.add(pk_import) self.global_symbols.update([field.declname for field in members.fields]) diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index 3d52a7b4..31ebdad1 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -318,6 +318,51 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return real_number_call + # Kokkos `inclusive_scan` + if name == "inclusive_scan": + # Check if it's called via pk.inclusive_scan + is_pk_call = ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == self.pk_import + ) + + if not is_pk_call: + return super().visit_Call(node) + + # Expected signature: pk.inclusive_scan(team_member, view, length) + # or: pk.inclusive_scan(team_member, view) - uses view size + if len(args) < 2 or len(args) > 3: + self.error( + node, + "pk.inclusive_scan() takes 2 or 3 arguments: team_member, view, [length]", + ) + + team_member_expr = args[0] + view_expr = args[1] + + # Create Kokkos::Experimental::begin() and end() calls + begin_function = cppast.DeclRefExpr("Kokkos::Experimental::begin") + end_function = cppast.DeclRefExpr("Kokkos::Experimental::end") + view_begin = cppast.CallExpr(begin_function, [view_expr]) + + if len(args) == 3: + # Use provided length: begin + length + length_expr = args[2] + view_begin_for_end = cppast.CallExpr(begin_function, [view_expr]) + view_end = cppast.BinaryOperator( + view_begin_for_end, length_expr, cppast.BinaryOperatorKind.Add + ) + else: + # Use end() when no length is provided + view_end = cppast.CallExpr(end_function, [view_expr]) + + # Create Kokkos::Experimental::inclusive_scan call + function = cppast.DeclRefExpr("Kokkos::Experimental::inclusive_scan") + scan_args = [team_member_expr, view_begin, view_end, view_begin] + + return cppast.CallExpr(function, scan_args) + return super().visit_Call(node) def is_nested_call(self, node: ast.FunctionDef) -> bool: diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index 1439eed9..1bf74342 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -1,4 +1,5 @@ from .accumulator import Acc +from .algorithms.inclusive_scan import inclusive_scan from .atomic.atomic_fetch_op import ( atomic_fetch_add, atomic_fetch_and, atomic_fetch_div, atomic_fetch_lshift, atomic_fetch_max, atomic_fetch_min, diff --git a/pykokkos/interface/algorithms/__init__.py b/pykokkos/interface/algorithms/__init__.py new file mode 100644 index 00000000..6fb66fa3 --- /dev/null +++ b/pykokkos/interface/algorithms/__init__.py @@ -0,0 +1,3 @@ +from .inclusive_scan import inclusive_scan + +__all__ = ["inclusive_scan"] diff --git a/pykokkos/interface/algorithms/inclusive_scan.py b/pykokkos/interface/algorithms/inclusive_scan.py new file mode 100644 index 00000000..607c181d --- /dev/null +++ b/pykokkos/interface/algorithms/inclusive_scan.py @@ -0,0 +1,15 @@ +from pykokkos.interface.hierarchical import TeamMember +from pykokkos.interface.views import ViewType + + +def inclusive_scan(team_member: TeamMember, view: ViewType, size: int = -1): + """ + Perform an inclusive scan on a view using a team member. + + **`team_barrier()` should always be called before accessing scanned data.** + + :param team_member: the team member + :param view: the view to scan + :param size: (optional) the number of elements to scan + """ + pass