Skip to content

Commit 6b4223e

Browse files
committed
Require NumPy >= 2.1
Fixes #21
1 parent 93201f1 commit 6b4223e

File tree

9 files changed

+37
-94
lines changed

9 files changed

+37
-94
lines changed

.github/workflows/array-api-tests.yml

+2-5
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
python-version: ['3.9', '3.10', '3.11', '3.12']
15-
numpy-version: ['1.26', 'dev']
16-
exclude:
17-
- python-version: '3.8'
18-
numpy-version: 'dev'
15+
numpy-version: ['2.1', 'dev']
1916

2017
steps:
2118
- name: Checkout array-api-strict
@@ -38,7 +35,7 @@ jobs:
3835
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
3936
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
4037
else
41-
python -m pip install 'numpy>=1.26,<2.0';
38+
python -m pip install 'numpy==${{ matrix.numpy-version }}';
4239
fi
4340
python -m pip install ${GITHUB_WORKSPACE}/array-api-strict
4441
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt

.github/workflows/tests.yml

+2-5
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@ jobs:
66
strategy:
77
matrix:
88
python-version: ['3.9', '3.10', '3.11', '3.12']
9-
numpy-version: ['1.26', 'dev']
10-
exclude:
11-
- python-version: '3.8'
12-
numpy-version: 'dev'
9+
numpy-version: ['2.1', 'dev']
1310
fail-fast: true
1411
steps:
1512
- uses: actions/checkout@v4
@@ -22,7 +19,7 @@ jobs:
2219
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
2320
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
2421
else
25-
python -m pip install 'numpy>=1.26,<2.0';
22+
python -m pip install 'numpy==${{ matrix.numpy-version }}';
2623
fi
2724
python -m pip install -r requirements-dev.txt
2825
- name: Run Tests

array_api_strict/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
1717
"""
1818

19+
import numpy as np
20+
from numpy.lib import NumpyVersion
21+
22+
if NumpyVersion(np.__version__) < NumpyVersion('2.1.0'):
23+
raise ImportError("array-api-strict requires NumPy >= 2.1.0")
24+
1925
__all__ = []
2026

2127
# Warning: __array_api_version__ could change globally with

array_api_strict/_array_object.py

+9-31
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
162162
if _allow_array:
163163
if self._device != CPU_DEVICE:
164164
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
165-
# copy keyword is new in 2.0.0; for older versions don't use it
166-
# retry without that keyword.
167-
if np.__version__[0] < '2':
168-
return np.asarray(self._array, dtype=dtype)
169-
elif np.__version__.startswith('2.0.0-dev0'):
170-
# Handle dev version for which we can't know based on version
171-
# number whether or not the copy keyword is supported.
172-
try:
173-
return np.asarray(self._array, dtype=dtype, copy=copy)
174-
except TypeError:
175-
return np.asarray(self._array, dtype=dtype)
176-
else:
177-
return np.asarray(self._array, dtype=dtype, copy=copy)
165+
return np.asarray(self._array, dtype=dtype, copy=copy)
178166
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")
179167

180168
# These are various helper functions to make the array behavior match the
@@ -586,24 +574,14 @@ def __dlpack__(
586574
if copy is not _default:
587575
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
588576

589-
if np.__version__[0] < '2.1':
590-
if max_version not in [_default, None]:
591-
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
592-
if dl_device not in [_default, None]:
593-
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
594-
if copy not in [_default, None]:
595-
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
596-
597-
return self._array.__dlpack__(stream=stream)
598-
else:
599-
kwargs = {'stream': stream}
600-
if max_version is not _default:
601-
kwargs['max_version'] = max_version
602-
if dl_device is not _default:
603-
kwargs['dl_device'] = dl_device
604-
if copy is not _default:
605-
kwargs['copy'] = copy
606-
return self._array.__dlpack__(**kwargs)
577+
kwargs = {'stream': stream}
578+
if max_version is not _default:
579+
kwargs['max_version'] = max_version
580+
if dl_device is not _default:
581+
kwargs['dl_device'] = dl_device
582+
if copy is not _default:
583+
kwargs['copy'] = copy
584+
return self._array.__dlpack__(**kwargs)
607585

608586
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
609587
"""

