@@ -96,12 +96,60 @@ def test_promoted_scalar_inherits_device():
96
96
97
97
assert y .device == device1
98
98
99
+
100
+ BIG_INT = int (1e30 )
101
+
102
+ def _check_op_array_scalar (dtypes , a , s , func , func_name , BIG_INT = BIG_INT ):
103
+ # Test array op scalar. From the spec, the following combinations
104
+ # are supported:
105
+
106
+ # - Python bool for a bool array dtype,
107
+ # - a Python int within the bounds of the given dtype for integer array dtypes,
108
+ # - a Python int or float for real floating-point array dtypes
109
+ # - a Python int, float, or complex for complex floating-point array dtypes
110
+
111
+ if ((dtypes == "all"
112
+ or dtypes == "numeric" and a .dtype in _numeric_dtypes
113
+ or dtypes == "real numeric" and a .dtype in _real_numeric_dtypes
114
+ or dtypes == "integer" and a .dtype in _integer_dtypes
115
+ or dtypes == "integer or boolean" and a .dtype in _integer_or_boolean_dtypes
116
+ or dtypes == "boolean" and a .dtype in _boolean_dtypes
117
+ or dtypes == "floating-point" and a .dtype in _floating_dtypes
118
+ or dtypes == "real floating-point" and a .dtype in _real_floating_dtypes
119
+ )
120
+ # bool is a subtype of int, which is why we avoid
121
+ # isinstance here.
122
+ and (a .dtype in _boolean_dtypes and type (s ) == bool
123
+ or a .dtype in _integer_dtypes and type (s ) == int
124
+ or a .dtype in _real_floating_dtypes and type (s ) in [float , int ]
125
+ or a .dtype in _complex_floating_dtypes and type (s ) in [complex , float , int ]
126
+ )):
127
+ if a .dtype in _integer_dtypes and s == BIG_INT :
128
+ with assert_raises (OverflowError ):
129
+ func (s )
130
+ return False
131
+
132
+ else :
133
+ # Only test for no error
134
+ with suppress_warnings () as sup :
135
+ # ignore warnings from pow(BIG_INT)
136
+ sup .filter (RuntimeWarning ,
137
+ "invalid value encountered in power" )
138
+ func (s )
139
+ return True
140
+
141
+ else :
142
+ with assert_raises (TypeError ):
143
+ func (s )
144
+ return False
145
+
146
+
99
147
def test_operators ():
100
148
# For every operator, we test that it works for the required type
101
149
# combinations and raises TypeError otherwise
102
150
binary_op_dtypes = {
103
151
"__add__" : "numeric" ,
104
- "__and__" : "integer_or_boolean " ,
152
+ "__and__" : "integer or boolean " ,
105
153
"__eq__" : "all" ,
106
154
"__floordiv__" : "real numeric" ,
107
155
"__ge__" : "real numeric" ,
@@ -112,12 +160,12 @@ def test_operators():
112
160
"__mod__" : "real numeric" ,
113
161
"__mul__" : "numeric" ,
114
162
"__ne__" : "all" ,
115
- "__or__" : "integer_or_boolean " ,
163
+ "__or__" : "integer or boolean " ,
116
164
"__pow__" : "numeric" ,
117
165
"__rshift__" : "integer" ,
118
166
"__sub__" : "numeric" ,
119
- "__truediv__" : "floating" ,
120
- "__xor__" : "integer_or_boolean " ,
167
+ "__truediv__" : "floating-point " ,
168
+ "__xor__" : "integer or boolean " ,
121
169
}
122
170
# Recompute each time because of in-place ops
123
171
def _array_vals ():
@@ -128,8 +176,6 @@ def _array_vals():
128
176
for d in _floating_dtypes :
129
177
yield asarray (1.0 , dtype = d )
130
178
131
-
132
- BIG_INT = int (1e30 )
133
179
for op , dtypes in binary_op_dtypes .items ():
134
180
ops = [op ]
135
181
if op not in ["__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ]:
@@ -139,40 +185,7 @@ def _array_vals():
139
185
for s in [1 , 1.0 , 1j , BIG_INT , False ]:
140
186
for _op in ops :
141
187
for a in _array_vals ():
142
- # Test array op scalar. From the spec, the following combinations
143
- # are supported:
144
-
145
- # - Python bool for a bool array dtype,
146
- # - a Python int within the bounds of the given dtype for integer array dtypes,
147
- # - a Python int or float for real floating-point array dtypes
148
- # - a Python int, float, or complex for complex floating-point array dtypes
149
-
150
- if ((dtypes == "all"
151
- or dtypes == "numeric" and a .dtype in _numeric_dtypes
152
- or dtypes == "real numeric" and a .dtype in _real_numeric_dtypes
153
- or dtypes == "integer" and a .dtype in _integer_dtypes
154
- or dtypes == "integer_or_boolean" and a .dtype in _integer_or_boolean_dtypes
155
- or dtypes == "boolean" and a .dtype in _boolean_dtypes
156
- or dtypes == "floating" and a .dtype in _floating_dtypes
157
- )
158
- # bool is a subtype of int, which is why we avoid
159
- # isinstance here.
160
- and (a .dtype in _boolean_dtypes and type (s ) == bool
161
- or a .dtype in _integer_dtypes and type (s ) == int
162
- or a .dtype in _real_floating_dtypes and type (s ) in [float , int ]
163
- or a .dtype in _complex_floating_dtypes and type (s ) in [complex , float , int ]
164
- )):
165
- if a .dtype in _integer_dtypes and s == BIG_INT :
166
- assert_raises (OverflowError , lambda : getattr (a , _op )(s ))
167
- else :
168
- # Only test for no error
169
- with suppress_warnings () as sup :
170
- # ignore warnings from pow(BIG_INT)
171
- sup .filter (RuntimeWarning ,
172
- "invalid value encountered in power" )
173
- getattr (a , _op )(s )
174
- else :
175
- assert_raises (TypeError , lambda : getattr (a , _op )(s ))
188
+ _check_op_array_scalar (dtypes , a , s , getattr (a , _op ), _op )
176
189
177
190
# Test array op array.
178
191
for _op in ops :
@@ -203,18 +216,18 @@ def _array_vals():
203
216
or (dtypes == "real numeric" and x .dtype in _real_numeric_dtypes and y .dtype in _real_numeric_dtypes )
204
217
or (dtypes == "numeric" and x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
205
218
or dtypes == "integer" and x .dtype in _integer_dtypes and y .dtype in _integer_dtypes
206
- or dtypes == "integer_or_boolean " and (x .dtype in _integer_dtypes and y .dtype in _integer_dtypes
219
+ or dtypes == "integer or boolean " and (x .dtype in _integer_dtypes and y .dtype in _integer_dtypes
207
220
or x .dtype in _boolean_dtypes and y .dtype in _boolean_dtypes )
208
221
or dtypes == "boolean" and x .dtype in _boolean_dtypes and y .dtype in _boolean_dtypes
209
- or dtypes == "floating" and x .dtype in _floating_dtypes and y .dtype in _floating_dtypes
222
+ or dtypes == "floating-point " and x .dtype in _floating_dtypes and y .dtype in _floating_dtypes
210
223
):
211
224
getattr (x , _op )(y )
212
225
else :
213
226
assert_raises (TypeError , lambda : getattr (x , _op )(y ))
214
227
215
228
unary_op_dtypes = {
216
229
"__abs__" : "numeric" ,
217
- "__invert__" : "integer_or_boolean " ,
230
+ "__invert__" : "integer or boolean " ,
218
231
"__neg__" : "numeric" ,
219
232
"__pos__" : "numeric" ,
220
233
}
@@ -223,7 +236,7 @@ def _array_vals():
223
236
if (
224
237
dtypes == "numeric"
225
238
and a .dtype in _numeric_dtypes
226
- or dtypes == "integer_or_boolean "
239
+ or dtypes == "integer or boolean "
227
240
and a .dtype in _integer_or_boolean_dtypes
228
241
):
229
242
# Only test for no error
0 commit comments