Skip to content

Commit

Permalink
Fix DataFrame/Series.rank for int and null data in mode.pandas_compat…
Browse files Browse the repository at this point in the history
…ible (#17954)

closes #17948

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #17954
  • Loading branch information
mroeschke authored Feb 8, 2025
1 parent fdb7e7d commit 61e47bb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
22 changes: 12 additions & 10 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5027,7 +5027,7 @@ def repeat(self, repeats, axis=None):

def astype(
self,
dtype: dict[Any, Dtype],
dtype: Dtype | dict[Any, Dtype],
copy: bool = False,
errors: Literal["raise", "ignore"] = "raise",
) -> Self:
Expand Down Expand Up @@ -6340,13 +6340,13 @@ def _preprocess_subset(self, subset) -> set[abc.Hashable]:
@_performance_tracking
def rank(
self,
axis=0,
method="average",
numeric_only=False,
na_option="keep",
ascending=True,
pct=False,
):
axis: Literal[0, "index"] = 0,
method: Literal["average", "min", "max", "first", "dense"] = "average",
numeric_only: bool = False,
na_option: Literal["keep", "top", "bottom"] = "keep",
ascending: bool = True,
pct: bool = False,
) -> Self:
"""
Compute numerical data ranks (1 through n) along axis.
Expand Down Expand Up @@ -6404,7 +6404,7 @@ def rank(
if numeric_only:
if isinstance(
source, cudf.Series
) and not _is_non_decimal_numeric_dtype(self.dtype):
) and not _is_non_decimal_numeric_dtype(self.dtype): # type: ignore[attr-defined]
raise TypeError(
"Series.rank does not allow numeric_only=True with "
"non-numeric dtype."
Expand All @@ -6416,7 +6416,7 @@ def rank(
)
source = self._get_columns_by_label(numeric_cols)
if source.empty:
return source.astype("float64")
return source.astype(np.dtype(np.float64))
elif source._num_columns != num_cols:
dropped_cols = True

Expand Down Expand Up @@ -6449,6 +6449,8 @@ def rank(
else plc.types.NullPolicy.INCLUDE
)

if cudf.get_option("mode.pandas_compatible"):
source = source.nans_to_nulls()
with acquire_spill_lock():
result_columns = [
libcudf.column.Column.from_pylibcudf(
Expand Down
14 changes: 13 additions & 1 deletion python/cudf/cudf/tests/test_rank.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.

from itertools import chain, combinations_with_replacement, product

import numpy as np
import pandas as pd
import pytest

import cudf
from cudf import DataFrame
from cudf.testing import assert_eq
from cudf.testing._utils import assert_exceptions_equal
Expand Down Expand Up @@ -151,3 +152,14 @@ def test_series_rank_combinations(elem, dtype):
ranked_ps = df["a"].rank(method="first")
# Check
assert_eq(ranked_ps, ranked_gs)


@pytest.mark.parametrize("klass", ["Series", "DataFrame"])
def test_int_nan_pandas_compatible(klass):
data = [3, 6, 1, 1, None, 6]
pd_obj = getattr(pd, klass)(data)
cudf_obj = getattr(cudf, klass)(data)
with cudf.option_context("mode.pandas_compatible", True):
result = cudf_obj.rank()
expected = pd_obj.rank()
assert_eq(result, expected)

0 comments on commit 61e47bb

Please sign in to comment.