Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor onedal interaction with backend and policies #2168

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6daef75
feat: Introduce BackendManager and PolicyManger
ahuber21 Nov 14, 2024
ec58015
fixup
ahuber21 Nov 23, 2024
1ddcc34
Remove _get_policy() and use queue directly instead
ahuber21 Nov 25, 2024
57155d1
Only use a global queue, which is set by user-facing functions
ahuber21 Nov 26, 2024
f1ec967
remove policy manager
ahuber21 Nov 26, 2024
a360979
wip: fixes related to global queue
ahuber21 Nov 27, 2024
86d1fbd
fixup is_cpu
ahuber21 Nov 27, 2024
6c3b242
fixup dispatch
ahuber21 Nov 27, 2024
2001b48
fixup queue as kwarg, is_gpu
ahuber21 Nov 28, 2024
089d23d
handle SUA interface errors
ahuber21 Nov 28, 2024
aadbd72
fix BackendFucntion in kernel_functions.py
ahuber21 Nov 29, 2024
4239025
undo accidential changes to tests
ahuber21 Nov 29, 2024
44ba5e0
fixup delete _policy.py; fix assert_all_finite from latest main
ahuber21 Nov 29, 2024
1a7dadc
remove utils/__init__.py
ahuber21 Dec 10, 2024
c67fb22
Merge remote-tracking branch 'origin/main' into dev/ahuber/refactor-o…
ahuber21 Dec 10, 2024
81f8285
fix some errors after validation cleanup
ahuber21 Dec 10, 2024
e040cdd
compare only non-cpu devices
ahuber21 Dec 10, 2024
511b44f
fix after merging main
ahuber21 Dec 11, 2024
b39b852
simplify SyclQueue
ahuber21 Dec 11, 2024
a5cb819
further simplify and align SyclQueue handling
ahuber21 Dec 11, 2024
891700f
fix missing return
ahuber21 Dec 11, 2024
a23ea0d
remove intermediate SyclQueue class
ahuber21 Dec 11, 2024
c7eef38
introduce manage_global_queue context manager
ahuber21 Dec 12, 2024
50f6d1b
bring back underscore methods
ahuber21 Dec 12, 2024
d5e9425
kmeans init compute_raw does not support queue
ahuber21 Dec 12, 2024
35a344c
cleanup @support_input_format
ahuber21 Dec 12, 2024
7168aa2
allow for onedal prefix in patching check
ahuber21 Dec 13, 2024
f0ca14b
add missing wraps(func)
ahuber21 Dec 13, 2024
e304dae
Merge remote-tracking branch 'origin/main' into dev/ahuber/refactor-o…
ahuber21 Dec 13, 2024
720e447
fixup
ahuber21 Dec 13, 2024
ff20f76
fix neighbors
ahuber21 Dec 16, 2024
554eb38
Merge remote-tracking branch 'origin/main' into dev/ahuber/refactor-o…
ahuber21 Dec 16, 2024
6307b24
debug print
ahuber21 Dec 17, 2024
e5d109c
debug output
ahuber21 Dec 17, 2024
039d89e
Add new logic for sparse matrix in _copy_to_usm
ahuber21 Dec 17, 2024
2f37e08
only _copy_to_usm with usm_iface
ahuber21 Dec 17, 2024
e40cb3f
lint
ahuber21 Dec 17, 2024
ab0dea4
Merge remote-tracking branch 'origin/main' into dev/ahuber/refactor-o…
ahuber21 Dec 17, 2024
cbf0d35
remove accidental import
ahuber21 Dec 17, 2024
e674f40
Merge remote-tracking branch 'origin/main' into dev/ahuber/refactor-o…
ahuber21 Feb 20, 2025
c80c972
fix some tests
ahuber21 Feb 21, 2025
2920ae4
properly use queue from data in sklearnex
ahuber21 Feb 21, 2025
fd58490
fixup
ahuber21 Feb 21, 2025
0ca7107
fixup
ahuber21 Feb 21, 2025
07ae0d2
use utf-8 encoding
ahuber21 Feb 24, 2025
f4fc506
use queue from data for finiteness check
ahuber21 Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
only _copy_to_usm with usm_iface
ahuber21 committed Dec 17, 2024
commit 2f37e085b6e97139d2576b00910f3c636e7dd14b
15 changes: 2 additions & 13 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
@@ -187,21 +187,11 @@ def wrapper(self, *args, **kwargs):


def _copy_to_usm(queue, array):
print(f"_copy_to_usm: {type(array)=}")
if shape := getattr(array, "shape", None):
print(f"_copy_to_usm: {shape=}")
print(f"_copy_to_usm: array=<{array}>")
if not dpctl_available:
raise RuntimeError(
"dpctl need to be installed to work " "with __sycl_usm_array_interface__"
)

if sp.issparse(array):
data = _copy_to_usm(queue, array.data)
indices = _copy_to_usm(queue, array.indices)
indptr = _copy_to_usm(queue, array.indptr)
return array.__class__((data, indices, indptr), shape=array.shape)

if hasattr(array, "__array__"):

try:
@@ -306,9 +296,8 @@ def wrapper_impl(*args, **kwargs):
hostkwargs["queue"] = queue
result = invoke_func(self, *hostargs, **hostkwargs)

# Is this even required?
# wrap_output_data does copy to device with usm, so are we copying back here?
if queue is not None:
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
if queue is not None and usm_iface is not None:
result = _copy_to_usm(queue, result)
if dpnp_available and isinstance(data[0], dpnp.ndarray):
result = _convert_to_dpnp(result)