diff --git a/pykokkos/core/module_setup.py b/pykokkos/core/module_setup.py index fb70e690..b05aabb5 100644 --- a/pykokkos/core/module_setup.py +++ b/pykokkos/core/module_setup.py @@ -116,7 +116,15 @@ def __init__( # The path to the main file if using the console self.console_main: str = "pk_console" - self.main: Path = self.get_main_path() + if self.metadata and self.metadata[0].path: + entity_path = Path(self.metadata[0].path) + if entity_path.suffix == ".py": + self.main = entity_path.with_suffix("") + else: + self.main = entity_path + else: + self.main: Path = self.get_main_path() + self.output_dir: Optional[Path] = self.get_output_dir( self.main, self.metadata, space, types_signature, self.restrict_signature ) @@ -127,7 +135,12 @@ def __init__( ] if self.output_dir is not None: - self.path: str = os.path.join(self.output_dir, self.module_file) + output_dir_abs = ( + Path(self.output_dir).resolve() + if not Path(self.output_dir).is_absolute() + else Path(self.output_dir) + ) + self.path: str = str(output_dir_abs / self.module_file) if km.is_multi_gpu_enabled(): self.gpu_module_paths: str = [ os.path.join(self.output_dir, module_file) @@ -181,6 +194,7 @@ def get_entity_dir(self, main: Path, metadata: List[EntityMetadata]) -> Path: :returns: the path to the base output directory """ + base_dir = self.get_main_dir(main) entity_dir: str = "" for m in metadata[:5]: @@ -195,7 +209,7 @@ def get_entity_dir(self, main: Path, metadata: List[EntityMetadata]) -> Path: if remaining != "": entity_dir += hashlib.md5(("".join(remaining)).encode()).hexdigest() - return self.get_main_dir(main) / Path(entity_dir) + return base_dir / Path(entity_dir) @staticmethod def get_main_dir(main: Path) -> Path: @@ -206,13 +220,19 @@ def get_main_dir(main: Path) -> Path: :returns: the path to the main directory """ - # If the parent directory is root, remove it so we can - # concatenate it to pk_cpp - main_path: Path = main - if str(main).startswith("/"): - main_path = Path(str(main)[1:]) + # convert to absolute path and make it relative to CWD + main_abs: Path = ( + main.resolve() if main.is_absolute() else (Path.cwd() / main).resolve() + ) + try: + main_rel: Path = main_abs.relative_to(Path.cwd()) + except ValueError: + # main_abs is not under cwd - fall back to old behavior + main_rel = ( + Path(str(main_abs)[1:]) if str(main_abs).startswith("/") else main_abs + ) - return Path(BASE_DIR) / main_path + return Path(BASE_DIR) / main_rel def get_main_path(self) -> Path: """ diff --git a/pykokkos/core/parsers/parser.py b/pykokkos/core/parsers/parser.py index 6a971dc9..b762c7b7 100644 --- a/pykokkos/core/parsers/parser.py +++ b/pykokkos/core/parsers/parser.py @@ -40,30 +40,44 @@ class Parser: Parse a PyKokkos workload and its dependencies """ - def __init__(self, path: str): + def __init__(self, path: Optional[str], pk_import: Optional[str] = None): """ Parse the file and find all entities - :param path: the path to the file + :param path: the path to the file (None for fused workunits) + :param pk_import: the pykokkos import identifier (required when path is None) """ self.lines: List[str] self.tree: ast.Module - with open(path, "r") as f: - self.lines = f.readlines() - self.tree = ast.parse("".join(self.lines)) - - self.path: str = path - self.pk_import: str = self.get_import() - self.workloads: Dict[str, PyKokkosEntity] = {} - self.classtypes: Dict[str, PyKokkosEntity] = {} - self.functors: Dict[str, PyKokkosEntity] = {} - self.workunits: Dict[str, PyKokkosEntity] = {} - - self.workloads = self.get_entities(PyKokkosStyles.workload) - self.classtypes = self.get_entities(PyKokkosStyles.classtype) - self.functors = self.get_entities(PyKokkosStyles.functor) - self.workunits = self.get_entities(PyKokkosStyles.workunit) + if path is not None: + with open(path, "r") as f: + self.lines = f.readlines() + self.tree = ast.parse("".join(self.lines)) + self.path: Optional[str] = path + self.pk_import: str = self.get_import() + self.workloads: Dict[str, PyKokkosEntity] = {} + self.classtypes: Dict[str, PyKokkosEntity] = {} + self.functors: Dict[str, PyKokkosEntity] = {} + self.workunits: Dict[str, PyKokkosEntity] = {} + + self.workloads = self.get_entities(PyKokkosStyles.workload) + self.classtypes = self.get_entities(PyKokkosStyles.classtype) + self.functors = self.get_entities(PyKokkosStyles.functor) + self.workunits = self.get_entities(PyKokkosStyles.workunit) + else: + # For fused workunits, we don't have a file to parse + # but we still need a parser instance for helper methods + if pk_import is None: + raise ValueError("pk_import must be provided when path is None") + self.lines = [] + self.tree = ast.Module(body=[]) + self.path = None + self.pk_import = pk_import + self.workloads: Dict[str, PyKokkosEntity] = {} + self.classtypes: Dict[str, PyKokkosEntity] = {} + self.functors: Dict[str, PyKokkosEntity] = {} + self.workunits: Dict[str, PyKokkosEntity] = {} def get_import(self) -> str: """ @@ -151,6 +165,21 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]: return entities + def _apply_inferred_types_to_args( + self, args: List[ast.arg], inferred_types: Dict[str, str] + ) -> None: + """ + Helper method to apply inferred types to function arguments. + Used by both fix_types and fix_function_types. + + :param args: List of argument AST nodes + :param inferred_types: Dictionary mapping parameter names to type strings + """ + for arg in args: + if arg.annotation is None and arg.arg in inferred_types: + type_str = inferred_types[arg.arg] + arg.annotation = self.get_annotation_node(type_str) + def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast.AST: """ Inject (into the entity AST) the missing annotations for datatypes that have been inferred. @@ -172,17 +201,28 @@ def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast. if needs_reset: entity_tree = self.reset_entity_tree(entity_tree, updated_types) - for arg_obj in entity_tree.args.args: - # Type already provided by the user - if arg_obj.arg not in updated_types.inferred_types: - continue - - update_type = updated_types.inferred_types[arg_obj.arg] - arg_obj.annotation = self.get_annotation_node(update_type) + # Reuse the shared logic + self._apply_inferred_types_to_args( + entity_tree.args.args, updated_types.inferred_types + ) assert entity_tree is not None return entity_tree + def fix_function_types( + self, function: ast.FunctionDef, inferred_types: Dict[str, str] + ) -> ast.FunctionDef: + """ + Apply inferred types to a Kokkos function AST. + Reuses the same logic as fix_types. + + :param function: The function AST to modify + :param inferred_types: Dictionary mapping parameter names to type strings + :returns: The modified function AST + """ + self._apply_inferred_types_to_args(function.args.args, inferred_types) + return function + def check_self(self, entity_tree: ast.AST) -> bool: """ Check if self args exists in the AST, which implies this AST was already @@ -295,6 +335,12 @@ def get_annotation_node(self, type: str) -> ast.AST: attr="TeamMember", ctx=ast.Load(), ) + elif type in ("double", "float"): + annotation_node = ast.Attribute( + value=ast.Name(id=self.pk_import, ctx=ast.Load()), + attr=type, + ctx=ast.Load(), + ) else: raise ValueError(f"Type inference for {type} is not supported") diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index dd6b08ac..d21138ef 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -212,6 +212,8 @@ def get_real_views(self): views: Set[cppast.DeclRefExpr] = set() for n, t in self.views.items(): + if not t.template_params: + continue dtype: cppast.PrimitiveType = t.template_params[0] if isinstance(dtype, cppast.PrimitiveType): if ( diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index 7711691c..dced1afe 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -2,12 +2,13 @@ import copy import os import sys -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union from pykokkos.core import cppast from pykokkos.core.keywords import Keywords from pykokkos.core.optimizations import add_restrict_views -from pykokkos.core.parsers import PyKokkosEntity, PyKokkosStyles +from pykokkos.core.parsers import Parser, PyKokkosEntity, PyKokkosStyles +from pykokkos.core.type_inference.args_type_inference import _infer_type_from_value from pykokkos.core.visitors import ( ClasstypeVisitor, KokkosFunctionVisitor, @@ -69,6 +70,12 @@ def translate( """ self.pk_import = entity.pk_import + # Create parser instance to reuse its methods + # For fused workunits, path is None, so we pass pk_import explicitly + if entity.path is not None: + self.parser = Parser(entity.path) + else: + self.parser = Parser(None, pk_import=entity.pk_import) entity.AST = self.add_parent_refs(entity.AST) for c in classtypes: @@ -219,6 +226,137 @@ def translate_classtypes( return declarations + definitions + def _extract_type_from_annotation( + self, annotation: Union[ast.Name, ast.Attribute] + ) -> Optional[str]: + """Extract type string from an annotation AST node.""" + if isinstance(annotation, ast.Name): + return annotation.id + if ( + isinstance(annotation, ast.Attribute) + and isinstance(annotation.value, ast.Name) + and annotation.value.id == self.pk_import + ): + return annotation.attr + return None + + def _find_type_in_workunits(self, arg_name: str) -> Optional[str]: + """Find type of an argument by searching workunit parameters.""" + for workunit in self.pk_members.pk_workunits.values(): + for workunit_arg in workunit.args.args: + if workunit_arg.arg == arg_name and workunit_arg.annotation: + return self._extract_type_from_annotation(workunit_arg.annotation) + return None + + def _find_type_in_functions(self, arg_name: str) -> Optional[str]: + """Find type of an argument by searching other function parameters.""" + for function in self.pk_members.pk_functions.values(): + for func_arg in function.args.args: + if func_arg.arg == arg_name and func_arg.annotation: + return self._extract_type_from_annotation(func_arg.annotation) + return None + + def _infer_type_from_argument(self, arg_node: ast.expr) -> Optional[str]: + """Infer type from a single argument node.""" + if isinstance(arg_node, ast.Constant): + return _infer_type_from_value(arg_node.value) + + if isinstance(arg_node, ast.Name): + type_str = self._find_type_in_workunits(arg_node.id) + if type_str: + return type_str + return self._find_type_in_functions(arg_node.id) + + return None + + def _infer_function_parameter_types( + self, functiondef: ast.FunctionDef, call_sites: List[ast.Call] + ) -> Dict[str, str]: + """Infer types for function parameters from call sites.""" + inferred_types: Dict[str, str] = {} + all_params = functiondef.args.args + + for call_site in call_sites: + for arg_idx, arg_node in enumerate(call_site.args): + # Map call argument index to parameter index + param_idx = arg_idx + if all_params and all_params[0].arg == "self": + param_idx = arg_idx + 1 + + if param_idx >= len(all_params): + continue + param = all_params[param_idx] + + # already infered + if param.annotation is not None: + continue + + inferred_type = self._infer_type_from_argument(arg_node) + if inferred_type: + inferred_types[param.arg] = inferred_type + + return inferred_types + + def _find_all_call_sites(self, function_name: str) -> List[ast.Call]: + """Find all call sites for a given function.""" + + class CallSiteFinder(ast.NodeVisitor): + def __init__(self, target_name: str, pk_import: str): + self.target_name = target_name + self.pk_import = pk_import + self.call_sites: List[ast.Call] = [] + + def visit_Call(self, node: ast.Call): + # Direct function call: function_name() + if isinstance(node.func, ast.Name) and node.func.id == self.target_name: + self.call_sites.append(node) + # Method call: self.function_name() or pk.function_name() + elif ( + isinstance(node.func, ast.Attribute) + and node.func.attr == self.target_name + ): + if isinstance(node.func.value, ast.Name): + # Accept both self.function_name and pk.function_name + if node.func.value.id in ("self", self.pk_import): + self.call_sites.append(node) + + self.generic_visit(node) + + finder = CallSiteFinder(function_name, self.pk_import) + + # Search in workunits + for workunit in self.pk_members.pk_workunits.values(): + finder.visit(workunit) + + # Search in other functions for nested calls + for other_func in self.pk_members.pk_functions.values(): + finder.visit(other_func) + + return finder.call_sites + + def _infer_function_types_iteratively(self) -> None: + """Iteratively infer types for all functions until convergence.""" + max_iterations = 20 + + for _ in range(max_iterations): + types_changed = False + + for functiondef in self.pk_members.pk_functions.values(): + call_sites = self._find_all_call_sites(functiondef.name) + inferred_types = self._infer_function_parameter_types( + functiondef, call_sites + ) + + if inferred_types: + self.parser.fix_function_types(functiondef, inferred_types) + types_changed = True + if not types_changed: + break + else: + print( + f"Warning: Type inference did not converge after {max_iterations} iterations" + ) + def translate_functions( self, source: Tuple[List[str], int], restrict_views: Set[str] ) -> List[cppast.MethodDecl]: @@ -230,6 +368,8 @@ def translate_functions( :returns: a list of method declarations """ + self._infer_function_types_iteratively() + # The visitor might add views declared as parameters views = copy.deepcopy(self.pk_members.views) diff --git a/pykokkos/core/type_inference/args_type_inference.py b/pykokkos/core/type_inference/args_type_inference.py index f64710a1..7abd06fb 100644 --- a/pykokkos/core/type_inference/args_type_inference.py +++ b/pykokkos/core/type_inference/args_type_inference.py @@ -56,10 +56,53 @@ class UpdatedDecorator: "float32", ] -# Cache for original argument nodes: Maps stringified workunit reference, e.g str(workunit_name), to the original ast.arguments node +# Supported array libraries for type inference +SUPPORTED_ARRAY_LIBRARIES = ("numpy", "cupy", "torch", "jax", "jaxlib") + ORIGINAL_PARAMS: Dict[str, ast.arguments] = {} +def _infer_type_from_value(value) -> str: + """ + Infer the type string from a Python value, reusing the same logic as infer_other_args. + Uses DataType enum and SUPPORTED_NP_DTYPES for consistency. + + :param value: The Python value to infer type from + :returns: Type string in the format used by type inference (e.g., "int", "double", "numpy:int64") + """ + param_type = type(value).__name__ + + if param_type == "int": + return "int" + elif param_type == "float": + return DataType.double.name + elif param_type == "bool": + return DataType.bool.name + else: + # Handle array library scalar types (numpy, cupy, torch, etc.) + pckg_name = type(value).__module__ + + if any(pckg_name.startswith(pkg) for pkg in SUPPORTED_ARRAY_LIBRARIES): + if param_type not in SUPPORTED_NP_DTYPES: + raise TypeError( + f"Array type {param_type} from {pckg_name} is unsupported" + ) + + if param_type == DataType.float64.name or param_type == "float64": + param_type = DataType.double.name + elif param_type == DataType.float32.name or param_type == "float32": + param_type = DataType.float.name + + return f"numpy:{param_type}" + else: + supported_libs = ", ".join(SUPPORTED_ARRAY_LIBRARIES) + raise TypeError( + f"Unsupported type for type inference: {type(value)} from module {pckg_name}. " + f"Only Python primitives (int, float, bool) and array library types " + f"({supported_libs}) are supported." + ) + + def check_missing_annotations(param_list: List[ast.arg]) -> bool: """ Check if any annotation node for parent argument node is none @@ -280,33 +323,14 @@ def infer_other_args( if param.annotation is not None: continue - param_type = type(value).__name__ - - # switch integer values over 31 bits (signed positive value) to numpy:int64 - if param_type == "int" and value.bit_length() > 31: - param_type = "numpy:int64" - - # check if package name is numpy (handling numpy primitives) - pckg_name = type(value).__module__ - - if pckg_name == "numpy": - if param_type not in SUPPORTED_NP_DTYPES: - err_str = f"Numpy type {param_type} is unsupported" - raise TypeError(err_str) - - if param_type == "float64": - param_type = "double" - if param_type == "float32": - param_type = "float" - # numpy:, Will switch to pk. in parser.fix_types - param_type = pckg_name + ":" + param_type - if isinstance(value, ViewType): view_dtype = get_pk_datatype(value.dtype) if not view_dtype: raise TypeError("Cannot infer datatype for view:", param.arg) param_type = "View" + str(len(value.shape)) + "D:" + view_dtype + else: + param_type = _infer_type_from_value(value) updated_types.inferred_types[param.arg] = param_type diff --git a/pykokkos/core/visitors/kokkosfunction_visitor.py b/pykokkos/core/visitors/kokkosfunction_visitor.py index 4ada9c6d..e8f3b649 100644 --- a/pykokkos/core/visitors/kokkosfunction_visitor.py +++ b/pykokkos/core/visitors/kokkosfunction_visitor.py @@ -15,8 +15,7 @@ class KokkosFunctionVisitor(PyKokkosVisitor): def visit_FunctionDef(self, node: ast.FunctionDef) -> cppast.MethodDecl: - if not self.is_valid_kokkos_function(node): - self.error(node, "Invalid Kokkos function") + self.is_valid_kokkos_function(node) return_type: cppast.ClassType if self.is_void_function(node): @@ -114,10 +113,7 @@ def visit_arguments(self, node: ast.arguments) -> None: def is_valid_kokkos_function(self, node) -> bool: # Is the return type annotation missing if node.returns is None: - return False - - # Is the type annotation for any argument missing (excluding self) - if any(arg.annotation is None and arg.arg != "self" for arg in node.args.args): - return False - - return True + self.error( + node.returns, + f"Return type annotation missing in function `{node.name}`.", + ) diff --git a/pykokkos/core/visitors/visitors_util.py b/pykokkos/core/visitors/visitors_util.py index 03a8bc5c..ecb10bcc 100644 --- a/pykokkos/core/visitors/visitors_util.py +++ b/pykokkos/core/visitors/visitors_util.py @@ -36,7 +36,9 @@ def pretty_print(node): "uint32": cppast.BuiltinType.UINT32, "uint64": cppast.BuiltinType.UINT64, "float": cppast.BuiltinType.FLOAT, + "float32": cppast.BuiltinType.FLOAT, # Alias for float "double": cppast.BuiltinType.DOUBLE, + "float64": cppast.BuiltinType.DOUBLE, # Alias for double "int": cppast.BuiltinType.INT32, "real": Keywords.RealPrecision.value, } @@ -136,7 +138,7 @@ def error(src, debug: bool, node, message) -> None: else: print(f"\n\033[31m\033[01mError\033[0m: {message}") - if debug: + if debug and node is not None: print("DEBUG AST:") pretty_print(node) @@ -297,8 +299,28 @@ def parse_view_template_params( py_type: str = view_type.typename is_scratch_view: bool = py_type.startswith("ScratchView") + # Check if this is actually a view type (starts with "View" or "ScratchView") + # If not, this might be a dtype that was incorrectly passed as a view type + if not (py_type.startswith("View") or py_type.startswith("ScratchView")): + raise ValueError( + f"Expected a view type (e.g., 'View1D', 'View2D', 'ScratchView1D'), " + f"but got '{py_type}'. This might be a dtype that was incorrectly treated as a view type." + ) + if rank is None: - rank = int(re.search(r"\d+", py_type).group()) + # Match the rank number that comes after "View" or "ScratchView" and before "D" + # This prevents matching numbers from dtype names like "float32" or "float64" + match = re.search(r"(?:View|ScratchView)(\d+)D", py_type) + if match: + rank = int(match.group(1)) + else: + # If pattern doesn't match, this is likely not a valid view type name + # or the typename format is unexpected - raise an error instead of + # using a fallback that could match wrong numbers from dtype names + raise ValueError( + f"Could not extract view rank from typename '{py_type}'. " + f"Expected format: 'ViewD' or 'ScratchViewD' (e.g., 'View1D', 'View2D')" + ) if not 0 < rank < 8: raise ValueError(f"View rank {rank} is not allowed") diff --git a/tests/test_kokkosfunctions_translator.py b/tests/test_kokkosfunctions_translator.py index 5108e6f4..fbe9d849 100644 --- a/tests/test_kokkosfunctions_translator.py +++ b/tests/test_kokkosfunctions_translator.py @@ -97,15 +97,15 @@ def views(self, tid: int, acc: pk.Acc[pk.double]) -> None: acc += self.use_views(tid) @pk.function - def nested_views_1(self, tid: int) -> int: + def nested_views_1(self, tid) -> int: return self.view1D[tid] @pk.function - def nested_views_2(self, tid: int) -> int: + def nested_views_2(self, tid) -> int: return self.nested_views_1(tid) + self.view2D[tid][0] @pk.function - def nested_views_3(self, tid: int) -> int: + def nested_views_3(self, tid) -> int: return self.nested_views_2(tid) + self.view3D[tid][0][0] @pk.workunit