Skip to content

Commit a319fc4

Browse files
authored
Merge pull request #6 from siyi-hu/reduce-api-p1
Merge recent fix according to review comments
2 parents 4e9f269 + 0102974 commit a319fc4

File tree

3 files changed

+14
-60
lines changed

3 files changed

+14
-60
lines changed

src/finch/algebra/algebra.py

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def return_type(op: Any, *args: Any) -> Any:
179179
operator.and_: ("__and__", "__rand__"),
180180
operator.xor: ("__xor__", "__rxor__"),
181181
operator.or_: ("__or__", "__ror__"),
182+
min: ("__add__", "__radd__"),
183+
max: ("__sub__", "__rsub__"),
182184
}
183185

184186

@@ -211,21 +213,6 @@ def _return_type_closure(a, b):
211213
register_property(T, meth, "return_type", _return_type_reflexive(meth))
212214
register_property(T, rmeth, "return_type", _return_type_reflexive(rmeth))
213215

214-
register_property(
215-
min,
216-
"__call__",
217-
"return_type",
218-
lambda op, a, b: query_property(a, "__add__", "return_type", b),
219-
)
220-
register_property(
221-
max,
222-
"__call__",
223-
"return_type",
224-
lambda op, a, b: query_property(a, "__add__", "return_type", b),
225-
)
226-
register_property(any, "__call__", "return_type", lambda op, a, b: bool)
227-
register_property(all, "__call__", "return_type", lambda op, a, b: bool)
228-
229216

230217
_unary_operators: dict[Callable, str] = {
231218
operator.abs: "__abs__",
@@ -421,34 +408,6 @@ def init_value(op, arg) -> Any:
421408
register_property(T, "__or__", "init_value", lambda a, b: a(False))
422409

423410

424-
def _min_init(arg):
425-
dtype = np.dtype(arg) if isinstance(arg, np.ndarray) else np.dtype(arg)
426-
if np.issubdtype(dtype, np.floating):
427-
return math.inf
428-
if np.issubdtype(dtype, np.integer):
429-
return np.iinfo(dtype).max
430-
if np.issubdtype(dtype, np.bool_):
431-
return True
432-
raise TypeError("Unsupported dtype for min")
433-
434-
435-
def _max_init(arg):
436-
dtype = np.dtype(arg) if isinstance(arg, np.ndarray) else np.dtype(arg)
437-
if np.issubdtype(dtype, np.floating):
438-
return -math.inf
439-
if np.issubdtype(dtype, np.integer):
440-
return np.iinfo(dtype).min
441-
if np.issubdtype(dtype, np.bool_):
442-
return False
443-
raise TypeError("Unsupported dtype for max")
444-
445-
446-
register_property(
447-
min, "__call__", "init_value", lambda op, arg: _min_init(element_type(arg))
448-
)
449-
register_property(
450-
max, "__call__", "init_value", lambda op, arg: _max_init(element_type(arg))
451-
)
452-
register_property(any, "__call__", "init_value", lambda op, arg: False)
453-
register_property(all, "__call__", "init_value", lambda op, arg: True)
411+
register_property(min, "__call__", "init_value", lambda op, arg: math.inf)
412+
register_property(max, "__call__", "init_value", lambda op, arg: -math.inf)
454413
register_property(operator.truth, "__call__", "init_value", lambda op, arg: True)

src/finch/interface/lazy.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,7 @@ def any(
496496
"""
497497
Test whether any element of input array ``x`` along given axis is True.
498498
"""
499-
if init is None:
500-
init = init_value(builtins.any, x)
499+
x = defer(x)
501500
return reduce(
502501
operator.or_,
503502
elementwise(operator.truth, x),
@@ -518,8 +517,7 @@ def all(
518517
"""
519518
Test whether all elements of input array ``x`` along given axis are True.
520519
"""
521-
if init is None:
522-
init = init_value(builtins.all, x)
520+
x = defer(x)
523521
return reduce(
524522
operator.and_,
525523
elementwise(operator.truth, x),
@@ -540,8 +538,7 @@ def min(
540538
"""
541539
Return the minimum of input array ``arr`` along given axis.
542540
"""
543-
if init is None:
544-
init = init_value(builtins.min, x)
541+
x = defer(x)
545542
return reduce(builtins.min, x, axis=axis, keepdims=keepdims, init=init)
546543

547544

@@ -556,6 +553,5 @@ def max(
556553
"""
557554
Return the maximum of input array ``arr`` along given axis.
558555
"""
559-
if init is None:
560-
init = init_value(builtins.max, x)
556+
x = defer(x)
561557
return reduce(builtins.max, x, axis=axis, keepdims=keepdims, init=init)

tests/test_reduce_api_funcs.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def output_term(args):
2626
@pytest.mark.parametrize(
2727
"finch_op, np_op",
2828
[
29-
(finch.prod, np.prod),
30-
# (finch.sum, np.sum),
3129
(finch.any, np.any),
3230
(finch.all, np.all),
3331
(finch.min, np.min),
@@ -43,7 +41,7 @@ def output_term(args):
4341
(0, 1),
4442
],
4543
)
46-
def test_reduction_api_bitwise(a, finch_op, np_op, axis):
44+
def test_reduction_api_boolean(a, finch_op, np_op, axis):
4745
result = finch_op(a, axis=axis)
4846
expected = np_op(a, axis=axis)
4947
assert_equal(result, expected)
@@ -52,10 +50,10 @@ def test_reduction_api_bitwise(a, finch_op, np_op, axis):
5250
@pytest.mark.parametrize(
5351
"a",
5452
[
55-
(np.array([[2, 0], [1, 3]])),
56-
(np.array([[2, 3, 4], [5, 6, 7]])),
53+
(np.array([[2, 0], [-1, 3]])),
54+
(np.array([[2, 3, 4], [5, -6, 7]])),
5755
(np.array([[1, 0, 3, 8], [0, 0, 10, 0]])),
58-
(np.array([[100, 14, 9, 78], [44, 3, 5, 10]])),
56+
(np.array([[100, -14, 9, 78], [-44, 3, 5, 10]])),
5957
],
6058
)
6159
@pytest.mark.parametrize(
@@ -82,8 +80,9 @@ def test_reduction_api_integer(a, finch_op, np_op, axis):
8280
@pytest.mark.parametrize(
8381
"a",
8482
[
83+
(np.array([[1.00002, -12.618, 0, 0.001], [-1.414, -5.01, 0, 0]])),
8584
(np.array([[0, 0.618, 0, 0.001], [0, 0.01, 0, 0]])),
86-
(np.array([[10000.0, 1.0, 89.0, 78], [401.0, 3, 5, 10.2]])),
85+
(np.array([[10000.0, 1.0, -89.0, 78], [401.0, 3, 5, 10.2]])),
8786
],
8887
)
8988
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)