Skip to content

Commit

Permalink
Fix mt.{cumsum, cumprod} when the first chunk is empty (#3134)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin committed Jun 12, 2022
1 parent 221e4b3 commit 424cfb9
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 3 deletions.
7 changes: 6 additions & 1 deletion mars/tensor/arithmetic/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import reduce

from ... import opcodes as OperandDef
from ...serialization.serializables import BoolField
from ..array_utils import device, as_same_device
from ..datasource import scalar
from ..utils import infer_dtype
Expand Down Expand Up @@ -89,17 +90,21 @@ class TensorTreeAdd(TensorMultiOp):
_op_type_ = OperandDef.TREE_ADD
_func_name = "add"

ignore_empty_input = BoolField("ignore_empty_input", default=False)

@classmethod
def _is_sparse(cls, *args):
if args and all(hasattr(x, "issparse") and x.issparse() for x in args):
return True
return False

@classmethod
def execute(cls, ctx, op):
def execute(cls, ctx, op: "TensorTreeAdd"):
inputs, device_id, xp = as_same_device(
[ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True
)
if op.ignore_empty_input:
inputs = [inp for inp in inputs if not hasattr(inp, "size") or inp.size > 0]

with device(device_id):
ctx[op.outputs[0].key] = reduce(xp.add, inputs)
Expand Down
5 changes: 5 additions & 0 deletions mars/tensor/arithmetic/multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import reduce

from ... import opcodes as OperandDef
from ...serialization.serializables import BoolField
from ..array_utils import device, as_same_device
from ..datasource import scalar
from ..utils import infer_dtype
Expand Down Expand Up @@ -88,6 +89,8 @@ class TensorTreeMultiply(TensorMultiOp):
_op_type_ = OperandDef.TREE_MULTIPLY
_func_name = "multiply"

ignore_empty_input = BoolField("ignore_empty_input", default=False)

def __init__(self, sparse=False, **kw):
super().__init__(sparse=sparse, **kw)

Expand All @@ -106,6 +109,8 @@ def execute(cls, ctx, op):
inputs, device_id, xp = as_same_device(
[ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True
)
if op.ignore_empty_input:
inputs = [inp for inp in inputs if not hasattr(inp, "size") or inp.size > 0]

with device(device_id):
ctx[op.outputs[0].key] = reduce(xp.multiply, inputs)
Expand Down
6 changes: 5 additions & 1 deletion mars/tensor/reduction/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,11 @@ def tile(cls, op):
to_cum_chunks.append(sliced_chunk)
to_cum_chunks.append(chunk)

bin_op = bin_op_type(args=to_cum_chunks, dtype=chunk.dtype)
# GH#3132: some chunks of to_cum_chunks may be empty,
# so we tell tree_add&tree_multiply to ignore them
bin_op = bin_op_type(
args=to_cum_chunks, dtype=chunk.dtype, ignore_empty_input=True
)
output_chunk = bin_op.new_chunk(
to_cum_chunks,
shape=chunk.shape,
Expand Down
10 changes: 10 additions & 0 deletions mars/tensor/reduction/tests/test_reduction_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,16 @@ def test_cum_reduction(setup):
np.cumsum(np.array(list("abcdefghi"), dtype=object)),
)

# test empty chunks
raw = np.random.rand(100)
arr = tensor(raw, chunk_size=((0, 100),))
res = arr.cumsum().execute().fetch()
expected = raw.cumsum()
np.testing.assert_allclose(res, expected)
res = arr.cumprod().execute().fetch()
expected = raw.cumprod()
np.testing.assert_allclose(res, expected)


def test_nan_cum_reduction(setup):
raw = np.random.randint(5, size=(8, 8, 8)).astype(float)
Expand Down
4 changes: 3 additions & 1 deletion mars/tensor/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,9 @@ def reshape(a, newshape, order="C"):

tensor_order = get_order(order, a.order, available_options="CFA")

if a.shape == newshape and tensor_order == a.order:
if a.shape == newshape and (
a.ndim <= 1 or (a.ndim > 1 and tensor_order == a.order)
):
# does not need to reshape
return a
return _reshape(
Expand Down

0 comments on commit 424cfb9

Please sign in to comment.