Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,21 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]:

return entities

def _apply_inferred_types_to_args(
self, args: List[ast.arg], inferred_types: Dict[str, str]
) -> None:
"""
Helper method to apply inferred types to function arguments.
Used by both fix_types and fix_function_types.

:param args: List of argument AST nodes
:param inferred_types: Dictionary mapping parameter names to type strings
"""
for arg in args:
if arg.annotation is None and arg.arg in inferred_types:
type_str = inferred_types[arg.arg]
arg.annotation = self.get_annotation_node(type_str)

def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast.AST:
"""
Inject (into the entity AST) the missing annotations for datatypes that have been inferred.
Expand All @@ -172,17 +187,28 @@ def fix_types(self, entity: PyKokkosEntity, updated_types: UpdatedTypes) -> ast.
if needs_reset:
entity_tree = self.reset_entity_tree(entity_tree, updated_types)

for arg_obj in entity_tree.args.args:
# Type already provided by the user
if arg_obj.arg not in updated_types.inferred_types:
continue

update_type = updated_types.inferred_types[arg_obj.arg]
arg_obj.annotation = self.get_annotation_node(update_type)
# Reuse the shared logic
self._apply_inferred_types_to_args(
entity_tree.args.args, updated_types.inferred_types
)

assert entity_tree is not None
return entity_tree

def fix_function_types(
self, function: ast.FunctionDef, inferred_types: Dict[str, str]
) -> ast.FunctionDef:
"""
Apply inferred types to a Kokkos function AST.
Reuses the same logic as fix_types.

:param function: The function AST to modify
:param inferred_types: Dictionary mapping parameter names to type strings
:returns: The modified function AST
"""
self._apply_inferred_types_to_args(function.args.args, inferred_types)
return function

def check_self(self, entity_tree: ast.AST) -> bool:
"""
Check if self args exists in the AST, which implies this AST was already
Expand Down Expand Up @@ -295,6 +321,12 @@ def get_annotation_node(self, type: str) -> ast.AST:
attr="TeamMember",
ctx=ast.Load(),
)
elif type in ("double", "float"):
annotation_node = ast.Attribute(
value=ast.Name(id=self.pk_import, ctx=ast.Load()),
attr=type,
ctx=ast.Load(),
)
else:
raise ValueError(f"Type inference for {type} is not supported")

