Skip to content

Commit db57900

Browse files
authored
Ensure min_periods=0 is passed through rolling aggregations (rapidsai#20653)
Contributes to rapidsai#18659 Authors: - Matthew Murray (https://github.com/Matt711) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: rapidsai#20653
1 parent 918189f commit db57900

File tree

3 files changed

+59
-40
lines changed

3 files changed

+59
-40
lines changed

python/cudf/cudf/core/window/rolling.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def _apply_agg_column(
373373
source_column.plc_column,
374374
pre,
375375
fwd,
376-
self.min_periods or 1,
376+
1 if self.min_periods is None else self.min_periods,
377377
rolling_agg,
378378
)
379379
)
@@ -442,9 +442,7 @@ def count(self) -> DataFrame | Series:
442442
return self._apply_agg("count")
443443

444444
def median(self, **kwargs):
445-
raise NotImplementedError(
446-
"groupby().rolling().median() is not yet implemented"
447-
)
445+
raise NotImplementedError("Rolling.median() is not yet implemented")
448446

449447
def apply(self, func, *args, **kwargs) -> DataFrame | Series:
450448
"""

python/cudf/cudf/pandas/scripts/conftest-patch.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8034,42 +8034,10 @@ def pytest_unconfigure(config):
80348034
"tests/window/moments/test_moments_consistency_expanding.py::test_expanding_apply_consistency_sum_nans[all_data6-2-sum]",
80358035
"tests/window/moments/test_moments_consistency_expanding.py::test_expanding_apply_consistency_sum_nans[all_data7-0-sum]",
80368036
"tests/window/moments/test_moments_consistency_expanding.py::test_expanding_apply_consistency_sum_nans[all_data7-2-sum]",
8037-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data1-rolling_consistency_cases0-False-<lambda>]",
8038-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data1-rolling_consistency_cases0-False-nansum]",
8039-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data1-rolling_consistency_cases0-True-<lambda>]",
8040-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data1-rolling_consistency_cases0-True-nansum]",
8041-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data11-rolling_consistency_cases0-False-<lambda>]",
8042-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data11-rolling_consistency_cases0-False-nansum]",
8043-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data11-rolling_consistency_cases0-True-<lambda>]",
8044-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data11-rolling_consistency_cases0-True-nansum]",
8045-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data15-rolling_consistency_cases0-False-<lambda>]",
8046-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data15-rolling_consistency_cases0-False-nansum]",
8047-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data15-rolling_consistency_cases0-True-<lambda>]",
8048-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data15-rolling_consistency_cases0-True-nansum]",
8049-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data16-rolling_consistency_cases0-False-<lambda>]",
8050-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data16-rolling_consistency_cases0-False-nansum]",
8051-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data16-rolling_consistency_cases0-True-<lambda>]",
8052-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data16-rolling_consistency_cases0-True-nansum]",
8053-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data17-rolling_consistency_cases0-False-<lambda>]",
8054-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data17-rolling_consistency_cases0-False-nansum]",
8055-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data17-rolling_consistency_cases0-True-<lambda>]",
8056-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data17-rolling_consistency_cases0-True-nansum]",
8057-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data5-rolling_consistency_cases0-False-<lambda>]",
8058-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data5-rolling_consistency_cases0-False-nansum]",
8059-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data5-rolling_consistency_cases0-True-<lambda>]",
8060-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data5-rolling_consistency_cases0-True-nansum]",
80618037
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data5-rolling_consistency_cases1-False-sum]",
80628038
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data5-rolling_consistency_cases1-True-sum]",
8063-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases0-False-<lambda>]",
8064-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases0-False-nansum]",
8065-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases0-True-<lambda>]",
8066-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases0-True-nansum]",
80678039
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases1-False-sum]",
80688040
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases1-True-sum]",
8069-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases0-False-<lambda>]",
8070-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases0-False-nansum]",
8071-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases0-True-<lambda>]",
8072-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases0-True-nansum]",
80738041
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases1-False-sum]",
80748042
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases1-True-sum]",
80758043
"tests/window/test_api.py::test_rolling_max_min_periods[None]",
@@ -8146,7 +8114,6 @@ def pytest_unconfigure(config):
81468114
"tests/window/test_rolling.py::test_closed_fixed_binary_col[True-10]",
81478115
"tests/window/test_rolling.py::test_closed_fixed_binary_col[True-2]",
81488116
"tests/window/test_rolling.py::test_closed_fixed_binary_col[True-5]",
8149-
"tests/window/test_rolling.py::test_missing_minp_zero",
81508117
"tests/window/test_rolling.py::test_rolling_non_monotonic[mean-expected1]",
81518118
"tests/window/test_rolling.py::test_rolling_non_monotonic[sum-expected2]",
81528119
"tests/window/test_rolling.py::test_rolling_non_monotonic[var-expected0]",
@@ -8158,10 +8125,12 @@ def pytest_unconfigure(config):
81588125
"tests/window/test_rolling.py::test_rolling_numerical_accuracy_kahan_mean[s-2.0]",
81598126
"tests/window/test_rolling.py::test_rolling_numerical_accuracy_kahan_mean[us-0.0]",
81608127
"tests/window/test_rolling.py::test_rolling_numerical_accuracy_kahan_mean[us-2.0]",
8161-
"tests/window/test_rolling.py::test_rolling_sum_all_nan_window_floating_artifacts",
81628128
"tests/window/test_rolling.py::test_rolling_var_same_value_count_logic[values0-3-1-expected0]",
81638129
"tests/window/test_rolling.py::test_variable_window_nonunique[DataFrame-right-expected2]",
81648130
"tests/window/test_rolling.py::test_variable_window_nonunique[Series-right-expected2]",
8131+
"tests/window/test_rolling_functions.py::test_nans[mean-mean-kwargs0]",
8132+
"tests/window/test_rolling_functions.py::test_nans[min-min-kwargs3]",
8133+
"tests/window/test_rolling_functions.py::test_nans[max-max-kwargs4]",
81658134
"tests/window/test_rolling_functions.py::test_rolling_max_gh6297[10]",
81668135
"tests/window/test_rolling_functions.py::test_rolling_max_gh6297[1]",
81678136
"tests/window/test_rolling_functions.py::test_rolling_max_gh6297[2]",
@@ -8195,8 +8164,8 @@ def pytest_unconfigure(config):
81958164
"tests/window/test_timeseries_window.py::TestRollingTS::test_rolling_on_decreasing_index[s]",
81968165
"tests/window/test_timeseries_window.py::TestRollingTS::test_rolling_on_decreasing_index[us]",
81978166
"tests/window/test_timeseries_window.py::TestRollingTS::test_rolling_on_empty",
8198-
"tests/window/test_win_type.py::test_cmov_window_corner[None]",
81998167
"tests/window/test_win_type.py::test_invalid_scipy_arg",
8168+
"tests/window/test_win_type.py::test_cmov_window_corner[None]",
82008169
"tests/window/test_win_type.py::test_win_type_not_implemented",
82018170
# dtype mismatch: cudf.pandas: pd.Index, pandas: pd.CategoricalIndex
82028171
r"""tests/io/json/test_pandas.py::TestPandasContainer::test_json_indent_all_orients[table-{\n "schema":{\n "fields":[\n {\n "name":"index",\n "type":"integer"\n },\n {\n "name":"a",\n "type":"string"\n },\n {\n "name":"b",\n "type":"string"\n }\n ],\n "primaryKey":[\n "index"\n ],\n "pandas_version":"1.4.0"\n },\n "data":[\n {\n "index":0,\n "a":"foo",\n "b":"bar"\n },\n {\n "index":1,\n "a":"baz",\n "b":"qux"\n }\n ]\n}]""",
@@ -8552,7 +8521,51 @@ def pytest_unconfigure(config):
85528521
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases0-False-sum]",
85538522
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data6-rolling_consistency_cases0-True-sum]",
85548523
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases0-False-sum]",
8555-
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_apply_consistency_sum[all_data7-rolling_consistency_cases0-True-sum]",
8524+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data1-rolling_consistency_cases0-True-0]",
8525+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data1-rolling_consistency_cases0-False-0]",
8526+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data5-rolling_consistency_cases0-True-0]",
8527+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data5-rolling_consistency_cases0-False-0]",
8528+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data6-rolling_consistency_cases0-True-0]",
8529+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data6-rolling_consistency_cases0-False-0]",
8530+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data7-rolling_consistency_cases0-True-0]",
8531+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data7-rolling_consistency_cases0-False-0]",
8532+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data11-rolling_consistency_cases0-True-0]",
8533+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data11-rolling_consistency_cases0-False-0]",
8534+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data15-rolling_consistency_cases0-True-0]",
8535+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data15-rolling_consistency_cases0-False-0]",
8536+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data16-rolling_consistency_cases0-True-0]",
8537+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data16-rolling_consistency_cases0-False-0]",
8538+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data17-rolling_consistency_cases0-True-0]",
8539+
"tests/window/moments/test_moments_consistency_rolling.py::test_moments_consistency_var[all_data17-rolling_consistency_cases0-False-0]",
8540+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_series_cov_corr[series_data1-rolling_consistency_cases0-True-0]",
8541+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_series_cov_corr[series_data5-rolling_consistency_cases0-True-0]",
8542+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_series_cov_corr[series_data5-rolling_consistency_cases0-False-0]",
8543+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_series_cov_corr[series_data6-rolling_consistency_cases0-True-0]",
8544+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_series_cov_corr[series_data7-rolling_consistency_cases0-True-0]",
8545+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_series_cov_corr[series_data6-rolling_consistency_cases0-False-0]",
8546+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_series_cov_corr[series_data7-rolling_consistency_cases0-False-0]",
8547+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data1-rolling_consistency_cases0-True]",
8548+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data1-rolling_consistency_cases0-False]",
8549+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data5-rolling_consistency_cases0-True]",
8550+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data5-rolling_consistency_cases0-False]",
8551+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data6-rolling_consistency_cases0-True]",
8552+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data6-rolling_consistency_cases0-False]",
8553+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data7-rolling_consistency_cases0-True]",
8554+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data7-rolling_consistency_cases0-False]",
8555+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data11-rolling_consistency_cases0-True]",
8556+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data11-rolling_consistency_cases0-False]",
8557+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data15-rolling_consistency_cases0-True]",
8558+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data15-rolling_consistency_cases0-False]",
8559+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data16-rolling_consistency_cases0-True]",
8560+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data16-rolling_consistency_cases0-False]",
8561+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data17-rolling_consistency_cases0-True]",
8562+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_mean[all_data17-rolling_consistency_cases0-False]",
8563+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_constant[consistent_data1-rolling_consistency_cases0-True]",
8564+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_constant[consistent_data1-rolling_consistency_cases0-False]",
8565+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_constant[consistent_data3-rolling_consistency_cases0-True]",
8566+
"tests/window/moments/test_moments_consistency_rolling.py::test_rolling_consistency_constant[consistent_data3-rolling_consistency_cases0-False]",
8567+
"tests/window/test_rolling.py::test_rolling_mean_all_nan_window_floating_artifacts[1-exp_values0]",
8568+
"tests/window/test_rolling.py::test_rolling_mean_all_nan_window_floating_artifacts[2-exp_values1]",
85568569
}
85578570

85588571
# TODO: Investigate why sometimes these fail

python/cudf/cudf/tests/window/test_rolling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,11 @@ def test_groupby_rolling_pickleable():
536536
df = cudf.DataFrame({"a": [1, 1, 2], "b": [1, 2, 3]})
537537
gb_rolling = pickle.loads(pickle.dumps(df.groupby("a").rolling(2)))
538538
assert_eq(gb_rolling.obj, cudf.DataFrame({"b": [1, 2, 3]}))
539+
540+
541+
def test_rolling_min_periods_zero():
542+
s = cudf.Series([np.nan, 1.0, 2.0, 3.0])
543+
ps = s.to_pandas()
544+
result = s.rolling(2, min_periods=0).sum()
545+
expected = ps.rolling(2, min_periods=0).sum()
546+
assert_eq(result, expected)

0 commit comments

Comments
 (0)