Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/kokkos/lower_bound_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pykokkos as pk


@pk.workunit
def init_data(i: int, view: pk.View1D[int]):
view[i] = i + 1


# Test lower_bound with scratch memory
@pk.workunit
def team_lower_bound(team_member: pk.TeamMember, view: pk.View1D[int], result_view: pk.View1D[int]):
team_size: int = team_member.team_size()
offset: int = team_member.league_rank() * team_size
localIdx: int = team_member.team_rank()
globalIdx: int = offset + localIdx
team_rank: int = team_member.team_rank()

# Allocate scratch memory for sorted data
scratch: pk.ScratchView1D[int] = pk.ScratchView1D(
team_member.team_scratch(0), team_size
)

# Copy data to scratch and make it sorted within the team
scratch[team_rank] = view[globalIdx]
team_member.team_barrier()

# Now use lower_bound to find position in scratch
# For example, find lower bound for the value at team_rank position
search_value: int = team_rank * 2 # Search for a value

# Find lower bound in scratch memory
bound_idx: int = pk.lower_bound(scratch, team_size, search_value)

# Store result
result_view[globalIdx] = bound_idx


# Test lower_bound with regular view
@pk.workunit
def lower_bound_view(i: int, view: pk.View1D[int], result_view: pk.View1D[int]):
# Find lower bound for value i in the first 10 elements
search_value: int = i
bound_idx: int = pk.lower_bound(view, 10, search_value)
result_view[i] = bound_idx


def main():
N = 64
team_size = 32
num_teams = (N + team_size - 1) // team_size

view: pk.View1D[int] = pk.View([N], int)
result_view: pk.View1D[int] = pk.View([N], int)

p_init = pk.RangePolicy(pk.ExecutionSpace.OpenMP, 0, N)
pk.parallel_for(p_init, init_data, view=view)

print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}")
print(f"Initial view: {view}")

# Test with TeamPolicy (scratch memory)
team_policy = pk.TeamPolicy(pk.ExecutionSpace.OpenMP, num_teams, team_size)

print("\nRunning lower_bound with scratch memory...")
pk.parallel_for(team_policy, team_lower_bound, view=view, result_view=result_view)
print(f"Result (scratch lower_bound): {result_view}")

# Test with RangePolicy (regular view)
print("\nRunning lower_bound with regular view...")
pk.parallel_for(p_init, lower_bound_view, view=view, result_view=result_view)
print(f"Result (view lower_bound): {result_view}")


if __name__ == "__main__":
main()
75 changes: 75 additions & 0 deletions examples/kokkos/upper_bound_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pykokkos as pk


@pk.workunit
def init_data(i: int, view: pk.View1D[int]):
view[i] = i + 1


# Test upper_bound with scratch memory
@pk.workunit
def team_upper_bound(team_member: pk.TeamMember, view: pk.View1D[int], result_view: pk.View1D[int]):
team_size: int = team_member.team_size()
offset: int = team_member.league_rank() * team_size
localIdx: int = team_member.team_rank()
globalIdx: int = offset + localIdx
team_rank: int = team_member.team_rank()

# Allocate scratch memory for sorted data
scratch: pk.ScratchView1D[int] = pk.ScratchView1D(
team_member.team_scratch(0), team_size
)

# Copy data to scratch and make it sorted within the team
scratch[team_rank] = view[globalIdx]
team_member.team_barrier()

# Now use upper_bound to find position in scratch
# For example, find upper bound for the value at team_rank position
search_value: int = team_rank * 2 # Search for a value

# Find upper bound in scratch memory
bound_idx: int = pk.upper_bound(scratch, team_size, search_value)

# Store result
result_view[globalIdx] = bound_idx


# Test upper_bound with regular view
@pk.workunit
def upper_bound_view(i: int, view: pk.View1D[int], result_view: pk.View1D[int]):
# Find upper bound for value i in the first 10 elements
search_value: int = i
bound_idx: int = pk.upper_bound(view, 10, search_value)
result_view[i] = bound_idx