Expand Down
77 changes: 76 additions & 1 deletion pykokkos/core/translators/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pykokkos.core import cppast
from pykokkos.core.keywords import Keywords
from pykokkos.core.optimizations import add_restrict_views
from pykokkos.core.parsers import PyKokkosEntity, PyKokkosStyles
from pykokkos.core.parsers import Parser, PyKokkosEntity, PyKokkosStyles
from pykokkos.core.type_inference.args_type_inference import _infer_type_from_value
from pykokkos.core.visitors import (
ClasstypeVisitor,
KokkosFunctionVisitor,
Expand Down Expand Up @@ -69,6 +70,8 @@ def translate(
"""

self.pk_import = entity.pk_import
# Create parser instance to reuse its methods
self.parser = Parser(entity.path)

entity.AST = self.add_parent_refs(entity.AST)
for c in classtypes:
Expand Down Expand Up @@ -230,6 +233,78 @@ def translate_functions(
:returns: a list of method declarations
"""

# Infer types for Kokkos function parameters from call sites in workunits
for functiondef in self.pk_members.pk_functions.values():
inferred_types: Dict[str, str] = {}
function_name = functiondef.name
param_names = [arg.arg for arg in functiondef.args.args]

class CallSiteFinder(ast.NodeVisitor):
def __init__(self, target_name: str, pk_import: str):
self.target_name = target_name
self.pk_import = pk_import
self.call_sites: List[ast.Call] = []

def visit_Call(self, node: ast.Call):
if (
isinstance(node.func, ast.Name)
and node.func.id == self.target_name
):
self.call_sites.append(node)
elif isinstance(node.func, ast.Attribute):
if (
isinstance(node.func.value, ast.Name)
and node.func.value.id == self.pk_import
and node.func.attr == self.target_name
):
self.call_sites.append(node)
self.generic_visit(node)

finder = CallSiteFinder(function_name, self.pk_import)
for workunit in self.pk_members.pk_workunits.values():
finder.visit(workunit)

# Check return type for float preference (same logic as infer_other_args pattern)
prefer_float = False
if functiondef.returns and isinstance(functiondef.returns, ast.Attribute):
if (
isinstance(functiondef.returns.value, ast.Name)
and functiondef.returns.value.id == self.pk_import
):
if functiondef.returns.attr in ("double", "float"):
prefer_float = True

# extract values from call sites and infer types
for call_site in finder.call_sites:
for i, arg_node in enumerate(call_site.args):
if i >= len(param_names) or param_names[i] in inferred_types:
continue
param_name = param_names[i]
if isinstance(arg_node, ast.Constant):
value = arg_node.value
inferred_types[param_name] = _infer_type_from_value(
value, prefer_float
)
elif isinstance(arg_node, ast.Name):
for workunit in self.pk_members.pk_workunits.values():
for workunit_arg in workunit.args.args:
if (
workunit_arg.arg == arg_node.id
and workunit_arg.annotation
):
ann = workunit_arg.annotation
if isinstance(ann, ast.Name):
inferred_types[param_name] = ann.id
elif (
isinstance(ann, ast.Attribute)
and isinstance(ann.value, ast.Name)
and ann.value.id == self.pk_import
):
inferred_types[param_name] = ann.attr
break
if inferred_types:
self.parser.fix_function_types(functiondef, inferred_types)

# The visitor might add views declared as parameters
views = copy.deepcopy(self.pk_members.views)

Expand Down
66 changes: 45 additions & 21 deletions pykokkos/core/type_inference/args_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,53 @@ class UpdatedDecorator:
"float32",
]

# Cache for original argument nodes: Maps stringified workunit reference, e.g str(workunit_name), to the original ast.arguments node
# Map from Python type names to type strings, using DataType enum for validation
PYTHON_TO_TYPE_STR = {
"int": "int",
"float": DataType.double.name,
"bool": DataType.bool.name,
}

ORIGINAL_PARAMS: Dict[str, ast.arguments] = {}


def _infer_type_from_value(value, prefer_float: bool = False) -> str:
"""
Infer the type string from a Python value, reusing the same logic as infer_other_args.
Uses DataType enum and SUPPORTED_NP_DTYPES for consistency.

:param value: The Python value to infer type from
:param prefer_float: If True and value is int, prefer double over int
:returns: Type string in the format used by type inference (e.g., "int", "double", "numpy:int64")
"""
param_type = type(value).__name__

if param_type == "int":
if value.bit_length() > 31:
param_type = f"numpy:{DataType.int64.name}"
elif prefer_float:
param_type = DataType.double.name
else:
param_type = PYTHON_TO_TYPE_STR["int"]
elif param_type == "float":
param_type = PYTHON_TO_TYPE_STR["float"]
elif param_type == "bool":
param_type = PYTHON_TO_TYPE_STR["bool"]
else:
pckg_name = type(value).__module__
if pckg_name == "numpy":
if param_type not in SUPPORTED_NP_DTYPES:
err_str = f"Numpy type {param_type} is unsupported"
raise TypeError(err_str)
if param_type == DataType.float64.name:
param_type = DataType.double.name
elif param_type == DataType.float32.name:
param_type = DataType.float.name
param_type = pckg_name + ":" + param_type

return param_type


def check_missing_annotations(param_list: List[ast.arg]) -> bool:
"""
Check if any annotation node for parent argument node is none
Expand Down Expand Up @@ -280,26 +323,7 @@ def infer_other_args(
if param.annotation is not None:
continue

param_type = type(value).__name__

# switch integer values over 31 bits (signed positive value) to numpy:int64
if param_type == "int" and value.bit_length() > 31:
param_type = "numpy:int64"

# check if package name is numpy (handling numpy primitives)
pckg_name = type(value).__module__

if pckg_name == "numpy":
if param_type not in SUPPORTED_NP_DTYPES:
err_str = f"Numpy type {param_type} is unsupported"
raise TypeError(err_str)

if param_type == "float64":
param_type = "double"
if param_type == "float32":
param_type = "float"
# numpy:<type>, Will switch to pk.<type> in parser.fix_types
param_type = pckg_name + ":" + param_type
param_type = _infer_type_from_value(value)

if isinstance(value, ViewType):
view_dtype = get_pk_datatype(value.dtype)
Expand Down
5 changes: 0 additions & 5 deletions pykokkos/core/visitors/kokkosfunction_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,4 @@ def is_valid_kokkos_function(self, node) -> bool:
# Is the return type annotation missing
if node.returns is None:
return False

# Is the type annotation for any argument missing (excluding self)
if any(arg.annotation is None and arg.arg != "self" for arg in node.args.args):
return False

return True
Loading