Skip to content

Commit

Permalink
[BACKPORT] Fix incorrect result for df.sort_values when specifying mu…
Browse files Browse the repository at this point in the history
…ltiple ascending (#2984) (#3006)

Co-authored-by: He Kaisheng <[email protected]>
  • Loading branch information
wjsi and hekaisheng authored May 7, 2022
1 parent 85331e8 commit e550ae4
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 169 deletions.
4 changes: 2 additions & 2 deletions mars/core/operand/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def __init__(self: OperandType, *args, **kwargs):
extra_names = (
set(kwargs) - set(self._FIELDS) - set(SchedulingHint.all_hint_names)
)
extras = AttributeDict((k, kwargs.pop(k)) for k in extra_names)
kwargs["extra_params"] = kwargs.pop("extra_params", extras)
extras = dict((k, kwargs.pop(k)) for k in extra_names)
kwargs["extra_params"] = AttributeDict(kwargs.pop("extra_params", extras))
self._extract_scheduling_hint(kwargs)
super().__init__(*args, **kwargs)

Expand Down
2 changes: 2 additions & 0 deletions mars/core/operand/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class MyOperand5(MyOperand4):


def test_execute():
op = MyOperand(extra_params={"my_extra_params": 1})
assert op.extra_params["my_extra_params"] == 1
MyOperand.register_executor(lambda *_: 2)
assert execute(dict(), MyOperand(_key="1")) == 2
assert execute(dict(), MyOperand2(_key="1")) == 2
Expand Down
7 changes: 4 additions & 3 deletions mars/dataframe/sort/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from ...core.operand import OperandStage
from ...serialization.serializables import (
FieldTypes,
AnyField,
BoolField,
Int32Field,
Int64Field,
StringField,
ListField,
BoolField,
StringField,
)
from ...utils import ceildiv
from ..operands import DataFrameOperand
Expand All @@ -32,7 +33,7 @@

class DataFrameSortOperand(DataFrameOperand):
_axis = Int32Field("axis")
_ascending = BoolField("ascending")
_ascending = AnyField("ascending")
_inplace = BoolField("inplace")
_kind = StringField("kind")
_na_position = StringField("na_position")
Expand Down
224 changes: 65 additions & 159 deletions mars/dataframe/sort/psrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from ... import opcodes as OperandDef
from ...core.operand import OperandStage, MapReduceOperand
from ...utils import lazy_import
from ...serialization.serializables import Int32Field, ListField, StringField, BoolField
from ...serialization.serializables import (
AnyField,
Int32Field,
ListField,
StringField,
BoolField,
)
from ...tensor.base.psrs import PSRSOperandMixin
from ..core import IndexValue, OutputType
from ..utils import standardize_range_index, parse_index, is_cudf
Expand Down Expand Up @@ -48,6 +54,23 @@ def __gt__(self, other):
_largest = _Largest()


class _ReversedValue:
def __init__(self, value):
self._value = value

def __lt__(self, other):
if type(other) is _ReversedValue:
# may happen when call searchsorted
return self._value >= other._value
return self._value >= other

def __gt__(self, other):
return self._value <= other

def __repr__(self):
return repr(self._value)


class DataFramePSRSOperandMixin(DataFrameOperandMixin, PSRSOperandMixin):
@classmethod
def _collect_op_properties(cls, op):
Expand Down Expand Up @@ -377,90 +400,23 @@ def execute_sort_index(data, op, inplace=None):

class DataFramePSRSChunkOperand(DataFrameOperand):
# sort type could be 'sort_values' or 'sort_index'
_sort_type = StringField("sort_type")
sort_type = StringField("sort_type")

_axis = Int32Field("axis")
_by = ListField("by")
_ascending = BoolField("ascending")
_inplace = BoolField("inplace")
_kind = StringField("kind")
_na_position = StringField("na_position")
axis = Int32Field("axis")
by = ListField("by", default=None)
ascending = AnyField("ascending")
inplace = BoolField("inplace")
kind = StringField("kind")
na_position = StringField("na_position")

# for sort_index
_level = ListField("level")
_sort_remaining = BoolField("sort_remaining")

_n_partition = Int32Field("n_partition")

def __init__(
self,
sort_type=None,
by=None,
axis=None,
ascending=None,
inplace=None,
kind=None,
na_position=None,
level=None,
sort_remaining=None,
n_partition=None,
output_types=None,
**kw
):
super().__init__(
_sort_type=sort_type,
_by=by,
_axis=axis,
_ascending=ascending,
_inplace=inplace,
_kind=kind,
_na_position=na_position,
_level=level,
_sort_remaining=sort_remaining,
_n_partition=n_partition,
_output_types=output_types,
**kw
)
level = ListField("level")
sort_remaining = BoolField("sort_remaining")

@property
def sort_type(self):
return self._sort_type
n_partition = Int32Field("n_partition")

@property
def axis(self):
return self._axis

@property
def by(self):
return self._by

@property
def ascending(self):
return self._ascending

@property
def inplace(self):
return self._inplace

@property
def kind(self):
return self._kind

@property
def na_position(self):
return self._na_position

@property
def level(self):
return self._level

@property
def sort_remaining(self):
return self._sort_remaining

@property
def n_partition(self):
return self._n_partition
def __init__(self, output_types=None, **kw):
super().__init__(_output_types=output_types, **kw)


class DataFramePSRSSortRegularSample(DataFramePSRSChunkOperand, DataFrameOperandMixin):
Expand Down Expand Up @@ -564,99 +520,49 @@ def execute(cls, ctx, op):
class DataFramePSRSShuffle(MapReduceOperand, DataFrameOperandMixin):
_op_type_ = OperandDef.PSRS_SHUFFLE

_sort_type = StringField("sort_type")
sort_type = StringField("sort_type")

# for shuffle map
_axis = Int32Field("axis")
_by = ListField("by")
_ascending = BoolField("ascending")
_inplace = BoolField("inplace")
_na_position = StringField("na_position")
_n_partition = Int32Field("n_partition")
axis = Int32Field("axis")
by = ListField("by")
ascending = AnyField("ascending")
inplace = BoolField("inplace")
na_position = StringField("na_position")
n_partition = Int32Field("n_partition")

# for sort_index
_level = ListField("level")
_sort_remaining = BoolField("sort_remaining")
level = ListField("level")
sort_remaining = BoolField("sort_remaining")

# for shuffle reduce
_kind = StringField("kind")

def __init__(
self,
sort_type=None,
by=None,
axis=None,
ascending=None,
n_partition=None,
na_position=None,
inplace=None,
kind=None,
level=None,
sort_remaining=None,
output_types=None,
**kw
):
super().__init__(
_sort_type=sort_type,
_by=by,
_axis=axis,
_ascending=ascending,
_n_partition=n_partition,
_na_position=na_position,
_inplace=inplace,
_kind=kind,
_level=level,
_sort_remaining=sort_remaining,
_output_types=output_types,
**kw
)

@property
def sort_type(self):
return self._sort_type

@property
def by(self):
return self._by

@property
def axis(self):
return self._axis

@property
def ascending(self):
return self._ascending
kind = StringField("kind")

@property
def inplace(self):
return self._inplace

@property
def na_position(self):
return self._na_position

@property
def level(self):
return self._level

@property
def sort_remaining(self):
return self._sort_remaining

@property
def n_partition(self):
return self._n_partition

@property
def kind(self):
return self._kind
def __init__(self, output_types=None, **kw):
super().__init__(_output_types=output_types, **kw)

@property
def output_limit(self):
return 1

@staticmethod
def _calc_poses(src_cols, pivots, ascending=True):
if isinstance(ascending, list):
for asc, col in zip(ascending, pivots.columns):
# Make pivots available to use ascending order when mixed order specified
if not asc:
if pd.api.types.is_numeric_dtype(pivots.dtypes[col]):
# for numeric dtypes, convert to negative is more efficient
pivots[col] = -pivots[col]
src_cols[col] = -src_cols[col]
else:
# for other types, convert to ReversedValue
pivots[col] = pivots[col].map(
lambda x: x
if type(x) is _ReversedValue
else _ReversedValue(x)
)
ascending = True

records = src_cols.to_records(index=False)
p_records = pivots.to_records(index=False)
if ascending:
Expand Down
7 changes: 7 additions & 0 deletions mars/dataframe/sort/sort_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ def dataframe_sort_values(
raise NotImplementedError("Only support sort on axis 0")
psrs_kinds = _validate_sort_psrs_kinds(psrs_kinds)
by = by if isinstance(by, (list, tuple)) else [by]
if isinstance(ascending, list): # pragma: no cover
if all(ascending):
# all are True, convert to True
ascending = True
elif not any(ascending):
# all are False, convert to False
ascending = False
op = DataFrameSortValues(
by=by,
axis=axis,
Expand Down
37 changes: 34 additions & 3 deletions mars/dataframe/sort/tests/test_sort_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
"distinct_opt", ["0"] if sys.platform.lower().startswith("win") else ["0", "1"]
)
def test_sort_values_execution(setup, distinct_opt):
ns = np.random.RandomState(0)
os.environ["PSRS_DISTINCT_COL"] = distinct_opt
df = pd.DataFrame(
np.random.rand(100, 10), columns=["a" + str(i) for i in range(10)]
)
df = pd.DataFrame(ns.rand(100, 10), columns=["a" + str(i) for i in range(10)])

# test one chunk
mdf = DataFrame(df)
Expand Down Expand Up @@ -67,6 +66,38 @@ def test_sort_values_execution(setup, distinct_opt):

pd.testing.assert_frame_equal(result, expected)

# test ascending is a list
result = (
mdf.sort_values(["a3", "a4", "a5", "a6"], ascending=[False, True, True, False])
.execute()
.fetch()
)
expected = df.sort_values(
["a3", "a4", "a5", "a6"], ascending=[False, True, True, False]
)
pd.testing.assert_frame_equal(result, expected)

in_df = pd.DataFrame(
{
"col1": ns.choice([f"a{i}" for i in range(5)], size=(100,)),
"col2": ns.choice([f"b{i}" for i in range(5)], size=(100,)),
"col3": ns.choice([f"c{i}" for i in range(5)], size=(100,)),
"col4": ns.randint(10, 20, size=(100,)),
}
)
mdf = DataFrame(in_df, chunk_size=10)
result = (
mdf.sort_values(
["col1", "col4", "col3", "col2"], ascending=[False, False, True, False]
)
.execute()
.fetch()
)
expected = in_df.sort_values(
["col1", "col4", "col3", "col2"], ascending=[False, False, True, False]
)
pd.testing.assert_frame_equal(result, expected)

# test multiindex
df2 = df.copy(deep=True)
df2.columns = pd.MultiIndex.from_product([list("AB"), list("CDEFG")])
Expand Down
9 changes: 8 additions & 1 deletion mars/oscar/backends/mars/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,14 @@ async def kill_sub_pool(
await asyncio.to_thread(process.join, 5)

async def is_sub_pool_alive(self, process: multiprocessing.Process):
return await asyncio.to_thread(process.is_alive)
try:
return await asyncio.to_thread(process.is_alive)
except RuntimeError as ex: # pragma: no cover
if "shutdown" not in str(ex):
# when atexit is triggered, the default pool might be shutdown
# and to_thread will fail
raise
return process.is_alive()

async def recover_sub_pool(self, address: str):
process_index = self._config.get_process_index(address)
Expand Down
Loading

0 comments on commit e550ae4

Please sign in to comment.