Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions pykokkos/core/module_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,15 @@ def __init__(
# The path to the main file if using the console
self.console_main: str = "pk_console"

self.main: Path = self.get_main_path()
if self.metadata and self.metadata[0].path:
entity_path = Path(self.metadata[0].path)
if entity_path.suffix == ".py":
self.main = entity_path.with_suffix("")
else:
self.main = entity_path
else:
self.main: Path = self.get_main_path()

self.output_dir: Optional[Path] = self.get_output_dir(
self.main, self.metadata, space, types_signature, self.restrict_signature
)
Expand All @@ -127,7 +135,12 @@ def __init__(
]

if self.output_dir is not None:
self.path: str = os.path.join(self.output_dir, self.module_file)
output_dir_abs = (
Path(self.output_dir).resolve()
if not Path(self.output_dir).is_absolute()
else Path(self.output_dir)
)
self.path: str = str(output_dir_abs / self.module_file)
if km.is_multi_gpu_enabled():
self.gpu_module_paths: str = [
os.path.join(self.output_dir, module_file)
Expand Down Expand Up @@ -181,6 +194,7 @@ def get_entity_dir(self, main: Path, metadata: List[EntityMetadata]) -> Path:
:returns: the path to the base output directory
"""

base_dir = self.get_main_dir(main)
entity_dir: str = ""

for m in metadata[:5]:
Expand All @@ -195,7 +209,7 @@ def get_entity_dir(self, main: Path, metadata: List[EntityMetadata]) -> Path:
if remaining != "":
entity_dir += hashlib.md5(("".join(remaining)).encode()).hexdigest()

return self.get_main_dir(main) / Path(entity_dir)
return base_dir / Path(entity_dir)

@staticmethod
def get_main_dir(main: Path) -> Path:
Expand All @@ -206,13 +220,19 @@ def get_main_dir(main: Path) -> Path:
:returns: the path to the main directory
"""

# If the parent directory is root, remove it so we can
# concatenate it to pk_cpp
main_path: Path = main
if str(main).startswith("/"):
main_path = Path(str(main)[1:])
# convert to absolute path and make it relative to CWD
main_abs: Path = (
main.resolve() if main.is_absolute() else (Path.cwd() / main).resolve()
)
try:
main_rel: Path = main_abs.relative_to(Path.cwd())
except ValueError:
# main_abs is not under cwd - fall back to old behavior
main_rel = (
Path(str(main_abs)[1:]) if str(main_abs).startswith("/") else main_abs
)

return Path(BASE_DIR) / main_path
return Path(BASE_DIR) / main_rel

def get_main_path(self) -> Path:
"""
Expand Down
94 changes: 70 additions & 24 deletions pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,44 @@ class Parser:
Parse a PyKokkos workload and its dependencies
"""

def __init__(self, path: str):
def __init__(self, path: Optional[str], pk_import: Optional[str] = None):
"""
Parse the file and find all entities

:param path: the path to the file
:param path: the path to the file (None for fused workunits)
:param pk_import: the pykokkos import identifier (required when path is None)
"""

self.lines: List[str]
self.tree: ast.Module
with open(path, "r") as f:
self.lines = f.readlines()
self.tree = ast.parse("".join(self.lines))

self.path: str = path
self.pk_import: str = self.get_import()
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

self.workloads = self.get_entities(PyKokkosStyles.workload)
self.classtypes = self.get_entities(PyKokkosStyles.classtype)
self.functors = self.get_entities(PyKokkosStyles.functor)
self.workunits = self.get_entities(PyKokkosStyles.workunit)
if path is not None:
with open(path, "r") as f:
self.lines = f.readlines()
self.tree = ast.parse("".join(self.lines))
self.path: Optional[str] = path
self.pk_import: str = self.get_import()
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

self.workloads = self.get_entities(PyKokkosStyles.workload)
self.classtypes = self.get_entities(PyKokkosStyles.classtype)
self.functors = self.get_entities(PyKokkosStyles.functor)
self.workunits = self.get_entities(PyKokkosStyles.workunit)
else:
# For fused workunits, we don't have a file to parse
# but we still need a parser instance for helper methods
if pk_import is None:
raise ValueError("pk_import must be provided when path is None")
self.lines = []
self.tree = ast.Module(body=[])
self.path = None
self.pk_import = pk_import
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

def get_import(self) -> str:
"""
Expand Down Expand Up @@ -151,6 +165,21 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]:

return entities

def _apply_inferred_types_to_args(
self, args: List[ast.arg], inferred_types: Dict[str, str]
) -> None:
"""
Helper method to apply inferred types to function arguments.
Used by both fix_types and fix_function_types.

:param args: List of argument AST nodes
:param inferred_types: Dictionary mapping parameter names to type strings
"""
for arg in args:
if arg.annotation is None and arg.arg in inferred_types:
type_str = inferred_types[arg.arg]
arg.annotation = self.get_annotation_node(type_str)

def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast.AST:
"""
Inject (into the entity AST) the missing annotations for datatypes that have been inferred.
Expand All @@ -172,17 +201,28 @@ def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast.
if needs_reset:
entity_tree = self.reset_entity_tree(entity_tree, updated_types)

for arg_obj in entity_tree.args.args:
# Type already provided by the user
if arg_obj.arg not in updated_types.inferred_types:
continue

update_type = updated_types.inferred_types[arg_obj.arg]
arg_obj.annotation = self.get_annotation_node(update_type)
# Reuse the shared logic
self._apply_inferred_types_to_args(
entity_tree.args.args, updated_types.inferred_types
)

assert entity_tree is not None
return entity_tree

def fix_function_types(
self, function: ast.FunctionDef, inferred_types: Dict[str, str]
) -> ast.FunctionDef:
"""
Apply inferred types to a Kokkos function AST.
Reuses the same logic as fix_types.

:param function: The function AST to modify
:param inferred_types: Dictionary mapping parameter names to type strings
:returns: The modified function AST
"""
self._apply_inferred_types_to_args(function.args.args, inferred_types)
return function

def check_self(self, entity_tree: ast.AST) -> bool:
"""
Check if self args exists in the AST, which implies this AST was already
Expand Down Expand Up @@ -295,6 +335,12 @@ def get_annotation_node(self, type: str) -> ast.AST:
attr="TeamMember",
ctx=ast.Load(),
)
elif type in ("double", "float"):
annotation_node = ast.Attribute(
value=ast.Name(id=self.pk_import, ctx=ast.Load()),
attr=type,
ctx=ast.Load(),
)
else:
raise ValueError(f"Type inference for {type} is not supported")

Expand Down
2 changes: 2 additions & 0 deletions pykokkos/core/translators/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def get_real_views(self):

views: Set[cppast.DeclRefExpr] = set()
for n, t in self.views.items():
if not t.template_params:
continue
dtype: cppast.PrimitiveType = t.template_params[0]
if isinstance(dtype, cppast.PrimitiveType):
if (
Expand Down
144 changes: 142 additions & 2 deletions pykokkos/core/translators/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import copy
import os
import sys
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union

from pykokkos.core import cppast
from pykokkos.core.keywords import Keywords
from pykokkos.core.optimizations import add_restrict_views
from pykokkos.core.parsers import PyKokkosEntity, PyKokkosStyles
from pykokkos.core.parsers import Parser, PyKokkosEntity, PyKokkosStyles
from pykokkos.core.type_inference.args_type_inference import _infer_type_from_value
from pykokkos.core.visitors import (
ClasstypeVisitor,
KokkosFunctionVisitor,
Expand Down Expand Up @@ -69,6 +70,12 @@ def translate(
"""

self.pk_import = entity.pk_import
# Create parser instance to reuse its methods
# For fused workunits, path is None, so we pass pk_import explicitly
if entity.path is not None:
self.parser = Parser(entity.path)
else:
self.parser = Parser(None, pk_import=entity.pk_import)

entity.AST = self.add_parent_refs(entity.AST)
for c in classtypes:
Expand Down Expand Up @@ -219,6 +226,137 @@ def translate_classtypes(

return declarations + definitions

def _extract_type_from_annotation(
self, annotation: Union[ast.Name, ast.Attribute]
) -> Optional[str]:
"""Extract type string from an annotation AST node."""
if isinstance(annotation, ast.Name):
return annotation.id
if (
isinstance(annotation, ast.Attribute)
and isinstance(annotation.value, ast.Name)
and annotation.value.id == self.pk_import
):
return annotation.attr
return None

def _find_type_in_workunits(self, arg_name: str) -> Optional[str]:
"""Find type of an argument by searching workunit parameters."""
for workunit in self.pk_members.pk_workunits.values():
for workunit_arg in workunit.args.args:
if workunit_arg.arg == arg_name and workunit_arg.annotation:
return self._extract_type_from_annotation(workunit_arg.annotation)
return None

def _find_type_in_functions(self, arg_name: str) -> Optional[str]:
"""Find type of an argument by searching other function parameters."""
for function in self.pk_members.pk_functions.values():
for func_arg in function.args.args:
if func_arg.arg == arg_name and func_arg.annotation:
return self._extract_type_from_annotation(func_arg.annotation)
return None

def _infer_type_from_argument(self, arg_node: ast.expr) -> Optional[str]:
"""Infer type from a single argument node."""
if isinstance(arg_node, ast.Constant):
return _infer_type_from_value(arg_node.value)

if isinstance(arg_node, ast.Name):
type_str = self._find_type_in_workunits(arg_node.id)
if type_str:
return type_str
return self._find_type_in_functions(arg_node.id)

return None

def _infer_function_parameter_types(
self, functiondef: ast.FunctionDef, call_sites: List[ast.Call]
) -> Dict[str, str]:
"""Infer types for function parameters from call sites."""
inferred_types: Dict[str, str] = {}
all_params = functiondef.args.args

for call_site in call_sites:
for arg_idx, arg_node in enumerate(call_site.args):
# Map call argument index to parameter index
param_idx = arg_idx
if all_params and all_params[0].arg == "self":
param_idx = arg_idx + 1

if param_idx >= len(all_params):
continue
param = all_params[param_idx]

# already infered
if param.annotation is not None:
continue

inferred_type = self._infer_type_from_argument(arg_node)
if inferred_type:
inferred_types[param.arg] = inferred_type

return inferred_types

def _find_all_call_sites(self, function_name: str) -> List[ast.Call]:
"""Find all call sites for a given function."""

class CallSiteFinder(ast.NodeVisitor):
def __init__(self, target_name: str, pk_import: str):
self.target_name = target_name
self.pk_import = pk_import
self.call_sites: List[ast.Call] = []

def visit_Call(self, node: ast.Call):
# Direct function call: function_name()
if isinstance(node.func, ast.Name) and node.func.id == self.target_name:
self.call_sites.append(node)
# Method call: self.function_name() or pk.function_name()
elif (
isinstance(node.func, ast.Attribute)
and node.func.attr == self.target_name
):
if isinstance(node.func.value, ast.Name):
# Accept both self.function_name and pk.function_name
if node.func.value.id in ("self", self.pk_import):
self.call_sites.append(node)

self.generic_visit(node)

finder = CallSiteFinder(function_name, self.pk_import)

# Search in workunits
for workunit in self.pk_members.pk_workunits.values():
finder.visit(workunit)

# Search in other functions for nested calls
for other_func in self.pk_members.pk_functions.values():
finder.visit(other_func)

return finder.call_sites

def _infer_function_types_iteratively(self) -> None:
"""Iteratively infer types for all functions until convergence."""
max_iterations = 20

for _ in range(max_iterations):
types_changed = False

for functiondef in self.pk_members.pk_functions.values():
call_sites = self._find_all_call_sites(functiondef.name)
inferred_types = self._infer_function_parameter_types(
functiondef, call_sites
)

if inferred_types:
self.parser.fix_function_types(functiondef, inferred_types)
types_changed = True
if not types_changed:
break
else:
print(
f"Warning: Type inference did not converge after {max_iterations} iterations"
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have some sort of failure warning here if the iterations max out and there are still uninferred types?

def translate_functions(
self, source: Tuple[List[str], int], restrict_views: Set[str]
) -> List[cppast.MethodDecl]:
Expand All @@ -230,6 +368,8 @@ def translate_functions(
:returns: a list of method declarations
"""

self._infer_function_types_iteratively()

# The visitor might add views declared as parameters
views = copy.deepcopy(self.pk_members.views)

Expand Down
Loading
Loading