def main():
N = 64
team_size = 32
num_teams = (N + team_size - 1) // team_size

view: pk.View1D[int] = pk.View([N], int)
result_view: pk.View1D[int] = pk.View([N], int)

p_init = pk.RangePolicy(pk.ExecutionSpace.Cuda, 0, N)
pk.parallel_for(p_init, init_data, view=view)

print(f"Total elements: {N}, Team size: {team_size}, Number of teams: {num_teams}")
print(f"Initial view: {view}")

# Test with TeamPolicy (scratch memory)
team_policy = pk.TeamPolicy(pk.ExecutionSpace.Cuda, num_teams, team_size)

print("\nRunning upper_bound with scratch memory...")
pk.parallel_for(team_policy, team_upper_bound, view=view, result_view=result_view)
print(f"Result (scratch upper_bound): {result_view}")

# Test with RangePolicy (regular view)
print("\nRunning upper_bound with regular view...")
pk.parallel_for(p_init, upper_bound_view, view=view, result_view=result_view)
print(f"Result (view upper_bound): {result_view}")


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions pykokkos/core/translators/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def generate_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand All @@ -290,6 +291,7 @@ def generate_cast_includes(self) -> str:
"Kokkos_Core.hpp",
"Kokkos_Random.hpp",
"Kokkos_Sort.hpp",
"Kokkos_StdAlgorithms.hpp",
"fstream",
"iostream",
"cmath",
Expand Down
2 changes: 1 addition & 1 deletion pykokkos/core/translators/symbols_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str):
self.global_symbols.update(math_functions)
self.global_symbols.update(allowed_types)
self.global_symbols.update(view_dtypes)
self.global_symbols.update(["self", "range", "math", "List", "abs"])
self.global_symbols.update(["self", "range", "math", "List", "abs", "upper_bound", "lower_bound"])
self.global_symbols.add(pk_import)

self.global_symbols.update([field.declname for field in members.fields])
Expand Down
70 changes: 70 additions & 0 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,76 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:

return real_number_call

# Custom `upper_bound` implementation using binary search
if name == "upper_bound":
# Check if it's called via pk.upper_bound
is_pk_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
)

if not is_pk_call:
return super().visit_Call(node)

# Expected signature: pk.upper_bound(view, size, value)
if len(args) != 3:
self.error(
node,
"pk.upper_bound() takes 3 arguments: view, size, value",
)

view_expr = args[0]
size_expr = args[1]
value_expr = args[2]

# Generate binary search lambda inline
from pykokkos.interface.algorithms.upper_bound import generate_upper_bound_binary_search

# Create lambda body with binary search
lambda_body = generate_upper_bound_binary_search(view_expr, size_expr, value_expr)

# Create and invoke lambda
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
lambda_call = cppast.CallExpr(lambda_expr, [])

return lambda_call

# Custom `lower_bound` implementation using binary search
if name == "lower_bound":
# Check if it's called via pk.lower_bound
is_pk_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
)

if not is_pk_call:
return super().visit_Call(node)

# Expected signature: pk.lower_bound(view, size, value)
if len(args) != 3:
self.error(
node,
"pk.lower_bound() takes 3 arguments: view, size, value",
)

view_expr = args[0]
size_expr = args[1]
value_expr = args[2]

# Generate binary search lambda inline
from pykokkos.interface.algorithms.lower_bound import generate_lower_bound_binary_search

# Create lambda body with binary search
lambda_body = generate_lower_bound_binary_search(view_expr, size_expr, value_expr)

# Create and invoke lambda
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
lambda_call = cppast.CallExpr(lambda_expr, [])

return lambda_call

return super().visit_Call(node)

