Skip to content

Commit b42c805

Browse files
Add support for passing pytorch tensors directly to kernel calls (#277)
* parallel_dispatch: allow passin pytorch tensors directly to kernel calls * add generic errors for interfaces * fix error message for views and other types * dispatch: update checking types * fix doc * formatting --------- Co-authored-by: Ivan Grigorik <givan502@gmail.com>
1 parent 9720e2f commit b42c805

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

pykokkos/interface/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
)
5858

5959
from .ext_module import compile_into_module
60+
from .interface_util import generic_error
6061

6162
def fence():
6263
pass
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import sys
2+
3+
4+
def get_lineno(frame):
5+
"""
6+
Return current line number `inspect` frame
7+
"""
8+
return frame.f_lineno
9+
10+
11+
def get_filename(frame):
12+
"""
13+
Return current line number `inspect` frame
14+
"""
15+
return frame.f_code.co_filename
16+
17+
18+
def generic_error(filename: str, lineno: str | int, error: str, exit_message: str):
19+
print(f"\n\033[31m\033[01mError {filename}:{lineno}\033[0m: {error}")
20+
sys.exit(f"PyKokkos: {exit_message}")

pykokkos/interface/parallel_dispatch.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from .execution_space import ExecutionSpace
1212
from .views import ViewType, array
1313

14+
from .interface_util import generic_error, get_filename, get_lineno
15+
16+
import inspect
17+
1418
workunit_cache: Dict[int, Callable] = {}
1519

1620

@@ -115,24 +119,42 @@ def check_workunit(workunit: Any) -> None:
115119

116120
def convert_arrays(kwargs: Dict[str, Any]) -> None:
117121
"""
118-
Convert all numpy and cupy ndarray objects into pk Views
122+
Convert all numpy, cupy and pytorch ndarray objects into pk Views
119123
120124
:param kwargs: the list of keyword arguments passed to the workunit
121125
"""
122126

123127
cp_available: bool
128+
torch_available: bool
124129

125130
try:
126131
import cupy as cp
127132
cp_available = True
128133
except ImportError:
129134
cp_available = False
130135

136+
try:
137+
import torch
138+
torch_available = True
139+
except ImportError:
140+
torch_available = False
141+
131142
for k, v in kwargs.items():
132-
if isinstance(v, np.ndarray):
143+
if isinstance(v, ViewType) or isinstance(v, np.generic):
144+
continue
145+
elif isinstance(v, np.ndarray):
133146
kwargs[k] = array(v)
134147
elif cp_available and isinstance(v, cp.ndarray):
135148
kwargs[k] = array(v)
149+
elif torch_available and torch.is_tensor(v):
150+
kwargs[k] = array(v)
151+
elif hasattr(v, '__array__') or hasattr(v, '__cuda_array_interface__') or hasattr(v, '__array_interface__'):
152+
# This is some array-like object we don't support
153+
caller_frame = inspect.currentframe().f_back.f_back
154+
filename = get_filename(caller_frame)
155+
lineno = get_lineno(caller_frame)
156+
msg = f"Type {type(v)} is not supported. Only numpy arrays, cupy arrays, and torch tensors are supported."
157+
generic_error(filename, lineno, msg, "Conversion failed")
136158

137159

138160
def parallel_for(*args, **kwargs) -> None:

0 commit comments

Comments
 (0)