-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Labels
API changeChanges to existing functions or objects in the API.Changes to existing functions or objects in the API.RFCRequest for comments. Feature requests and proposed changes.Request for comments. Feature requests and proposed changes.
Milestone
Description
import array_api_strict as xp
a = xp.arange(10, device=xp.Device("device1"))
xp.searchsorted(a, 42)raises:
Traceback (most recent call last):
Cell In[5], line 5
xp.searchsorted(a, 42)
File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_flags.py:395 in wrapper
return func(*args, **kwargs)
File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_searching_functions.py:78 in searchsorted
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
AttributeError: 'int' object has no attribute 'dtype'This is a bit annoying as it requires to instead write: xp.searchsorted(a, xp.asarray(42, device=a.device)) which feels unnecessarily verbose.
Even PyTorch accepts the following without complaining, for instance:
import array_api_compat.torch as xp
a = xp.arange(10, device="mps")
xp.searchsorted(a, 42)However, the SPEC does not mention Python scalar support explicitly, so maybe it would need to be updated first?
Metadata
Metadata
Assignees
Labels
API changeChanges to existing functions or objects in the API.Changes to existing functions or objects in the API.RFCRequest for comments. Feature requests and proposed changes.Request for comments. Feature requests and proposed changes.
Type
Projects
Status
Stage 3