Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)

from .ext_module import compile_into_module
from .interface_util import generic_error

def fence():
pass
Expand Down
20 changes: 20 additions & 0 deletions pykokkos/interface/interface_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import sys


def get_lineno(frame):
"""
Return current line number `inspect` frame
"""
return frame.f_lineno


def get_filename(frame):
"""
Return current line number `inspect` frame
"""
return frame.f_code.co_filename


def generic_error(filename: str, lineno: str | int, error: str, exit_message: str):
print(f"\n\033[31m\033[01mError {filename}:{lineno}\033[0m: {error}")
sys.exit(f"PyKokkos: {exit_message}")
26 changes: 24 additions & 2 deletions pykokkos/interface/parallel_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from .execution_space import ExecutionSpace
from .views import ViewType, array

from .interface_util import generic_error, get_filename, get_lineno

import inspect

workunit_cache: Dict[int, Callable] = {}


Expand Down Expand Up @@ -115,24 +119,42 @@ def check_workunit(workunit: Any) -> None:

def convert_arrays(kwargs: Dict[str, Any]) -> None:
"""
Convert all numpy and cupy ndarray objects into pk Views
Convert all numpy, cupy and pytorch ndarray objects into pk Views

:param kwargs: the list of keyword arguments passed to the workunit
"""

cp_available: bool
torch_available: bool

try:
import cupy as cp
cp_available = True
except ImportError:
cp_available = False

try:
import torch
torch_available = True
except ImportError:
torch_available = False

for k, v in kwargs.items():
if isinstance(v, np.ndarray):
if isinstance(v, ViewType) or isinstance(v, np.generic):
continue
elif isinstance(v, np.ndarray):
kwargs[k] = array(v)
elif cp_available and isinstance(v, cp.ndarray):
kwargs[k] = array(v)
elif torch_available and torch.is_tensor(v):
kwargs[k] = array(v)
elif hasattr(v, '__array__') or hasattr(v, '__cuda_array_interface__') or hasattr(v, '__array_interface__'):
# This is some array-like object we don't support
caller_frame = inspect.currentframe().f_back.f_back
filename = get_filename(caller_frame)
lineno = get_lineno(caller_frame)
msg = f"Type {type(v)} is not supported. Only numpy arrays, cupy arrays, and torch tensors are supported."
generic_error(filename, lineno, msg, "Conversion failed")


def parallel_for(*args, **kwargs) -> None:
Expand Down