Skip to content

Commit 73f6426

Browse files
ENH: astype: add device kwarg (#240)
Co-authored-by: Lucas Colley <[email protected]>
1 parent d6f431d commit 73f6426

12 files changed

+47
-24
lines changed

Diff for: array_api_compat/common/_aliases.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
233233
**kwargs,
234234
)
235235

236-
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
237-
if not copy and dtype == x.dtype:
238-
return x
239-
return x.astype(dtype=dtype, copy=copy)
240-
241236
# These functions have different keyword argument names
242237

243238
def std(
@@ -549,7 +544,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
549544
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
550545
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
551546
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
552-
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
547+
'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
553548
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
554549
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
555550
'unstack', 'sign']

Diff for: array_api_compat/cupy/_aliases.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import cupy as cp
44

5-
from ..common import _aliases
5+
from ..common import _aliases, _helpers
66
from .._internal import get_xp
77

88
from ._info import __array_namespace_info__
@@ -46,7 +46,6 @@
4646
unique_counts = get_xp(cp)(_aliases.unique_counts)
4747
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
4848
unique_values = get_xp(cp)(_aliases.unique_values)
49-
astype = _aliases.astype
5049
std = get_xp(cp)(_aliases.std)
5150
var = get_xp(cp)(_aliases.var)
5251
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
@@ -110,6 +109,21 @@ def asarray(
110109

111110
return cp.array(obj, dtype=dtype, **kwargs)
112111

112+
113+
def astype(
114+
x: ndarray,
115+
dtype: Dtype,
116+
/,
117+
*,
118+
copy: bool = True,
119+
device: Optional[Device] = None,
120+
) -> ndarray:
121+
if device is None:
122+
return x.astype(dtype=dtype, copy=copy)
123+
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
124+
return out.copy() if copy and out is x else out
125+
126+
113127
# These functions are completely new here. If the library already has them
114128
# (i.e., numpy 2.0), use the library version instead of our wrapper.
115129
if hasattr(cp, 'vecdot'):
@@ -127,10 +141,10 @@ def asarray(
127141
else:
128142
unstack = get_xp(cp)(_aliases.unstack)
129143

130-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
144+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
131145
'acos', 'acosh', 'asin', 'asinh', 'atan',
132146
'atan2', 'atanh', 'bitwise_left_shift',
133147
'bitwise_invert', 'bitwise_right_shift',
134-
'concat', 'pow', 'sign']
148+
'bool', 'concat', 'pow', 'sign']
135149

136150
_all_ignore = ['cp', 'get_xp']

Diff for: array_api_compat/dask/array/_aliases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _isscalar(a):
233233

234234
_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
235235

236-
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
236+
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'astype', 'acos',
237237
'acosh', 'asin', 'asinh', 'atan', 'atan2',
238238
'atanh', 'bitwise_left_shift', 'bitwise_invert',
239239
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',

Diff for: array_api_compat/numpy/_aliases.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
unique_counts = get_xp(np)(_aliases.unique_counts)
4747
unique_inverse = get_xp(np)(_aliases.unique_inverse)
4848
unique_values = get_xp(np)(_aliases.unique_values)
49-
astype = _aliases.astype
5049
std = get_xp(np)(_aliases.std)
5150
var = get_xp(np)(_aliases.var)
5251
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
@@ -115,6 +114,18 @@ def asarray(
115114

116115
return np.array(obj, copy=copy, dtype=dtype, **kwargs)
117116

117+
118+
def astype(
119+
x: ndarray,
120+
dtype: Dtype,
121+
/,
122+
*,
123+
copy: bool = True,
124+
device: Optional[Device] = None,
125+
) -> ndarray:
126+
return x.astype(dtype=dtype, copy=copy)
127+
128+
118129
# These functions are completely new here. If the library already has them
119130
# (i.e., numpy 2.0), use the library version instead of our wrapper.
120131
if hasattr(np, 'vecdot'):
@@ -132,10 +143,10 @@ def asarray(
132143
else:
133144
unstack = get_xp(np)(_aliases.unstack)
134145

135-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
146+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
136147
'acos', 'acosh', 'asin', 'asinh', 'atan',
137148
'atan2', 'atanh', 'bitwise_left_shift',
138149
'bitwise_invert', 'bitwise_right_shift',
139-
'concat', 'pow']
150+
'bool', 'concat', 'pow']
140151

141152
_all_ignore = ['np', 'get_xp']

Diff for: array_api_compat/torch/_aliases.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,19 @@ def triu(x: array, /, *, k: int = 0) -> array:
613613
def expand_dims(x: array, /, *, axis: int = 0) -> array:
614614
return torch.unsqueeze(x, axis)
615615

616-
def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
617-
return x.to(dtype, copy=copy)
616+
617+
def astype(
618+
x: array,
619+
dtype: Dtype,
620+
/,
621+
*,
622+
copy: bool = True,
623+
device: Optional[Device] = None,
624+
) -> array:
625+
if device is not None:
626+
return x.to(device, dtype=dtype, copy=copy)
627+
return x.to(dtype=dtype, copy=copy)
628+
618629

619630
def broadcast_arrays(*arrays: array) -> List[array]:
620631
shape = torch.broadcast_shapes(*[a.shape for a in arrays])

Diff for: cupy-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -181,5 +181,4 @@ array_api_tests/test_fft.py::test_irfftn
181181
# cupy.ndaray cannot be specified as `repeats` argument.
182182
array_api_tests/test_manipulation_functions.py::test_repeat
183183
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
184-
array_api_tests/test_signatures.py::test_func_signature[astype]
185184
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]

Diff for: dask-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,3 @@ array_api_tests/test_statistical_functions.py::test_prod
154154
# 2023.12 support
155155
array_api_tests/test_manipulation_functions.py::test_repeat
156156
array_api_tests/test_searching_functions.py::test_searchsorted
157-
array_api_tests/test_signatures.py::test_func_signature[astype]

Diff for: numpy-1-21-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is
254254
# 2023.12 support
255255
array_api_tests/test_searching_functions.py::test_searchsorted
256256
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
257-
array_api_tests/test_signatures.py::test_func_signature[astype]
258257
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
259258
# uint64 repeats not supported
260259
array_api_tests/test_manipulation_functions.py::test_repeat

Diff for: numpy-1-26-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ array_api_tests/test_statistical_functions.py::test_prod
4949
# 2023.12 support
5050
array_api_tests/test_searching_functions.py::test_searchsorted
5151
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
52-
array_api_tests/test_signatures.py::test_func_signature[astype]
5352
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
5453
# uint64 repeats not supported
5554
array_api_tests/test_manipulation_functions.py::test_repeat

Diff for: numpy-dev-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
2020

2121
# 2023.12 support
2222
# Argument 'device' missing from signature
23-
array_api_tests/test_signatures.py::test_func_signature[astype]
2423
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
2524
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
2625
# uint64 repeats not supported

Diff for: numpy-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
4141
# 2023.12 support
4242
array_api_tests/test_searching_functions.py::test_searchsorted
4343
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
44-
array_api_tests/test_signatures.py::test_func_signature[astype]
4544
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
4645
# uint64 repeats not supported
4746
array_api_tests/test_manipulation_functions.py::test_repeat

Diff for: torch-xfails.txt

-2
Original file line numberDiff line numberDiff line change
@@ -202,5 +202,3 @@ array_api_tests/test_signatures.py::test_func_signature[repeat]
202202
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
203203
# Argument 'max_version' missing from signature
204204
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
205-
# Argument 'device' missing from signature
206-
array_api_tests/test_signatures.py::test_func_signature[astype]

0 commit comments

Comments
 (0)