def is_nested_call(self, node: ast.FunctionDef) -> bool:
Expand Down
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .accumulator import Acc
from .algorithms import lower_bound, upper_bound
from .atomic.atomic_fetch_op import (
atomic_fetch_add, atomic_fetch_and, atomic_fetch_div,
atomic_fetch_lshift, atomic_fetch_max, atomic_fetch_min,
Expand Down
4 changes: 4 additions & 0 deletions pykokkos/interface/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .lower_bound import lower_bound
from .upper_bound import upper_bound

__all__ = ["lower_bound", "upper_bound"]
91 changes: 91 additions & 0 deletions pykokkos/interface/algorithms/lower_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from pykokkos.interface.views import ViewType
from pykokkos.core import cppast

def lower_bound(view: ViewType, size: int, value) -> int:
"""
Perform a lower bound search on a view

Returns the index of the first element not less than (i.e. greater or equal to) value,
similar to std::lower_bound or thrust::lower_bound.

:param view: the view to search (must be sorted)
:param size: the number of elements to search
:param value: the value to search for
:returns: the index of the first element >= value
"""
pass


def generate_lower_bound_binary_search(
view_expr: cppast.Expr, size_expr: cppast.Expr, value_expr: cppast.Expr
) -> cppast.CompoundStmt:
"""
Generate binary search implementation for lower_bound.
Returns a CompoundStmt that implements:

int left = 0;
int right = size;
int mid;
while (left < right) {
mid = left + (right - left) / 2;
if (view[mid] < value) {
left = mid + 1;
} else {
right = mid;
}
}
return left;
"""

# Variable declarations
int_type = cppast.PrimitiveType("int32_t")

# int left = 0;
left_var = cppast.DeclRefExpr("left")
left_init = cppast.IntegerLiteral(0)
left_decl = cppast.VarDecl(int_type, left_var, left_init)
left_stmt = cppast.DeclStmt(left_decl)

# int right = size;
right_var = cppast.DeclRefExpr("right")
right_decl = cppast.VarDecl(int_type, right_var, size_expr)
right_stmt = cppast.DeclStmt(right_decl)

# int mid;
mid_var = cppast.DeclRefExpr("mid")
mid_decl = cppast.VarDecl(int_type, mid_var, None)
mid_stmt = cppast.DeclStmt(mid_decl)

# while (left < right)
while_cond = cppast.BinaryOperator(left_var, right_var, cppast.BinaryOperatorKind.LT)

# mid = left + (right - left) / 2;
right_minus_left = cppast.BinaryOperator(right_var, left_var, cppast.BinaryOperatorKind.Sub)
div_expr = cppast.BinaryOperator(right_minus_left, cppast.IntegerLiteral(2), cppast.BinaryOperatorKind.Div)
mid_calc = cppast.BinaryOperator(left_var, div_expr, cppast.BinaryOperatorKind.Add)
mid_assign = cppast.AssignOperator([mid_var], mid_calc, cppast.BinaryOperatorKind.Assign)

# if (view[mid] < value)
view_ref = view_expr if isinstance(view_expr, cppast.DeclRefExpr) else cppast.DeclRefExpr("view")
view_access = cppast.ArraySubscriptExpr(view_ref, [mid_var])
if_cond = cppast.BinaryOperator(view_access, value_expr, cppast.BinaryOperatorKind.LT)

# left = mid + 1;
mid_plus_one = cppast.BinaryOperator(mid_var, cppast.IntegerLiteral(1), cppast.BinaryOperatorKind.Add)
left_assign = cppast.AssignOperator([left_var], mid_plus_one, cppast.BinaryOperatorKind.Assign)

# right = mid;
right_assign = cppast.AssignOperator([right_var], mid_var, cppast.BinaryOperatorKind.Assign)

# if-else statement
if_stmt = cppast.IfStmt(if_cond, left_assign, right_assign)

# while body
while_body = cppast.CompoundStmt([mid_assign, if_stmt])
while_stmt = cppast.WhileStmt(while_cond, while_body)

# return left;
return_stmt = cppast.ReturnStmt(left_var)

# Complete function body
return cppast.CompoundStmt([left_stmt, right_stmt, mid_stmt, while_stmt, return_stmt])
Loading