|
4 | 4 |
|
5 | 5 | from collections.abc import Collection
|
6 | 6 |
|
| 7 | +import awkward as ak |
7 | 8 | from awkward._backends.backend import Backend
|
8 | 9 | from awkward._nplikes.numpy import Numpy
|
9 | 10 | from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
|
| 11 | +from awkward._nplikes.virtual import VirtualArray |
10 | 12 | from awkward._typing import Callable, TypeAlias, TypeVar, cast
|
11 | 13 | from awkward._util import UNSET, Sentinel
|
12 | 14 |
|
@@ -70,6 +72,7 @@ def common_backend(backends: Collection[Backend]) -> Backend:
|
70 | 72 |
|
71 | 73 |
|
72 | 74 | def backend_of_obj(obj, default: D | Sentinel = UNSET) -> Backend | D:
|
| 75 | + # the backend of virtual arrays will be determined via the `find_virtual_backend` lookup |
73 | 76 | cls = type(obj)
|
74 | 77 | try:
|
75 | 78 | lookup = _type_to_backend_lookup[cls]
|
@@ -129,3 +132,22 @@ def regularize_backend(backend: str | Backend) -> Backend:
|
129 | 132 | return _name_to_backend_cls[backend].instance()
|
130 | 133 | else:
|
131 | 134 | raise ValueError(f"No such backend {backend!r} exists.")
|
| 135 | + |
| 136 | + |
| 137 | +@register_backend_lookup_factory |
| 138 | +def find_virtual_backend(obj: type): |
| 139 | + """ |
| 140 | + Implements a lookup for finding the backends of virtual arrays. |
| 141 | + This is necessary to avoid calling `isinstance` inside `backend_of_obj` which may cause slowdown. |
| 142 | + """ |
| 143 | + if issubclass(obj, VirtualArray): |
| 144 | + |
| 145 | + def finder(obj: VirtualArray): |
| 146 | + if isinstance(obj.nplike, ak._nplikes.numpy.Numpy): |
| 147 | + return _name_to_backend_cls["cpu"].instance() |
| 148 | + elif isinstance(obj.nplike, ak._nplikes.cupy.Cupy): |
| 149 | + return _name_to_backend_cls["cuda"].instance() |
| 150 | + else: |
| 151 | + raise TypeError("A virtual array can only have numpy or cupy backends") |
| 152 | + |
| 153 | + return finder |
0 commit comments