diff --git a/examples/kokkos/lower_bound_example.py b/examples/kokkos/lower_bound_example.py new file mode 100644 index 00000000..20a0c91e --- /dev/null +++ b/examples/kokkos/lower_bound_example.py @@ -0,0 +1,97 @@ +import pykokkos as pk + + +@pk.workunit +def init_data(i, view): + view[i] = i + 1 + + +# Test lower_bound with scratch memory +@pk.workunit +def team_lower_bound(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[int] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: int = team_rank * 2 # Search for a value + bound_idx: int = pk.lower_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +# Test lower_bound with regular view +# Find lower bound for value i in the first 10 elements +@pk.workunit +def lower_bound_view(i, view, result_view): + search_value: int = i + bound_idx: int = pk.lower_bound(view, 10, search_value) + result_view[i] = bound_idx + + +def main(): + N = 64 + team_size = 32 + num_teams = (N + team_size - 1) // team_size + + view: pk.View1D[int] = pk.View([N], int) + result_view: pk.View1D[int] = pk.View([N], int) + + # Expected results + expected_scratch = pk.View([64], int) + expected_scratch_data = [ + 0, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29 + ] + for i in range(64): + expected_scratch[i] = expected_scratch_data[i] + + expected_view = pk.View([64], int) + expected_view_data = [ + 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10 + ] + for i in range(64): + expected_view[i] = expected_view_data[i] + + p_init = pk.RangePolicy(pk.ExecutionSpace.OpenMP, 0, N) + pk.parallel_for(p_init, init_data, view=view) + + print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}") + print(f"Initial view: {view}") + + # Test with TeamPolicy (scratch memory) + team_policy = pk.TeamPolicy(pk.ExecutionSpace.OpenMP, num_teams, team_size) + + pk.parallel_for(team_policy, team_lower_bound, view=view, result_view=result_view) + print(f"Result (scratch lower_bound): {result_view}") + + # Assert scratch lower_bound results + for i in range(N): + assert ( + result_view[i] == expected_scratch[i] + ), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_scratch[i]}" + + # Test with RangePolicy (regular view) + pk.parallel_for(p_init, lower_bound_view, view=view, result_view=result_view) + print(f"Result (view lower_bound): {result_view}") + + for i in range(N): + assert ( + result_view[i] == expected_view[i] + ), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_view[i]}" + + +if __name__ == "__main__": + main() diff --git a/examples/kokkos/upper_bound_example.py b/examples/kokkos/upper_bound_example.py new file mode 100644 index 00000000..606719b1 --- /dev/null +++ b/examples/kokkos/upper_bound_example.py @@ -0,0 +1,100 @@ +import pykokkos as pk + + +@pk.workunit +def init_data(i, view): + view[i] = i + 1 + + +# Test upper_bound with scratch memory +@pk.workunit +def team_upper_bound(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[int] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: int = team_rank * 2 # Search for a value + bound_idx: int = pk.upper_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +# Test upper_bound with regular view +# Find upper bound for value i in the first 10 elements +@pk.workunit +def upper_bound_view(i, view, result_view): + search_value: int = i + bound_idx: int = pk.upper_bound(view, 10, search_value) + result_view[i] = bound_idx + + +def main(): + N = 64 + team_size = 32 + num_teams = (N + team_size - 1) // team_size + + view: pk.View1D[int] = pk.View([N], int) + result_view: pk.View1D[int] = pk.View([N], int) + + # Expected results + expected_scratch = pk.View([64], int) + expected_scratch_data = [ + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, + 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 + ] + for i in range(64): + expected_scratch[i] = expected_scratch_data[i] + + expected_view = pk.View([64], int) + expected_view_data = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10 + ] + for i in range(64): + expected_view[i] = expected_view_data[i] + + p_init = pk.RangePolicy(pk.ExecutionSpace.Cuda, 0, N) + pk.parallel_for(p_init, init_data, view=view) + + print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}") + print(f"Initial view: {view}") + + # Test with TeamPolicy (scratch memory) + team_policy = pk.TeamPolicy(pk.ExecutionSpace.Cuda, num_teams, team_size) + + pk.parallel_for(team_policy, team_upper_bound, view=view, result_view=result_view) + print(f"Result (scratch upper_bound): {result_view}") + + # Assert scratch upper_bound results + for i in range(N): + assert ( + result_view[i] == expected_scratch[i] + ), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_scratch[i]}" + print("Scratch upper_bound test passed") + + # Test with RangePolicy (regular view) + pk.parallel_for(p_init, upper_bound_view, view=view, result_view=result_view) + print(f"Result (view upper_bound): {result_view}") + + # Assert view upper_bound results + for i in range(N): + assert ( + result_view[i] == expected_view[i] + ), f"Mismatch at index {i}: got {result_view[i]}, expected {expected_view[i]}" + print("View upper_bound test passed") + + +if __name__ == "__main__": + main() diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index 2d830b23..faeafe26 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -268,6 +268,7 @@ def generate_includes(self) -> str: "Kokkos_Core.hpp", "Kokkos_Random.hpp", "Kokkos_Sort.hpp", + "Kokkos_StdAlgorithms.hpp", "fstream", "iostream", "cmath", @@ -290,6 +291,7 @@ def generate_cast_includes(self) -> str: "Kokkos_Core.hpp", "Kokkos_Random.hpp", "Kokkos_Sort.hpp", + "Kokkos_StdAlgorithms.hpp", "fstream", "iostream", "cmath", diff --git a/pykokkos/core/translators/symbols_pass.py b/pykokkos/core/translators/symbols_pass.py index d8b997af..84707b8d 100644 --- a/pykokkos/core/translators/symbols_pass.py +++ b/pykokkos/core/translators/symbols_pass.py @@ -54,7 +54,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str): self.global_symbols.update(math_functions) self.global_symbols.update(allowed_types) self.global_symbols.update(view_dtypes) - self.global_symbols.update(["self", "range", "math", "List", "abs"]) + self.global_symbols.update(["self", "range", "math", "List", "abs", "upper_bound", "lower_bound"]) self.global_symbols.add(pk_import) self.global_symbols.update([field.declname for field in members.fields]) diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index 3d52a7b4..3ba8cd0d 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -318,6 +318,76 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return real_number_call + # Custom `upper_bound` implementation using binary search + if name == "upper_bound": + # Check if it's called via pk.upper_bound + is_pk_call = ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == self.pk_import + ) + + if not is_pk_call: + return super().visit_Call(node) + + # Expected signature: pk.upper_bound(view, size, value) + if len(args) != 3: + self.error( + node, + "pk.upper_bound() takes 3 arguments: view, size, value", + ) + + view_expr = args[0] + size_expr = args[1] + value_expr = args[2] + + # Generate binary search lambda inline + from pykokkos.interface.algorithms.upper_bound import generate_upper_bound_binary_search + + # Create lambda body with binary search + lambda_body = generate_upper_bound_binary_search(view_expr, size_expr, value_expr) + + # Create and invoke lambda + lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body) + lambda_call = cppast.CallExpr(lambda_expr, []) + + return lambda_call + + # Custom `lower_bound` implementation using binary search + if name == "lower_bound": + # Check if it's called via pk.lower_bound + is_pk_call = ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == self.pk_import + ) + + if not is_pk_call: + return super().visit_Call(node) + + # Expected signature: pk.lower_bound(view, size, value) + if len(args) != 3: + self.error( + node, + "pk.lower_bound() takes 3 arguments: view, size, value", + ) + + view_expr = args[0] + size_expr = args[1] + value_expr = args[2] + + # Generate binary search lambda inline + from pykokkos.interface.algorithms.lower_bound import generate_lower_bound_binary_search + + # Create lambda body with binary search + lambda_body = generate_lower_bound_binary_search(view_expr, size_expr, value_expr) + + # Create and invoke lambda + lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body) + lambda_call = cppast.CallExpr(lambda_expr, []) + + return lambda_call + return super().visit_Call(node) def is_nested_call(self, node: ast.FunctionDef) -> bool: diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index 1439eed9..7cecd9fc 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -1,4 +1,5 @@ from .accumulator import Acc +from .algorithms import lower_bound, upper_bound from .atomic.atomic_fetch_op import ( atomic_fetch_add, atomic_fetch_and, atomic_fetch_div, atomic_fetch_lshift, atomic_fetch_max, atomic_fetch_min, diff --git a/pykokkos/interface/algorithms/__init__.py b/pykokkos/interface/algorithms/__init__.py new file mode 100644 index 00000000..2c2e8805 --- /dev/null +++ b/pykokkos/interface/algorithms/__init__.py @@ -0,0 +1,4 @@ +from .lower_bound import lower_bound +from .upper_bound import upper_bound + +__all__ = ["lower_bound", "upper_bound"] diff --git a/pykokkos/interface/algorithms/lower_bound.py b/pykokkos/interface/algorithms/lower_bound.py new file mode 100644 index 00000000..b4c4b785 --- /dev/null +++ b/pykokkos/interface/algorithms/lower_bound.py @@ -0,0 +1,95 @@ +from pykokkos.interface.views import ViewType +from pykokkos.core import cppast + +def lower_bound(view: ViewType, size: int, value) -> int: + """ + Perform a lower bound search on a view + + Returns the index of the first element not less than (i.e. greater or equal to) value, + similar to std::lower_bound or thrust::lower_bound. + + Supported types: All orderable numeric types (int8, int16, int32, int64, + uint8, uint16, uint32, uint64, float, double). Complex types are not + supported as they cannot be ordered. + + :param view: the view to search (must be sorted in ascending order) + :param size: the number of elements to search + :param value: the value to search for (must match view's element type) + :returns: the index of the first element >= value + """ + pass + + +def generate_lower_bound_binary_search( + view_expr: cppast.Expr, size_expr: cppast.Expr, value_expr: cppast.Expr +) -> cppast.CompoundStmt: + """ + Generate binary search implementation for lower_bound. + Returns a CompoundStmt that implements: + + int left = 0; + int right = size; + int mid; + while (left < right) { + mid = left + (right - left) / 2; + if (view[mid] < value) { + left = mid + 1; + } else { + right = mid; + } + } + return left; + """ + + # Variable declarations + int_type = cppast.PrimitiveType("int32_t") + + # int left = 0; + left_var = cppast.DeclRefExpr("left") + left_init = cppast.IntegerLiteral(0) + left_decl = cppast.VarDecl(int_type, left_var, left_init) + left_stmt = cppast.DeclStmt(left_decl) + + # int right = size; + right_var = cppast.DeclRefExpr("right") + right_decl = cppast.VarDecl(int_type, right_var, size_expr) + right_stmt = cppast.DeclStmt(right_decl) + + # int mid; + mid_var = cppast.DeclRefExpr("mid") + mid_decl = cppast.VarDecl(int_type, mid_var, None) + mid_stmt = cppast.DeclStmt(mid_decl) + + # while (left < right) + while_cond = cppast.BinaryOperator(left_var, right_var, cppast.BinaryOperatorKind.LT) + + # mid = left + (right - left) / 2; + right_minus_left = cppast.BinaryOperator(right_var, left_var, cppast.BinaryOperatorKind.Sub) + div_expr = cppast.BinaryOperator(right_minus_left, cppast.IntegerLiteral(2), cppast.BinaryOperatorKind.Div) + mid_calc = cppast.BinaryOperator(left_var, div_expr, cppast.BinaryOperatorKind.Add) + mid_assign = cppast.AssignOperator([mid_var], mid_calc, cppast.BinaryOperatorKind.Assign) + + # if (view(mid) < value) + view_ref = view_expr if isinstance(view_expr, cppast.DeclRefExpr) else cppast.DeclRefExpr("view") + view_access = cppast.CallExpr(view_ref, [mid_var]) + if_cond = cppast.BinaryOperator(view_access, value_expr, cppast.BinaryOperatorKind.LT) + + # left = mid + 1; + mid_plus_one = cppast.BinaryOperator(mid_var, cppast.IntegerLiteral(1), cppast.BinaryOperatorKind.Add) + left_assign = cppast.AssignOperator([left_var], mid_plus_one, cppast.BinaryOperatorKind.Assign) + + # right = mid; + right_assign = cppast.AssignOperator([right_var], mid_var, cppast.BinaryOperatorKind.Assign) + + # if-else statement + if_stmt = cppast.IfStmt(if_cond, left_assign, right_assign) + + # while body + while_body = cppast.CompoundStmt([mid_assign, if_stmt]) + while_stmt = cppast.WhileStmt(while_cond, while_body) + + # return left; + return_stmt = cppast.ReturnStmt(left_var) + + # Complete function body + return cppast.CompoundStmt([left_stmt, right_stmt, mid_stmt, while_stmt, return_stmt]) diff --git a/pykokkos/interface/algorithms/upper_bound.py b/pykokkos/interface/algorithms/upper_bound.py new file mode 100644 index 00000000..41bc6b68 --- /dev/null +++ b/pykokkos/interface/algorithms/upper_bound.py @@ -0,0 +1,95 @@ +from pykokkos.interface.views import ViewType +from pykokkos.core import cppast + +def upper_bound(view: ViewType, size: int, value) -> int: + """ + Perform an upper bound search on a view + + Returns the index of the first element greater than value, + similar to std::upper_bound or thrust::upper_bound. + + Supported types: All orderable numeric types (int8, int16, int32, int64, + uint8, uint16, uint32, uint64, float, double). Complex types are not + supported as they cannot be ordered. + + :param view: the view to search (must be sorted in ascending order) + :param size: the number of elements to search + :param value: the value to search for (must match view's element type) + :returns: the index of the first element greater than value + """ + pass + + +def generate_upper_bound_binary_search( + view_expr: cppast.Expr, size_expr: cppast.Expr, value_expr: cppast.Expr +) -> cppast.CompoundStmt: + """ + Generate binary search implementation for upper_bound. + Returns a CompoundStmt that implements: + + int left = -1; + int right = size; + int mid; + while (left + 1 < right) { + mid = left + ((right - left) >> 1); + if (view[mid] > value) { + right = mid; + } else { + left = mid; + } + } + return right; + """ + + # Variable declarations + int_type = cppast.PrimitiveType("int32_t") + + # int left = -1; + left_var = cppast.DeclRefExpr("left") + left_init = cppast.IntegerLiteral(-1) + left_decl = cppast.VarDecl(int_type, left_var, left_init) + left_stmt = cppast.DeclStmt(left_decl) + + # int right = size; + right_var = cppast.DeclRefExpr("right") + right_decl = cppast.VarDecl(int_type, right_var, size_expr) + right_stmt = cppast.DeclStmt(right_decl) + + # int mid; + mid_var = cppast.DeclRefExpr("mid") + mid_decl = cppast.VarDecl(int_type, mid_var, None) + mid_stmt = cppast.DeclStmt(mid_decl) + + # while condition: left + 1 < right + left_plus_one = cppast.BinaryOperator(left_var, cppast.IntegerLiteral(1), cppast.BinaryOperatorKind.Add) + while_cond = cppast.BinaryOperator(left_plus_one, right_var, cppast.BinaryOperatorKind.LT) + + # mid = left + ((right - left) >> 1); + right_minus_left = cppast.BinaryOperator(right_var, left_var, cppast.BinaryOperatorKind.Sub) + shift_expr = cppast.BinaryOperator(right_minus_left, cppast.IntegerLiteral(1), cppast.BinaryOperatorKind.Shr) + mid_calc = cppast.BinaryOperator(left_var, shift_expr, cppast.BinaryOperatorKind.Add) + mid_assign = cppast.AssignOperator([mid_var], mid_calc, cppast.BinaryOperatorKind.Assign) + + # if (view[mid] > value) + view_ref = view_expr if isinstance(view_expr, cppast.DeclRefExpr) else cppast.DeclRefExpr("view") + view_access = cppast.CallExpr(view_ref, [mid_var]) + if_cond = cppast.BinaryOperator(view_access, value_expr, cppast.BinaryOperatorKind.GT) + + # right = mid; + right_assign = cppast.AssignOperator([right_var], mid_var, cppast.BinaryOperatorKind.Assign) + + # left = mid; + left_assign = cppast.AssignOperator([left_var], mid_var, cppast.BinaryOperatorKind.Assign) + + # if-else statement + if_stmt = cppast.IfStmt(if_cond, right_assign, left_assign) + + # while body + while_body = cppast.CompoundStmt([mid_assign, if_stmt]) + while_stmt = cppast.WhileStmt(while_cond, while_body) + + # return right; + return_stmt = cppast.ReturnStmt(right_var) + + # Complete function body + return cppast.CompoundStmt([left_stmt, right_stmt, mid_stmt, while_stmt, return_stmt]) diff --git a/pykokkos/interface/views.py b/pykokkos/interface/views.py index 536b6719..3d335346 100644 --- a/pykokkos/interface/views.py +++ b/pykokkos/interface/views.py @@ -1102,6 +1102,7 @@ def _calculate_scratch_size(type_size: int, *dims: int, alignment: int = 8) -> i class ScratchView: + @classmethod def __class_getitem__(cls, item): generic_alias = super().__class_getitem__(item) generic_alias.type_param = item diff --git a/runtests.py b/runtests.py index 39c4ff70..37b45a9e 100644 --- a/runtests.py +++ b/runtests.py @@ -18,6 +18,12 @@ # pytest pytest_args = [] +# Ignore build directories to avoid collecting tests from dependencies +pytest_args.extend([ + "--ignore=base/_skbuild", + "--ignore=base/.eggs", +]) + parser = argparse.ArgumentParser() parser.add_argument('-t', '--specifictests', type=str) parser.add_argument('-d', '--durations', type=int) @@ -25,6 +31,10 @@ if args.specifictests: pytest_args.append(args.specifictests) +else: + # Default to tests/ directory if no specific test is provided + pytest_args.append("tests/") + if args.durations: pytest_args.append(f"--durations={args.durations}") diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py new file mode 100644 index 00000000..619453c5 --- /dev/null +++ b/tests/test_algorithms.py @@ -0,0 +1,485 @@ +""" +Test algorithms like lower_bound and upper_bound +""" + +import pykokkos as pk +import pytest + + +@pk.workunit +def init_int_data(i, view): + view[i] = (i + 1) * 2 + + +@pk.workunit +def init_float_data(i, view: pk.View1D[pk.float]): + view[i] = (i + 1) * 2.0 + + +@pk.workunit +def init_double_data(i, view: pk.View1D[pk.double]): + view[i] = (i + 1) * 2.0 + + +@pk.workunit +def lower_bound_int(i, view, result_view): + search_value: int = i * 2 + bound_idx: int = pk.lower_bound(view, 10, search_value) + result_view[i] = bound_idx + + +@pk.workunit +def lower_bound_float(i, view, result_view): + search_value: float = float(i * 2) + bound_idx: int = pk.lower_bound(view, 10, search_value) + result_view[i] = bound_idx + + +@pk.workunit +def lower_bound_double(i, view, result_view): + search_value: float = float(i * 2) + bound_idx: int = pk.lower_bound(view, 10, search_value) + result_view[i] = bound_idx + + +@pk.workunit +def upper_bound_int(i, view, result_view): + search_value: int = i * 2 + bound_idx: int = pk.upper_bound(view, 10, search_value) + result_view[i] = bound_idx + + +@pk.workunit +def upper_bound_float(i, view, result_view): + search_value: float = float(i * 2) + bound_idx: int = pk.upper_bound(view, 10, search_value) + result_view[i] = bound_idx + + +@pk.workunit +def upper_bound_double(i, view, result_view): + search_value: float = float(i * 2) + bound_idx: int = pk.upper_bound(view, 10, search_value) + result_view[i] = bound_idx + + +@pk.workunit +def team_lower_bound_int(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[pk.int32] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: int = team_rank * 4 + bound_idx: int = pk.lower_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +@pk.workunit +def team_lower_bound_float(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[pk.float] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: float = float(team_rank * 4) + bound_idx: int = pk.lower_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +@pk.workunit +def team_lower_bound_double(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[pk.double] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: float = float(team_rank * 4) + bound_idx: int = pk.lower_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +@pk.workunit +def team_upper_bound_int(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[pk.int32] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: int = team_rank * 4 + bound_idx: int = pk.upper_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +@pk.workunit +def team_upper_bound_float(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[pk.float] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: float = float(team_rank * 4) + bound_idx: int = pk.upper_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +@pk.workunit +def team_upper_bound_double(team_member, view, result_view): + team_size: int = team_member.team_size() + offset: int = team_member.league_rank() * team_size + localIdx: int = team_member.team_rank() + globalIdx: int = offset + localIdx + team_rank: int = team_member.team_rank() + + scratch: pk.ScratchView1D[pk.double] = pk.ScratchView1D( + team_member.team_scratch(0), team_size + ) + + scratch[team_rank] = view[globalIdx] + team_member.team_barrier() + search_value: float = float(team_rank * 4) + bound_idx: int = pk.upper_bound(scratch, team_size, search_value) + result_view[globalIdx] = bound_idx + + +class TestLowerBoundInt: + def test_lower_bound_int(self): + N = 20 + view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_int_data, view=view) + pk.parallel_for(N, lower_bound_int, view=view, result_view=result_view) + + # View contains [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, ...] + # For search value i*2: 0->0, 2->0, 4->1, 6->2, 8->3, 10->4, etc. + # lower_bound returns first element >= value + expected = [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10] + + for i in range(N): + assert ( + result_view[i] == expected[i] + ), f"Failed at index {i}: expected {expected[i]}, got {result_view[i]}" + + +class TestLowerBoundFloat: + def test_lower_bound_float(self): + N = 20 + view: pk.View1D[pk.float] = pk.View([N], pk.float) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_float_data, view=view) + pk.parallel_for(N, lower_bound_float, view=view, result_view=result_view) + + # View contains [2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, ...] + # lower_bound returns first element >= value + expected = [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10] + + for i in range(N): + assert ( + result_view[i] == expected[i] + ), f"Failed at index {i}: expected {expected[i]}, got {result_view[i]}" + + +class TestLowerBoundDouble: + def test_lower_bound_double(self): + N = 20 + view: pk.View1D[pk.double] = pk.View([N], pk.double) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_double_data, view=view) + pk.parallel_for(N, lower_bound_double, view=view, result_view=result_view) + + # View contains [2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, ...] + # lower_bound returns first element >= value + expected = [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10] + + for i in range(N): + assert ( + result_view[i] == expected[i] + ), f"Failed at index {i}: expected {expected[i]}, got {result_view[i]}" + + +class TestUpperBoundInt: + def test_upper_bound_int(self): + N = 20 + view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_int_data, view=view) + pk.parallel_for(N, upper_bound_int, view=view, result_view=result_view) + + # View contains [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, ...] + # For search value i*2: upper_bound finds first element > value + # 0->0 (first>0 is 2), 2->1 (first>2 is 4), 4->2 (first>4 is 6), etc. + expected = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + ] + + for i in range(N): + assert ( + result_view[i] == expected[i] + ), f"Failed at index {i}: expected {expected[i]}, got {result_view[i]}" + + +class TestUpperBoundFloat: + def test_upper_bound_float(self): + N = 20 + view: pk.View1D[pk.float] = pk.View([N], pk.float) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_float_data, view=view) + pk.parallel_for(N, upper_bound_float, view=view, result_view=result_view) + + # upper_bound returns first element > value + expected = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + ] + + for i in range(N): + assert ( + result_view[i] == expected[i] + ), f"Failed at index {i}: expected {expected[i]}, got {result_view[i]}" + + +class TestUpperBoundDouble: + def test_upper_bound_double(self): + N = 20 + view: pk.View1D[pk.double] = pk.View([N], pk.double) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_double_data, view=view) + pk.parallel_for(N, upper_bound_double, view=view, result_view=result_view) + + # upper_bound returns first element > value + expected = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + ] + + for i in range(N): + assert ( + result_view[i] == expected[i] + ), f"Failed at index {i}: expected {expected[i]}, got {result_view[i]}" + + +class TestTeamLowerBound: + def test_team_lower_bound_int(self): + N = 32 + num_teams = 2 + + view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_int_data, view=view) + + max_team_size = N // num_teams + scratch_size = pk.ScratchView1D[pk.int32].shmem_size(max_team_size) + + policy = pk.TeamPolicy(num_teams, "auto") + policy.set_scratch_size(0, pk.PerTeam(scratch_size)) + + pk.parallel_for( + policy, team_lower_bound_int, view=view, result_view=result_view + ) + + # Just verify it runs without error - the exact values depend on the sorting + assert result_view[0] >= 0 + + def test_team_lower_bound_float(self): + N = 32 + num_teams = 2 + + view: pk.View1D[pk.float] = pk.View([N], pk.float) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_float_data, view=view) + + max_team_size = N // num_teams + scratch_size = pk.ScratchView1D[pk.float].shmem_size(max_team_size) + + policy = pk.TeamPolicy(num_teams, "auto") + policy.set_scratch_size(0, pk.PerTeam(scratch_size)) + + pk.parallel_for( + policy, team_lower_bound_float, view=view, result_view=result_view + ) + + assert result_view[0] >= 0 + + def test_team_lower_bound_double(self): + N = 32 + num_teams = 2 + + view: pk.View1D[pk.double] = pk.View([N], pk.double) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_double_data, view=view) + + max_team_size = N // num_teams + scratch_size = pk.ScratchView1D[pk.double].shmem_size(max_team_size) + + policy = pk.TeamPolicy(num_teams, "auto") + policy.set_scratch_size(0, pk.PerTeam(scratch_size)) + + pk.parallel_for( + policy, team_lower_bound_double, view=view, result_view=result_view + ) + + assert result_view[0] >= 0 + + +class TestTeamUpperBound: + def test_team_upper_bound_int(self): + N = 32 + num_teams = 2 + + view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_int_data, view=view) + + max_team_size = N // num_teams + scratch_size = pk.ScratchView1D[pk.int32].shmem_size(max_team_size) + + policy = pk.TeamPolicy(num_teams, "auto") + policy.set_scratch_size(0, pk.PerTeam(scratch_size)) + + pk.parallel_for( + policy, team_upper_bound_int, view=view, result_view=result_view + ) + + assert result_view[0] >= 0 + + def test_team_upper_bound_float(self): + N = 32 + num_teams = 2 + + view: pk.View1D[pk.float] = pk.View([N], pk.float) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_float_data, view=view) + + max_team_size = N // num_teams + scratch_size = pk.ScratchView1D[pk.float].shmem_size(max_team_size) + + policy = pk.TeamPolicy(num_teams, "auto") + policy.set_scratch_size(0, pk.PerTeam(scratch_size)) + + pk.parallel_for( + policy, team_upper_bound_float, view=view, result_view=result_view + ) + + assert result_view[0] >= 0 + + def test_team_upper_bound_double(self): + N = 32 + num_teams = 2 + + view: pk.View1D[pk.double] = pk.View([N], pk.double) + result_view: pk.View1D[pk.int32] = pk.View([N], pk.int32) + + pk.parallel_for(N, init_double_data, view=view) + + max_team_size = N // num_teams + scratch_size = pk.ScratchView1D[pk.double].shmem_size(max_team_size) + + policy = pk.TeamPolicy(num_teams, "auto") + policy.set_scratch_size(0, pk.PerTeam(scratch_size)) + + pk.parallel_for( + policy, team_upper_bound_double, view=view, result_view=result_view + ) + + assert result_view[0] >= 0