diff --git a/doc/api.rst b/doc/api.rst index 67c81aaf601..b67eafecc76 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1329,6 +1329,8 @@ Grouper Objects groupers.BinGrouper groupers.UniqueGrouper groupers.TimeResampler + groupers.SeasonGrouper + groupers.SeasonResampler Rolling objects diff --git a/doc/conf.py b/doc/conf.py index d4328dbf1b0..43afc1253e5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -187,6 +187,8 @@ "pd.NaT": "~pandas.NaT", } +autodoc_type_aliases = napoleon_type_aliases # Keep both in sync + # mermaid config mermaid_version = "10.9.1" diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index 7cb4e883347..673e23d75ac 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -332,6 +332,14 @@ Different groupers can be combined to construct sophisticated GroupBy operations ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() +Time Grouping and Resampling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. seealso:: + + See :ref:`resampling`. + + Shuffling ~~~~~~~~~ diff --git a/doc/user-guide/time-series.rst b/doc/user-guide/time-series.rst index d131ae74b9f..cb3e94e3645 100644 --- a/doc/user-guide/time-series.rst +++ b/doc/user-guide/time-series.rst @@ -1,3 +1,5 @@ +.. currentmodule:: xarray + .. _time-series: ================ @@ -21,12 +23,12 @@ core functionality. Creating datetime64 data ------------------------ -Xarray uses the numpy dtypes ``datetime64[unit]`` and ``timedelta64[unit]`` -(where unit is one of ``"s"``, ``"ms"``, ``"us"`` and ``"ns"``) to represent datetime +Xarray uses the numpy dtypes :py:class:`numpy.datetime64` and :py:class:`numpy.timedelta64` +with specified units (one of ``"s"``, ``"ms"``, ``"us"`` and ``"ns"``) to represent datetime data, which offer vectorized operations with numpy and smooth integration with pandas. -To convert to or create regular arrays of ``datetime64`` data, we recommend -using :py:func:`pandas.to_datetime` and :py:func:`pandas.date_range`: +To convert to or create regular arrays of :py:class:`numpy.datetime64` data, we recommend +using :py:func:`pandas.to_datetime`, :py:class:`pandas.DatetimeIndex`, or :py:func:`xarray.date_range`: .. ipython:: python @@ -34,13 +36,6 @@ using :py:func:`pandas.to_datetime` and :py:func:`pandas.date_range`: pd.DatetimeIndex( ["2000-01-01 00:00:00", "2000-02-02 00:00:00"], dtype="datetime64[s]" ) - pd.date_range("2000-01-01", periods=365) - pd.date_range("2000-01-01", periods=365, unit="s") - -It is also possible to use corresponding :py:func:`xarray.date_range`: - -.. ipython:: python - xr.date_range("2000-01-01", periods=365) xr.date_range("2000-01-01", periods=365, unit="s") @@ -81,7 +76,7 @@ attribute like ``'days since 2000-01-01'``). You can manual decode arrays in this form by passing a dataset to -:py:func:`~xarray.decode_cf`: +:py:func:`decode_cf`: .. ipython:: python @@ -93,8 +88,8 @@ You can manual decode arrays in this form by passing a dataset to coder = xr.coders.CFDatetimeCoder(time_unit="s") xr.decode_cf(ds, decode_times=coder) -From xarray 2025.01.2 the resolution of the dates can be one of ``"s"``, ``"ms"``, ``"us"`` or ``"ns"``. One limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262, which gets increased significantly with lower resolutions. When a store contains dates outside of these bounds (or dates < `1582-10-15`_ with a Gregorian, also known as standard, calendar), dates will be returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex` will be used for indexing. -:py:class:`~xarray.CFTimeIndex` enables most of the indexing functionality of a :py:class:`pandas.DatetimeIndex`. +From xarray 2025.01.2 the resolution of the dates can be one of ``"s"``, ``"ms"``, ``"us"`` or ``"ns"``. One limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262, which gets increased significantly with lower resolutions. When a store contains dates outside of these bounds (or dates < `1582-10-15`_ with a Gregorian, also known as standard, calendar), dates will be returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`CFTimeIndex` will be used for indexing. +:py:class:`CFTimeIndex` enables most of the indexing functionality of a :py:class:`pandas.DatetimeIndex`. See :ref:`CFTimeIndex` for more information. Datetime indexing @@ -205,35 +200,37 @@ You can also search for multiple months (in this case January through March), us Resampling and grouped operations --------------------------------- -Datetime components couple particularly well with grouped operations (see -:ref:`groupby`) for analyzing features that repeat over time. Here's how to -calculate the mean by time of day: + +.. seealso:: + + For more generic documentation on grouping, see :ref:`groupby`. + + +Datetime components couple particularly well with grouped operations for analyzing features that repeat over time. +Here's how to calculate the mean by time of day: .. ipython:: python - :okwarning: ds.groupby("time.hour").mean() For upsampling or downsampling temporal resolutions, xarray offers a -:py:meth:`~xarray.Dataset.resample` method building on the core functionality +:py:meth:`Dataset.resample` method building on the core functionality offered by the pandas method of the same name. Resample uses essentially the -same api as ``resample`` `in pandas`_. +same api as :py:meth:`pandas.DataFrame.resample` `in pandas`_. .. _in pandas: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#up-and-downsampling For example, we can downsample our dataset from hourly to 6-hourly: .. ipython:: python - :okwarning: ds.resample(time="6h") -This will create a specialized ``Resample`` object which saves information -necessary for resampling. All of the reduction methods which work with -``Resample`` objects can also be used for resampling: +This will create a specialized :py:class:`~xarray.core.resample.DatasetResample` or :py:class:`~xarray.core.resample.DataArrayResample` +object which saves information necessary for resampling. All of the reduction methods which work with +:py:class:`Dataset` or :py:class:`DataArray` objects can also be used for resampling: .. ipython:: python - :okwarning: ds.resample(time="6h").mean() @@ -252,7 +249,7 @@ by specifying the ``dim`` keyword argument ds.resample(time="6h").mean(dim=["time", "latitude", "longitude"]) For upsampling, xarray provides six methods: ``asfreq``, ``ffill``, ``bfill``, ``pad``, -``nearest`` and ``interpolate``. ``interpolate`` extends ``scipy.interpolate.interp1d`` +``nearest`` and ``interpolate``. ``interpolate`` extends :py:func:`scipy.interpolate.interp1d` and supports all of its schemes. All of these resampling operations work on both Dataset and DataArray objects with an arbitrary number of dimensions. @@ -266,9 +263,7 @@ Data that has indices outside of the given ``tolerance`` are set to ``NaN``. It is often desirable to center the time values after a resampling operation. That can be accomplished by updating the resampled dataset time coordinate values -using time offset arithmetic via the `pandas.tseries.frequencies.to_offset`_ function. - -.. _pandas.tseries.frequencies.to_offset: https://pandas.pydata.org/docs/reference/api/pandas.tseries.frequencies.to_offset.html +using time offset arithmetic via the :py:func:`pandas.tseries.frequencies.to_offset` function. .. ipython:: python @@ -277,5 +272,80 @@ using time offset arithmetic via the `pandas.tseries.frequencies.to_offset`_ fun resampled_ds["time"] = resampled_ds.get_index("time") + offset resampled_ds -For more examples of using grouped operations on a time dimension, see -:doc:`../examples/weather-data`. + +.. seealso:: + + For more examples of using grouped operations on a time dimension, see :doc:`../examples/weather-data`. + + +Handling Seasons +~~~~~~~~~~~~~~~~ + +Two extremely common time series operations are to group by seasons, and resample to a seasonal frequency. +Xarray has historically supported some simple versions of these computations. +For example, ``.groupby("time.season")`` (where the seasons are DJF, MAM, JJA, SON) +and resampling to a seasonal frequency using Pandas syntax: ``.resample(time="QS-DEC")``. + +Quite commonly one wants more flexibility in defining seasons. For these use-cases, Xarray provides +:py:class:`groupers.SeasonGrouper` and :py:class:`groupers.SeasonResampler`. + + +.. currentmodule:: xarray.groupers + +.. ipython:: python + + from xarray.groupers import SeasonGrouper + + ds.groupby(time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"])).mean() + + +Note how the seasons are in the specified order, unlike ``.groupby("time.season")`` where the +seasons are sorted alphabetically. + +.. ipython:: python + + ds.groupby("time.season").mean() + + +:py:class:`SeasonGrouper` supports overlapping seasons: + +.. ipython:: python + + ds.groupby(time=SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"])).mean() + + +Skipping months is allowed: + +.. ipython:: python + + ds.groupby(time=SeasonGrouper(["JJAS"])).mean() + + +Use :py:class:`SeasonResampler` to specify custom seasons. + +.. ipython:: python + + from xarray.groupers import SeasonResampler + + ds.resample(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).mean() + + +:py:class:`SeasonResampler` is smart enough to correctly handle years for seasons that +span the end of the year (e.g. DJF). By default :py:class:`SeasonResampler` will skip any +season that is incomplete (e.g. the first DJF season for a time series that starts in Jan). +Pass the ``drop_incomplete=False`` kwarg to :py:class:`SeasonResampler` to disable this behaviour. + +.. ipython:: python + + from xarray.groupers import SeasonResampler + + ds.resample( + time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + + +Seasons need not be of the same length: + +.. ipython:: python + + ds.resample(time=SeasonResampler(["JF", "MAM", "JJAS", "OND"])).mean() diff --git a/properties/test_properties.py b/properties/test_properties.py index fc0a1955539..2ae91a15801 100644 --- a/properties/test_properties.py +++ b/properties/test_properties.py @@ -1,11 +1,15 @@ +import itertools + import pytest pytest.importorskip("hypothesis") -from hypothesis import given +import hypothesis.strategies as st +from hypothesis import given, note import xarray as xr import xarray.testing.strategies as xrst +from xarray.groupers import find_independent_seasons, season_to_month_tuple @given(attrs=xrst.simple_attrs) @@ -15,3 +19,45 @@ def test_assert_identical(attrs): ds = xr.Dataset(attrs=attrs) xr.testing.assert_identical(ds, ds.copy(deep=True)) + + +@given( + roll=st.integers(min_value=0, max_value=12), + breaks=st.lists( + st.integers(min_value=0, max_value=11), min_size=1, max_size=12, unique=True + ), +) +def test_property_season_month_tuple(roll, breaks): + chars = list("JFMAMJJASOND") + months = tuple(range(1, 13)) + + rolled_chars = chars[roll:] + chars[:roll] + rolled_months = months[roll:] + months[:roll] + breaks = sorted(breaks) + if breaks[0] != 0: + breaks = [0] + breaks + if breaks[-1] != 12: + breaks = breaks + [12] + seasons = tuple( + "".join(rolled_chars[start:stop]) for start, stop in itertools.pairwise(breaks) + ) + actual = season_to_month_tuple(seasons) + expected = tuple( + rolled_months[start:stop] for start, stop in itertools.pairwise(breaks) + ) + assert expected == actual + + +@given(data=st.data(), nmonths=st.integers(min_value=1, max_value=11)) +def test_property_find_independent_seasons(data, nmonths): + chars = "JFMAMJJASOND" + # if stride > nmonths, then we can't infer season order + stride = data.draw(st.integers(min_value=1, max_value=nmonths)) + chars = chars + chars[:nmonths] + seasons = [list(chars[i : i + nmonths]) for i in range(0, 12, stride)] + note(seasons) + groups = find_independent_seasons(seasons) + for group in groups: + inds = tuple(itertools.chain(*group.inds)) + assert len(inds) == len(set(inds)) + assert len(group.codes) == len(set(group.codes)) diff --git a/pyproject.toml b/pyproject.toml index 5494d4ab484..461e951fcad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -391,6 +391,8 @@ extend-ignore-identifiers-re = [ [tool.typos.default.extend-words] # NumPy function names arange = "arange" +ond = "ond" +aso = "aso" # Technical terms nd = "nd" diff --git a/xarray/compat/toolzcompat.py b/xarray/compat/toolzcompat.py new file mode 100644 index 00000000000..4632419a845 --- /dev/null +++ b/xarray/compat/toolzcompat.py @@ -0,0 +1,56 @@ +# This file contains functions copied from the toolz library in accordance +# with its license. The original copyright notice is duplicated below. + +# Copyright (c) 2013 Matthew Rocklin + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# a. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# b. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# c. Neither the name of toolz nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. + + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + + +def sliding_window(n, seq): + """A sequence of overlapping subsequences + + >>> list(sliding_window(2, [1, 2, 3, 4])) + [(1, 2), (2, 3), (3, 4)] + + This function creates a sliding window suitable for transformations like + sliding means / smoothing + + >>> mean = lambda seq: float(sum(seq)) / len(seq) + >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) + [1.5, 2.5, 3.5] + """ + import collections + import itertools + + return zip( + *( + collections.deque(itertools.islice(it, i), 0) or it + for i, it in enumerate(itertools.tee(seq, n)) + ), + strict=False, + ) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f41e0eea8cb..5ddec186e7f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6886,7 +6886,7 @@ def groupby( >>> da.groupby("letters") + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b'> Execute a reduction @@ -6902,8 +6902,8 @@ def groupby( >>> da.groupby(["letters", "x"]) + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b' + 'x': UniqueGrouper('x'), 4/4 groups with labels 10, 20, 30, 40> Use Grouper objects to express more complicated GroupBy operations diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bf2858c1b18..498ceded87f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9861,7 +9861,7 @@ def groupby( >>> ds.groupby("letters") + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b'> Execute a reduction @@ -9878,8 +9878,8 @@ def groupby( >>> ds.groupby(["letters", "x"]) + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b' + 'x': UniqueGrouper('x'), 4/4 groups with labels 10, 20, 30, 40> Use Grouper objects to express more complicated GroupBy operations diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 6f5472a014a..ec88e7feef3 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -262,6 +262,8 @@ def _ensure_1d( from xarray.core.dataarray import DataArray if isinstance(group, DataArray): + for dim in set(group.dims) - set(obj.dims): + obj = obj.expand_dims(dim) # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) @@ -834,7 +836,10 @@ def __repr__(self) -> str: for grouper in self.groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) - text += f"\n {grouper.name!r}: {coord.size}/{grouper.full_index.size} groups present with labels {labels}" + text += ( + f"\n {grouper.name!r}: {type(grouper.grouper).__name__}({grouper.group.name!r}), " + f"{coord.size}/{grouper.full_index.size} groups with labels {labels}" + ) return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: @@ -1072,7 +1077,7 @@ def _flox_reduce( parsed_dim_list = list() # preserve order for dim_ in itertools.chain( - *(grouper.group.dims for grouper in self.groupers) + *(grouper.codes.dims for grouper in self.groupers) ): if dim_ not in parsed_dim_list: parsed_dim_list.append(dim_) @@ -1086,7 +1091,7 @@ def _flox_reduce( # Better to control it here than in flox. for grouper in self.groupers: if any( - d not in grouper.group.dims and d not in obj.dims for d in parsed_dim + d not in grouper.codes.dims and d not in obj.dims for d in parsed_dim ): raise ValueError(f"cannot reduce over dimensions {dim}.") @@ -1331,9 +1336,6 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - if dim is None: - dim = (self._group_dim,) - # Dataset.quantile does this, do it for flox to ensure same output. q = np.asarray(q, dtype=np.float64) @@ -1352,7 +1354,7 @@ def quantile( self._obj.__class__.quantile, shortcut=False, q=q, - dim=dim, + dim=dim or self._group_dim, method=method, keep_attrs=keep_attrs, skipna=skipna, diff --git a/xarray/groupers.py b/xarray/groupers.py index 025f8fae486..0551b02ae91 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,9 +7,14 @@ from __future__ import annotations import datetime +import functools +import itertools +import operator from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from itertools import pairwise +from itertools import chain, pairwise from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np @@ -17,10 +22,17 @@ from numpy.typing import ArrayLike from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq +from xarray.coding.cftimeindex import CFTimeIndex +from xarray.compat.toolzcompat import sliding_window from xarray.computation.apply_ufunc import apply_ufunc +from xarray.core.common import ( + _contains_cftime_datetimes, + _contains_datetime_like_objects, +) from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import array_all, isnull +from xarray.core.formatting import first_n_items from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -69,9 +81,9 @@ class EncodedGroups: codes: DataArray full_index: pd.Index - group_indices: GroupIndices - unique_coord: Variable | _DummyGroup - coords: Coordinates + group_indices: GroupIndices = field(init=False, repr=False) + unique_coord: Variable | _DummyGroup = field(init=False, repr=False) + coords: Coordinates = field(init=False, repr=False) def __init__( self, @@ -106,7 +118,10 @@ def __init__( self.group_indices = group_indices if unique_coord is None: - unique_values = full_index[np.unique(codes)] + unique_codes = np.sort(pd.unique(codes.data)) + # Skip the -1 sentinel + unique_codes = unique_codes[unique_codes >= 0] + unique_values = full_index[unique_codes] self.unique_coord = Variable( dims=codes.name, data=unique_values, attrs=codes.attrs ) @@ -586,3 +601,371 @@ def unique_value_groups( if isinstance(values, pd.MultiIndex): values.names = ar.names return values, inverse + + +def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: + """ + >>> season_to_month_tuple(["DJF", "MAM", "JJA", "SON"]) + ((12, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)) + >>> season_to_month_tuple(["DJFM", "MAMJ", "JJAS", "SOND"]) + ((12, 1, 2, 3), (3, 4, 5, 6), (6, 7, 8, 9), (9, 10, 11, 12)) + >>> season_to_month_tuple(["DJFM", "SOND"]) + ((12, 1, 2, 3), (9, 10, 11, 12)) + """ + initials = "JFMAMJJASOND" + starts = dict( + ("".join(s), i + 1) + for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=True) + ) + result: list[tuple[int, ...]] = [] + for i, season in enumerate(seasons): + if len(season) == 1: + if i < len(seasons) - 1: + suffix = seasons[i + 1][0] + else: + suffix = seasons[0][0] + else: + suffix = season[1] + + start = starts[season[0] + suffix] + + month_append = [] + for i in range(len(season[1:])): + elem = start + i + 1 + month_append.append(elem - 12 * (elem > 12)) + result.append((start,) + tuple(month_append)) + return tuple(result) + + +def inds_to_season_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]: + inits = "JFMAMJJASOND" + return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) + + +def is_sorted_periodic(lst): + """Used to verify that seasons provided to SeasonResampler are in order.""" + n = len(lst) + + # Find the wraparound point where the list decreases + wrap_point = -1 + for i in range(1, n): + if lst[i] < lst[i - 1]: + wrap_point = i + break + + # If no wraparound point is found, the list is already sorted + if wrap_point == -1: + return True + + # Check if both parts around the wrap point are sorted + for i in range(1, wrap_point): + if lst[i] < lst[i - 1]: + return False + for i in range(wrap_point + 1, n): + if lst[i] < lst[i - 1]: + return False + + # Check wraparound condition + return lst[-1] <= lst[0] + + +@dataclass(kw_only=True, frozen=True) +class SeasonsGroup: + seasons: tuple[str, ...] + # tuple[integer months] corresponding to each season + inds: tuple[tuple[int, ...], ...] + # integer code for each season, this is not simply range(len(seasons)) + # when the seasons have overlaps + codes: Sequence[int] + + +def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]: + """ + Iterates though a list of seasons e.g. ["DJF", "FMA", ...], + and splits that into multiple sequences of non-overlapping seasons. + + >>> find_independent_seasons( + ... ["DJF", "FMA", "AMJ", "JJA", "ASO", "OND"] + ... ) # doctest: +NORMALIZE_WHITESPACE + [SeasonsGroup(seasons=('DJF', 'AMJ', 'ASO'), inds=((12, 1, 2), (4, 5, 6), (8, 9, 10)), codes=[0, 2, 4]), + SeasonsGroup(seasons=('FMA', 'JJA', 'OND'), inds=((2, 3, 4), (6, 7, 8), (10, 11, 12)), codes=[1, 3, 5])] + + >>> find_independent_seasons(["DJF", "MAM", "JJA", "SON"]) + [SeasonsGroup(seasons=('DJF', 'MAM', 'JJA', 'SON'), inds=((12, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)), codes=[0, 1, 2, 3])] + """ + season_inds = season_to_month_tuple(seasons) + grouped = defaultdict(list) + codes = defaultdict(list) + seen: set[tuple[int, ...]] = set() + idx = 0 + # This is quadratic, but the number of seasons is at most 12 + for i, current in enumerate(season_inds): + # Start with a group + if current not in seen: + grouped[idx].append(current) + codes[idx].append(i) + seen.add(current) + + # Loop through remaining groups, and look for overlaps + for j, second in enumerate(season_inds[i:]): + if not (set(chain(*grouped[idx])) & set(second)): + if second not in seen: + grouped[idx].append(second) + codes[idx].append(j + i) + seen.add(second) + if len(seen) == len(seasons): + break + # found all non-overlapping groups for this row, increment and start over + idx += 1 + + grouped_ints = tuple(tuple(idx) for idx in grouped.values() if idx) + return [ + SeasonsGroup(seasons=inds_to_season_string(inds), inds=inds, codes=codes) + for inds, codes in zip(grouped_ints, codes.values(), strict=False) + ] + + +@dataclass +class SeasonGrouper(Grouper): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: sequence of str + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. + Overlapping seasons are allowed (e.g. ``["DJFM", "MAMJ", "JJAS", "SOND"]``) + + Examples + -------- + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) + SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND']) + + The ordering is preserved + + >>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"]) + SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF']) + + Overlapping seasons are allowed + + >>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"]) + SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND']) + """ + + seasons: Sequence[str] + # drop_incomplete: bool = field(default=True) # TODO + + def factorize(self, group: T_Group) -> EncodedGroups: + if TYPE_CHECKING: + assert not isinstance(group, _DummyGroup) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonGrouper can only be used to group by datetime-like arrays." + ) + months = group.dt.month.data + seasons_groups = find_independent_seasons(self.seasons) + codes_ = np.full((len(seasons_groups),) + group.shape, -1, dtype=np.int8) + group_indices: list[list[int]] = [[]] * len(self.seasons) + for axis_index, seasgroup in enumerate(seasons_groups): + for season_tuple, code in zip( + seasgroup.inds, seasgroup.codes, strict=False + ): + mask = np.isin(months, season_tuple) + codes_[axis_index, mask] = code + (indices,) = mask.nonzero() + group_indices[code] = indices.tolist() + + if np.all(codes_ == -1): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + needs_dummy_dim = len(seasons_groups) > 1 + codes = DataArray( + dims=(("__season_dim__",) if needs_dummy_dim else tuple()) + group.dims, + data=codes_ if needs_dummy_dim else codes_.squeeze(), + attrs=group.attrs, + name="season", + ) + unique_coord = Variable("season", self.seasons, attrs=group.attrs) + full_index = pd.Index(self.seasons) + return EncodedGroups( + codes=codes, + group_indices=tuple(group_indices), + unique_coord=unique_coord, + full_index=full_index, + ) + + def reset(self) -> Self: + return type(self)(self.seasons) + + +@dataclass +class SeasonResampler(Resampler): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: Sequence[str] + An ordered list of seasons. + drop_incomplete: bool + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. + + Examples + -------- + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + SeasonResampler(seasons=['JF', 'MAM', 'JJAS', 'OND'], drop_incomplete=True) + + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + SeasonResampler(seasons=['DJFM', 'AM', 'JJA', 'SON'], drop_incomplete=True) + """ + + seasons: Sequence[str] + drop_incomplete: bool = field(default=True, kw_only=True) + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False) + + def __post_init__(self): + self.season_inds = season_to_month_tuple(self.seasons) + all_inds = functools.reduce(operator.add, self.season_inds) + if len(all_inds) > len(set(all_inds)): + raise ValueError( + f"Overlapping seasons are not allowed. Received {self.seasons!r}" + ) + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True)) + + if not is_sorted_periodic(list(itertools.chain(*self.season_inds))): + raise ValueError( + "Resampling is only supported with sorted seasons. " + f"Provided seasons {self.seasons!r} are not sorted." + ) + + def factorize(self, group: T_Group) -> EncodedGroups: + if group.ndim != 1: + raise ValueError( + "SeasonResampler can only be used to resample by 1D arrays." + ) + if not isinstance(group, DataArray) or not _contains_datetime_like_objects( + group.variable + ): + raise ValueError( + "SeasonResampler can only be used to group by datetime-like DataArrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + season_tuples = self.season_tuples + + nstr = max(len(s) for s in seasons) + year = group.dt.year.astype(int) + month = group.dt.month.astype(int) + season_label = np.full(group.shape, "", dtype=f"U{nstr}") + + # offset years for seasons with December and January + for season_str, season_ind in zip(seasons, season_inds, strict=True): + season_label[month.isin(season_ind)] = season_str + if "DJ" in season_str: + after_dec = season_ind[season_str.index("D") + 1 :] + # important: this is assuming non-overlapping seasons + year[month.isin(after_dec)] -= 1 + + # Allow users to skip one or more months? + # present_seasons is a mask that is True for months that are requested in the output + present_seasons = season_label != "" + if present_seasons.all(): + # avoid copies if we can. + present_seasons = slice(None) + frame = pd.DataFrame( + data={ + "index": np.arange(group[present_seasons].size), + "month": month[present_seasons], + }, + index=pd.MultiIndex.from_arrays( + [year.data[present_seasons], season_label[present_seasons]], + names=["year", "season"], + ), + ) + + agged = ( + frame["index"] + .groupby(["year", "season"], sort=False) + .agg(["first", "count"]) + ) + first_items = agged["first"] + counts = agged["count"] + + index_class: type[CFTimeIndex] | type[pd.DatetimeIndex] + if _contains_cftime_datetimes(group.data): + index_class = CFTimeIndex + datetime_class = type(first_n_items(group.data, 1).item()) + else: + index_class = pd.DatetimeIndex + datetime_class = datetime.datetime + + # these are the seasons that are present + unique_coord = index_class( + [ + datetime_class(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) + + # This sorted call is a hack. It's hard to figure out how + # to start the iteration for arbitrary season ordering + # for example "DJF" as first entry or last entry + # So we construct the largest possible index and slice it to the + # range present in the data. + complete_index = index_class( + sorted( + [ + datetime_class(year=y, month=m, day=1) + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) + ) + + # all years and seasons + def get_label(year, season): + month, *_ = season_tuples[season] + return f"{year}-{month:02d}-01" + + unique_codes = np.arange(len(unique_coord)) + valid_season_mask = season_label != "" + first_valid_season, last_valid_season = season_label[valid_season_mask][[0, -1]] + first_year, last_year = year.data[[0, -1]] + if self.drop_incomplete: + if month.data[valid_season_mask][0] != season_tuples[first_valid_season][0]: + if "DJ" in first_valid_season: + first_year += 1 + first_valid_season = seasons[ + (seasons.index(first_valid_season) + 1) % len(seasons) + ] + unique_codes -= 1 + + if ( + month.data[valid_season_mask][-1] + != season_tuples[last_valid_season][-1] + ): + last_valid_season = seasons[seasons.index(last_valid_season) - 1] + if "DJ" in last_valid_season: + last_year -= 1 + unique_codes[-1] = -1 + + first_label = get_label(first_year, first_valid_season) + last_label = get_label(last_year, last_valid_season) + + slicer = complete_index.slice_indexer(first_label, last_label) + full_index = complete_index[slicer] + + final_codes = np.full(group.data.size, -1) + final_codes[present_seasons] = np.repeat(unique_codes, counts) + codes = group.copy(data=final_codes, deep=False) + + return EncodedGroups(codes=codes, full_index=full_index) + + def reset(self) -> Self: + return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 31024d72e60..37a7509b820 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -384,7 +384,10 @@ def create_test_data( pytest.param(cal, marks=requires_cftime) for cal in sorted(_NON_STANDARD_CALENDAR_NAMES) ] -_STANDARD_CALENDARS = [pytest.param(cal) for cal in _STANDARD_CALENDAR_NAMES] +_STANDARD_CALENDARS = [ + pytest.param(cal, marks=requires_cftime if cal != "standard" else ()) + for cal in _STANDARD_CALENDAR_NAMES +] _ALL_CALENDARS = sorted(_STANDARD_CALENDARS + _NON_STANDARD_CALENDARS) _CFTIME_CALENDARS = [ pytest.param(*p.values, marks=requires_cftime) for p in _ALL_CALENDARS diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1c351f0ee62..d80eaa0b6e8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -13,19 +13,23 @@ from packaging.version import Version import xarray as xr -from xarray import DataArray, Dataset, Variable +from xarray import DataArray, Dataset, Variable, date_range from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible from xarray.groupers import ( BinGrouper, EncodedGroups, Grouper, + SeasonGrouper, + SeasonResampler, TimeResampler, UniqueGrouper, + season_to_month_tuple, ) from xarray.namedarray.pycompat import is_chunked_array from xarray.structure.alignment import broadcast from xarray.tests import ( + _ALL_CALENDARS, InaccessibleArray, assert_allclose, assert_equal, @@ -615,7 +619,7 @@ def test_groupby_repr(obj, dim) -> None: N = len(np.unique(obj[dim])) expected = f"<{obj.__class__.__name__}GroupBy" expected += f", grouped over 1 grouper(s), {N} groups in total:" - expected += f"\n {dim!r}: {N}/{N} groups present with labels " + expected += f"\n {dim!r}: UniqueGrouper({dim!r}), {N}/{N} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5>" elif dim == "y": @@ -632,7 +636,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"<{obj.__class__.__name__}GroupBy" expected += ", grouped over 1 grouper(s), 12 groups in total:\n" - expected += " 'month': 12/12 groups present with labels " + expected += " 'month': UniqueGrouper('month'), 12/12 groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected @@ -3292,6 +3296,308 @@ def test_groupby_dask_eager_load_warnings() -> None: ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False) +class TestSeasonGrouperAndResampler: + def test_season_to_month_tuple(self): + assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == ( + (1, 2), + (3, 4, 5), + (6, 7, 8, 9), + (10, 11, 12), + ) + assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == ( + (12, 1, 2, 3), + (4, 5), + (6, 7, 8, 9), + (10, 11), + ) + + def test_season_grouper_raises_error_if_months_are_not_valid_or_not_continuous( + self, + ): + calendar = "standard" + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + with pytest.raises(KeyError, match="IN"): + da.groupby(time=SeasonGrouper(["INVALID_SEASON"])) + + with pytest.raises(KeyError, match="MD"): + da.groupby(time=SeasonGrouper(["MDF"])) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_with_months_spanning_calendar_year_using_same_year( + self, calendar + ): + time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + # fmt: off + data = np.array( + [ + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75, + ] + + ) + # fmt: on + da = DataArray(data, dims="time", coords={"time": time}) + da["year"] = da.time.dt.year + + actual = da.groupby( + year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"]) + ).mean() + + # Expected if the same year "ND" is used for seasonal grouping + expected = xr.DataArray( + data=np.array([[1.38, 1.616667], [1.51, 1.5]]), + dims=["year", "season"], + coords={"year": [2001, 2002], "season": ["NDJFM", "AMJ"]}, + ) + + assert_allclose(expected, actual) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_with_partial_years(self, calendar): + time = date_range("2001-01-01", "2002-06-30", freq="MS", calendar=calendar) + # fmt: off + data = np.array( + [ + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, + ] + ) + # fmt: on + da = DataArray(data, dims="time", coords={"time": time}) + da["year"] = da.time.dt.year + + actual = da.groupby( + year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"]) + ).mean() + + # Expected if partial years are handled correctly + expected = xr.DataArray( + data=np.array([[1.38, 1.616667], [1.43333333, 1.5]]), + dims=["year", "season"], + coords={"year": [2001, 2002], "season": ["NDJFM", "AMJ"]}, + ) + + assert_allclose(expected, actual) + + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_with_single_month_seasons(self, calendar): + time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + # fmt: off + data = np.array( + [ + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75, + ] + ) + # fmt: on + da = DataArray(data, dims="time", coords={"time": time}) + da["year"] = da.time.dt.year + + # TODO: Consider supporting this if needed + # It does not work without flox, because the group labels are not unique, + # and so the stack/unstack approach does not work. + with pytest.raises(ValueError): + da.groupby( + year=UniqueGrouper(), + time=SeasonGrouper( + ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"] + ), + ).mean() + + # Expected if single month seasons are handled correctly + # expected = xr.DataArray( + # data=np.array( + # [ + # [1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7], + # [1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75], + # ] + # ), + # dims=["year", "season"], + # coords={ + # "year": [2001, 2002], + # "season": ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"], + # }, + # ) + # assert_allclose(expected, actual) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( + self, calendar + ): + time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + # fmt: off + data = np.array( + [ + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75, + ] + ) + # fmt: on + da = DataArray(data, dims="time", coords={"time": time}) + + gb = da.resample(time=SeasonResampler(["NDJFM", "AMJ"], drop_incomplete=False)) + actual = gb.mean() + + # fmt: off + new_time_da = xr.DataArray( + dims="time", + data=pd.DatetimeIndex( + [ + "2000-11-01", "2001-04-01", "2001-11-01", "2002-04-01", "2002-11-01" + ] + ), + ) + # fmt: on + if calendar != "standard": + new_time_da = new_time_da.convert_calendar( + calendar=calendar, align_on="date" + ) + new_time = new_time_da.time.variable + + # Expected if the previous "ND" is used for seasonal grouping + expected = xr.DataArray( + data=np.array([1.25, 1.616667, 1.49, 1.5, 1.625]), + dims="time", + coords={"time": new_time}, + ) + assert_allclose(expected, actual) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_simple(self, calendar) -> None: + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + expected = da.groupby("time.season").mean() + # note season order matches expected + actual = da.groupby( + time=SeasonGrouper( + ["DJF", "JJA", "MAM", "SON"], # drop_incomplete=False + ) + ).mean() + assert_identical(expected, actual) + + @pytest.mark.parametrize("seasons", [["JJA", "MAM", "SON", "DJF"]]) + def test_season_resampling_raises_unsorted_seasons(self, seasons): + calendar = "standard" + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + with pytest.raises(ValueError, match="sort"): + da.resample(time=SeasonResampler(seasons)) + + @pytest.mark.parametrize( + "use_cftime", [pytest.param(True, marks=requires_cftime), False] + ) + @pytest.mark.parametrize("drop_incomplete", [True, False]) + @pytest.mark.parametrize( + "seasons", + [ + pytest.param(["DJF", "MAM", "JJA", "SON"], id="standard"), + pytest.param(["NDJ", "FMA", "MJJ", "ASO"], id="nov-first"), + pytest.param(["MAM", "JJA", "SON", "DJF"], id="standard-diff-order"), + pytest.param(["JFM", "AMJ", "JAS", "OND"], id="december-same-year"), + pytest.param(["DJF", "MAM", "JJA", "ON"], id="skip-september"), + pytest.param(["JJAS"], id="jjas-only"), + ], + ) + def test_season_resampler( + self, seasons: list[str], drop_incomplete: bool, use_cftime: bool + ) -> None: + calendar = "standard" + time = date_range( + "2001-01-01", + "2002-12-30", + freq="D", + calendar=calendar, + use_cftime=use_cftime, + ) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + counts = da.resample(time="ME").count() + + seasons_as_ints = season_to_month_tuple(seasons) + month = counts.time.dt.month.data + year = counts.time.dt.year.data + for season, as_ints in zip(seasons, seasons_as_ints, strict=True): + if "DJ" in season: + for imonth in as_ints[season.index("D") + 1 :]: + year[month == imonth] -= 1 + counts["time"] = ( + "time", + [pd.Timestamp(f"{y}-{m}-01") for y, m in zip(year, month, strict=True)], + ) + if has_cftime: + counts = counts.convert_calendar(calendar, "time", align_on="date") + + expected_vals = [] + expected_time = [] + for year in [2001, 2002, 2003]: + for season, as_ints in zip(seasons, seasons_as_ints, strict=True): + out_year = year + if "DJ" in season: + out_year = year - 1 + if out_year == 2003: + # this is a dummy year added to make sure we cover 2002-DJF + continue + available = [ + counts.sel(time=f"{out_year}-{month:02d}").data for month in as_ints + ] + if any(len(a) == 0 for a in available) and drop_incomplete: + continue + output_label = pd.Timestamp(f"{out_year}-{as_ints[0]:02d}-01") + expected_time.append(output_label) + # use concatenate to handle empty array when dec value does not exist + expected_vals.append(np.concatenate(available).sum()) + + expected = ( + # we construct expected in the standard calendar + xr.DataArray(expected_vals, dims="time", coords={"time": expected_time}) + ) + if has_cftime: + # and then convert to the expected calendar, + expected = expected.convert_calendar( + calendar, align_on="date", use_cftime=use_cftime + ) + # and finally sort since DJF will be out-of-order + expected = expected.sortby("time") + + rs = SeasonResampler(seasons, drop_incomplete=drop_incomplete) + # through resample + actual = da.resample(time=rs).sum() + assert_identical(actual, expected) + + @requires_cftime + def test_season_resampler_errors(self): + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + # non-datetime array + with pytest.raises(ValueError): + DataArray(np.ones(5), dims="time").groupby(time=SeasonResampler(["DJF"])) + + # ndim > 1 array + with pytest.raises(ValueError): + DataArray( + np.ones((5, 5)), dims=("t", "x"), coords={"x": np.arange(5)} + ).groupby(x=SeasonResampler(["DJF"])) + + # overlapping seasons + with pytest.raises(ValueError): + da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum() + + @requires_cftime + def test_season_resampler_groupby_identical(self): + time = date_range("2001-01-01", "2002-12-30", freq="D") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + # through resample + resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"]) + rs = da.resample(time=resampler).sum() + + # through groupby + gb = da.groupby(time=resampler).sum() + assert_identical(rs, gb) + + # TODO: Possible property tests to add to this module # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array