array_api_strict/_creation_functions.py

-23
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,6 @@ def asarray(
8383
if isinstance(obj, Array) and device is None:
8484
device = obj.device
8585

86-
if np.__version__[0] < '2':
87-
if copy is False:
88-
# Note: copy=False is not yet implemented in np.asarray for
89-
# NumPy 1
90-
91-
# Work around it by creating the new array and seeing if NumPy
92-
# copies it.
93-
if isinstance(obj, Array):
94-
new_array = np.array(obj._array, copy=copy, dtype=_np_dtype)
95-
if new_array is not obj._array:
96-
raise ValueError("Unable to avoid copy while creating an array from given array.")
97-
return Array._new(new_array, device=device)
98-
elif _supports_buffer_protocol(obj):
99-
# Buffer protocol will always support no-copy
100-
return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device)
101-
else:
102-
# No-copy is unsupported for Python built-in types.
103-
raise ValueError("Unable to avoid copy while creating an array from given object.")
104-
105-
if copy is None:
106-
# NumPy 1 treats copy=False the same as the standard copy=None
107-
copy = False
108-
10986
if isinstance(obj, Array):
11087
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device)
11188
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):

array_api_strict/tests/test_array_object.py

+15-27
Original file line numberDiff line numberDiff line change
@@ -456,31 +456,19 @@ def dlpack_2023_12(api_version):
456456
set_array_api_strict_flags(api_version=api_version)
457457

458458
a = asarray([1, 2, 3], dtype=int8)
459-
# Never an error
460-
a.__dlpack__()
461-
462459

463-
if np.__version__ < '2.1':
464-
exception = NotImplementedError if api_version >= '2023.12' else ValueError
465-
pytest.raises(exception, lambda:
466-
a.__dlpack__(dl_device=CPU_DEVICE))
467-
pytest.raises(exception, lambda:
468-
a.__dlpack__(dl_device=None))
469-
pytest.raises(exception, lambda:
470-
a.__dlpack__(max_version=(1, 0)))
471-
pytest.raises(exception, lambda:
472-
a.__dlpack__(max_version=None))
473-
pytest.raises(exception, lambda:
474-
a.__dlpack__(copy=False))
475-
pytest.raises(exception, lambda:
476-
a.__dlpack__(copy=True))
477-
pytest.raises(exception, lambda:
478-
a.__dlpack__(copy=None))
479-
else:
480-
a.__dlpack__(dl_device=CPU_DEVICE)
481-
a.__dlpack__(dl_device=None)
482-
a.__dlpack__(max_version=(1, 0))
483-
a.__dlpack__(max_version=None)
484-
a.__dlpack__(copy=False)
485-
a.__dlpack__(copy=True)
486-
a.__dlpack__(copy=None)
460+
# Do not error
461+
a.__dlpack__()
462+
a.__dlpack__(dl_device=CPU_DEVICE)
463+
a.__dlpack__(dl_device=None)
464+
a.__dlpack__(max_version=(1, 0))
465+
a.__dlpack__(max_version=None)
466+
a.__dlpack__(copy=False)
467+
a.__dlpack__(copy=True)
468+
a.__dlpack__(copy=None)
469+
470+
x = np.from_dlpack(a)
471+
assert isinstance(x, np.ndarray)
472+
assert x.dtype == np.int8
473+
assert x.shape == (3,)
474+
assert np.all(x == np.asarray([1, 2, 3]))

requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
pytest
22
hypothesis
3-
numpy
3+
numpy>=2.1

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
numpy
1+
numpy>=2.1

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
url="https://data-apis.org/array-api-strict/",
1717
license="MIT",
1818
python_requires=">=3.9",
19-
install_requires=["numpy"],
19+
install_requires=["numpy>=2.1"],
2020
classifiers=[
2121
"Programming Language :: Python :: 3",
2222
"Programming Language :: Python :: 3.9",

0 commit comments

Comments
 (0)