Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions pykokkos/core/module_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,15 @@ def __init__(
# The path to the main file if using the console
self.console_main: str = "pk_console"

self.main: Path = self.get_main_path()
if self.metadata and self.metadata[0].path:
entity_path = Path(self.metadata[0].path)
if entity_path.suffix == ".py":
self.main = entity_path.with_suffix("")
else:
self.main = entity_path
else:
self.main: Path = self.get_main_path()

self.output_dir: Optional[Path] = self.get_output_dir(
self.main, self.metadata, space, types_signature, self.restrict_signature
)
Expand All @@ -127,7 +135,12 @@ def __init__(
]

if self.output_dir is not None:
self.path: str = os.path.join(self.output_dir, self.module_file)
output_dir_abs = (
Path(self.output_dir).resolve()
if not Path(self.output_dir).is_absolute()
else Path(self.output_dir)
)
self.path: str = str(output_dir_abs / self.module_file)
if km.is_multi_gpu_enabled():
self.gpu_module_paths: str = [
os.path.join(self.output_dir, module_file)
Expand Down Expand Up @@ -181,6 +194,7 @@ def get_entity_dir(self, main: Path, metadata: List[EntityMetadata]) -> Path:
:returns: the path to the base output directory
"""

base_dir = self.get_main_dir(main)
entity_dir: str = ""

for m in metadata[:5]:
Expand All @@ -195,7 +209,7 @@ def get_entity_dir(self, main: Path, metadata: List[EntityMetadata]) -> Path:
if remaining != "":
entity_dir += hashlib.md5(("".join(remaining)).encode()).hexdigest()

return self.get_main_dir(main) / Path(entity_dir)
return base_dir / Path(entity_dir)

@staticmethod
def get_main_dir(main: Path) -> Path:
Expand All @@ -206,13 +220,19 @@ def get_main_dir(main: Path) -> Path:
:returns: the path to the main directory
"""

# If the parent directory is root, remove it so we can
# concatenate it to pk_cpp
main_path: Path = main
if str(main).startswith("/"):
main_path = Path(str(main)[1:])
# convert to absolute path and make it relative to CWD
main_abs: Path = (
main.resolve() if main.is_absolute() else (Path.cwd() / main).resolve()
)
try:
main_rel: Path = main_abs.relative_to(Path.cwd())
except ValueError:
# main_abs is not under cwd - fall back to old behavior
main_rel = (
Path(str(main_abs)[1:]) if str(main_abs).startswith("/") else main_abs
)

return Path(BASE_DIR) / main_path
return Path(BASE_DIR) / main_rel

def get_main_path(self) -> Path:
"""
Expand Down
94 changes: 70 additions & 24 deletions pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,44 @@ class Parser:
Parse a PyKokkos workload and its dependencies
"""

def __init__(self, path: str):
def __init__(self, path: Optional[str], pk_import: Optional[str] = None):
"""
Parse the file and find all entities

:param path: the path to the file
:param path: the path to the file (None for fused workunits)
:param pk_import: the pykokkos import identifier (required when path is None)
"""

self.lines: List[str]
self.tree: ast.Module
with open(path, "r") as f:
self.lines = f.readlines()
self.tree = ast.parse("".join(self.lines))

self.path: str = path
self.pk_import: str = self.get_import()
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

self.workloads = self.get_entities(PyKokkosStyles.workload)
self.classtypes = self.get_entities(PyKokkosStyles.classtype)
self.functors = self.get_entities(PyKokkosStyles.functor)
self.workunits = self.get_entities(PyKokkosStyles.workunit)
if path is not None:
with open(path, "r") as f:
self.lines = f.readlines()
self.tree = ast.parse("".join(self.lines))
self.path: Optional[str] = path
self.pk_import: str = self.get_import()
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

self.workloads = self.get_entities(PyKokkosStyles.workload)
self.classtypes = self.get_entities(PyKokkosStyles.classtype)
self.functors = self.get_entities(PyKokkosStyles.functor)
self.workunits = self.get_entities(PyKokkosStyles.workunit)
else:
# For fused workunits, we don't have a file to parse
# but we still need a parser instance for helper methods
if pk_import is None:
raise ValueError("pk_import must be provided when path is None")
self.lines = []
self.tree = ast.Module(body=[])
self.path = None
self.pk_import = pk_import
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

def get_import(self) -> str:
"""
Expand Down Expand Up @@ -151,6 +165,21 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]:

return entities

def _apply_inferred_types_to_args(
self, args: List[ast.arg], inferred_types: Dict[str, str]
) -> None:
"""
Helper method to apply inferred types to function arguments.
Used by both fix_types and fix_function_types.

:param args: List of argument AST nodes
:param inferred_types: Dictionary mapping parameter names to type strings
"""
for arg in args:
if arg.annotation is None and arg.arg in inferred_types:
type_str = inferred_types[arg.arg]
arg.annotation = self.get_annotation_node(type_str)

def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast.AST:
"""
Inject (into the entity AST) the missing annotations for datatypes that have been inferred.
Expand All @@ -172,17 +201,28 @@ def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast.
if needs_reset:
entity_tree = self.reset_entity_tree(entity_tree, updated_types)

for arg_obj in entity_tree.args.args:
# Type already provided by the user
if arg_obj.arg not in updated_types.inferred_types:
continue

update_type = updated_types.inferred_types[arg_obj.arg]
arg_obj.annotation = self.get_annotation_node(update_type)
# Reuse the shared logic
self._apply_inferred_types_to_args(
entity_tree.args.args, updated_types.inferred_types
)

assert entity_tree is not None
return entity_tree

def fix_function_types(
self, function: ast.FunctionDef, inferred_types: Dict[str, str]
) -> ast.FunctionDef:
"""
Apply inferred types to a Kokkos function AST.
Reuses the same logic as fix_types.

:param function: The function AST to modify
:param inferred_types: Dictionary mapping parameter names to type strings
:returns: The modified function AST
"""
self._apply_inferred_types_to_args(function.args.args, inferred_types)
return function

def check_self(self, entity_tree: ast.AST) -> bool:
"""
Check if self args exists in the AST, which implies this AST was already
Expand Down Expand Up @@ -295,6 +335,12 @@ def get_annotation_node(self, type: str) -> ast.AST:
attr="TeamMember",
ctx=ast.Load(),
)
elif type in ("double", "float"):
annotation_node = ast.Attribute(
value=ast.Name(id=self.pk_import, ctx=ast.Load()),
attr=type,
ctx=ast.Load(),
)
else:
raise ValueError(f"Type inference for {type} is not supported")

Expand Down
2 changes: 2 additions & 0 deletions pykokkos/core/translators/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def get_real_views(self):

views: Set[cppast.DeclRefExpr] = set()
for n, t in self.views.items():
if not t.template_params:
continue
dtype: cppast.PrimitiveType = t.template_params[0]
if isinstance(dtype, cppast.PrimitiveType):
if (
Expand Down
119 changes: 118 additions & 1 deletion pykokkos/core/translators/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pykokkos.core import cppast
from pykokkos.core.keywords import Keywords
from pykokkos.core.optimizations import add_restrict_views
from pykokkos.core.parsers import PyKokkosEntity, PyKokkosStyles
from pykokkos.core.parsers import Parser, PyKokkosEntity, PyKokkosStyles
from pykokkos.core.type_inference.args_type_inference import _infer_type_from_value
from pykokkos.core.visitors import (
ClasstypeVisitor,
KokkosFunctionVisitor,
Expand Down Expand Up @@ -69,6 +70,12 @@ def translate(
"""

self.pk_import = entity.pk_import
# Create parser instance to reuse its methods
# For fused workunits, path is None, so we pass pk_import explicitly
if entity.path is not None:
self.parser = Parser(entity.path)
else:
self.parser = Parser(None, pk_import=entity.pk_import)

entity.AST = self.add_parent_refs(entity.AST)
for c in classtypes:
Expand Down Expand Up @@ -230,6 +237,116 @@ def translate_functions(
:returns: a list of method declarations
"""

# Infer types for Kokkos function parameters from call sites in workunits
for functiondef in self.pk_members.pk_functions.values():
inferred_types: Dict[str, str] = {}
function_name = functiondef.name
param_names = [arg.arg for arg in functiondef.args.args]

class CallSiteFinder(ast.NodeVisitor):
def __init__(self, target_name: str, pk_import: str):
self.target_name = target_name
self.pk_import = pk_import
self.call_sites: List[ast.Call] = []

def visit_Call(self, node: ast.Call):
if (
isinstance(node.func, ast.Name)
and node.func.id == self.target_name
):
self.call_sites.append(node)
elif isinstance(node.func, ast.Attribute):
if (
isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
and node.func.attr == self.target_name
):
self.call_sites.append(node)
self.generic_visit(node)

finder = CallSiteFinder(function_name, self.pk_import)
for workunit in self.pk_members.pk_workunits.values():
finder.visit(workunit)

# Check return type for float preference (same logic as infer_other_args pattern)
prefer_float = False
if functiondef.returns and isinstance(functiondef.returns, ast.Attribute):
if (
isinstance(functiondef.returns.value, ast.Name)
and functiondef.returns.value.id == self.pk_import
):
if functiondef.returns.attr in ("double", "float"):
prefer_float = True

# extract values from call sites and infer types
for call_site in finder.call_sites:
for i, arg_node in enumerate(call_site.args):
if i >= len(param_names) or param_names[i] in inferred_types:
continue
param_name = param_names[i]
if isinstance(arg_node, ast.Constant):
value = arg_node.value
inferred_types[param_name] = _infer_type_from_value(
value, prefer_float
)
elif isinstance(arg_node, ast.Name):
found_type = False
for workunit in self.pk_members.pk_workunits.values():
for workunit_arg in workunit.args.args:
if (
workunit_arg.arg == arg_node.id
and workunit_arg.annotation
):
ann = workunit_arg.annotation
if isinstance(ann, ast.Name):
inferred_types[param_name] = ann.id
found_type = True
elif (
isinstance(ann, ast.Attribute)
and isinstance(ann.value, ast.Name)
and ann.value.id == self.pk_import
):
inferred_types[param_name] = ann.attr
found_type = True
break
if found_type:
break

if not found_type:
for other_function in self.pk_members.pk_functions.values():
for func_arg in other_function.args.args:
if (
func_arg.arg == arg_node.id
and func_arg.annotation
):
ann = func_arg.annotation
if isinstance(ann, ast.Name):
inferred_types[param_name] = ann.id
found_type = True
elif (
isinstance(ann, ast.Attribute)
and isinstance(ann.value, ast.Name)
and ann.value.id == self.pk_import
):
inferred_types[param_name] = ann.attr
found_type = True
break
if found_type:
break

# if parameter name is a common thread ID name, default to int
if not found_type and param_name in (
"tid",
"i",
"j",
"k",
"idx",
"index",
):
inferred_types[param_name] = "int"
if inferred_types:
self.parser.fix_function_types(functiondef, inferred_types)

# The visitor might add views declared as parameters
views = copy.deepcopy(self.pk_members.views)

Expand Down
Loading
Loading