Skip to content

RFC: add scalar support to xp.searchsorted #967

@ogrisel

Description

@ogrisel
import array_api_strict as xp

a = xp.arange(10, device=xp.Device("device1"))
xp.searchsorted(a, 42)

raises:

Traceback (most recent call last):
  Cell In[5], line 5
    xp.searchsorted(a, 42)
  File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_flags.py:395 in wrapper
    return func(*args, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_searching_functions.py:78 in searchsorted
    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
AttributeError: 'int' object has no attribute 'dtype'

This is a bit annoying as it requires to instead write: xp.searchsorted(a, xp.asarray(42, device=a.device)) which feels unnecessarily verbose.

Even PyTorch accepts the following without complaining, for instance:

import array_api_compat.torch as xp

a = xp.arange(10, device="mps")
xp.searchsorted(a, 42)

However, the SPEC does not mention Python scalar support explicitly, so maybe it would need to be updated first?

https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#searchsorted

Metadata

Metadata

Assignees

No one assigned

    Labels

    API changeChanges to existing functions or objects in the API.RFCRequest for comments. Feature requests and proposed changes.

    Type

    No type

    Projects

    Status

    Stage 3

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions