diff --git a/.gitignore b/.gitignore index d313308..59cec8f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ __pycache__ .ipynb_checkpoints -/data/data +/data/* diff --git a/examples/MOM6_budget_examples_mass_heat_salt.ipynb b/examples/MOM6_budget_examples_mass_heat_salt.ipynb index 9316595..8a58220 100644 --- a/examples/MOM6_budget_examples_mass_heat_salt.ipynb +++ b/examples/MOM6_budget_examples_mass_heat_salt.ipynb @@ -1047,111 +1047,131 @@ " stroke-width: 0.8px;\n", "}\n", "
<xarray.DataArray 'heat_lhs_sum_advection' (time: 1, z_l: 35, yh: 180, xh: 240)> Size: 12MB\n",
-       "array([[[[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "...\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]]]], shape=(1, 35, 180, 240))\n",
+       "dask.array<add, shape=(1, 35, 180, 240), dtype=float64, chunksize=(1, 35, 100, 100), chunktype=numpy.ndarray>\n",
        "Coordinates:\n",
        "  * time       (time) object 8B 2000-07-01 00:00:00\n",
        "  * z_l        (z_l) float64 280B 2.5 10.0 20.0 32.5 ... 5.5e+03 6e+03 6.5e+03\n",
        "  * yh         (yh) int64 1kB 0 1 2 3 4 5 6 7 ... 173 174 175 176 177 178 179\n",
        "  * xh         (xh) int64 2kB 0 1 2 3 4 5 6 7 ... 233 234 235 236 237 238 239\n",
-       "    geolon     (yh, xh) float64 346kB ...\n",
-       "    lon        (yh, xh) float64 346kB ...\n",
-       "    geolat     (yh, xh) float64 346kB ...\n",
-       "    lat        (yh, xh) float64 346kB ...\n",
-       "    deptho     (yh, xh) float32 173kB ...\n",
-       "    wet        (yh, xh) float32 173kB ...\n",
-       "    areacello  (yh, xh) float64 346kB ...\n",
+       "    geolon     (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    lon        (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    geolat     (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    lat        (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    deptho     (yh, xh) float32 173kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    wet        (yh, xh) float32 173kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    areacello  (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
        "Attributes:\n",
        "    cell_measures:  volume: volcello area: areacello\n",
        "    time_avg_info:  average_T1,average_T2,average_DT\n",
        "    standard_name:  cell_area\n",
        "    note:           We ignore land cells in partially wet cells when coarseni...\n",
-       "    provenance:     heat_lhs_sum_advection_sum
  • cell_measures :
    volume: volcello area: areacello
    time_avg_info :
    average_T1,average_T2,average_DT
    standard_name :
    cell_area
    note :
    We ignore land cells in partially wet cells when coarsening, so that tracer content can be accurately reconstructed by multiplying coarsened area-averaged tendencies by it. Fully wet (`wet==1.0`) and fully dry (`wet==0.0`) cells should be unaffected, and will just represent the total cell area. For the partially wet cells, total cell area can be derived from the ocean area by divding `areacello` by `wet`.
    provenance :
    heat_lhs_sum_advection_sum
  • " ], "text/plain": [ " Size: 12MB\n", - "array([[[[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - "...\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]]]], shape=(1, 35, 180, 240))\n", + "dask.array\n", "Coordinates:\n", " * time (time) object 8B 2000-07-01 00:00:00\n", " * z_l (z_l) float64 280B 2.5 10.0 20.0 32.5 ... 5.5e+03 6e+03 6.5e+03\n", " * yh (yh) int64 1kB 0 1 2 3 4 5 6 7 ... 173 174 175 176 177 178 179\n", " * xh (xh) int64 2kB 0 1 2 3 4 5 6 7 ... 233 234 235 236 237 238 239\n", - " geolon (yh, xh) float64 346kB ...\n", - " lon (yh, xh) float64 346kB ...\n", - " geolat (yh, xh) float64 346kB ...\n", - " lat (yh, xh) float64 346kB ...\n", - " deptho (yh, xh) float32 173kB ...\n", - " wet (yh, xh) float32 173kB ...\n", - " areacello (yh, xh) float64 346kB ...\n", + " geolon (yh, xh) float64 346kB dask.array\n", + " lon (yh, xh) float64 346kB dask.array\n", + " geolat (yh, xh) float64 346kB dask.array\n", + " lat (yh, xh) float64 346kB dask.array\n", + " deptho (yh, xh) float32 173kB dask.array\n", + " wet (yh, xh) float32 173kB dask.array\n", + " areacello (yh, xh) float64 346kB dask.array\n", "Attributes:\n", " cell_measures: volume: volcello area: areacello\n", " time_avg_info: average_T1,average_T2,average_DT\n", diff --git a/examples/load_example_model_grid.py b/examples/load_example_model_grid.py index c5eb9ac..9fd917f 100644 --- a/examples/load_example_model_grid.py +++ b/examples/load_example_model_grid.py @@ -25,6 +25,7 @@ def load_MOM6_example_grid(file_name): "z_l":xr.DataArray([3000], dims=("z_l",)), "z_i":xr.DataArray([0,6000], dims=("z_i",)) }) + ds = ds.chunk({"xh":100, "yh":100, "xq":100, "yq":100, "time":1}) # Chunk up the data to make it more like a user's typical dataset return construct_grid(ds) def load_MOM6_coarsened_diagnostics(): diff --git a/xbudget/collect.py b/xbudget/collect.py index 5c8feb4..509c3ee 100644 --- a/xbudget/collect.py +++ b/xbudget/collect.py @@ -158,7 +158,7 @@ def _deep_search(b, new_b={}, k_last=None): _deep_search(v, new_b=new_b, k_last=k) return new_b -def collect_budgets(ds, xbudget_dict): +def collect_budgets(ds, xbudget_dict, allow_rechunk = True): """Fills xbudget dictionary with all tracer content tendencies Parameters @@ -179,13 +179,17 @@ def collect_budgets(ds, xbudget_dict): } } } + allow_rechunk : bool (default: True) + Whether to temporarily rechunk when taking differences along a dimension, + e.g. to compute flux divergences on `center` from fluxes on `outer` or + tendencies on `center` from snapshots on `outer`. """ for eq, v in xbudget_dict.items(): for side in ["lhs", "rhs"]: if side in v: - budget_fill_dict(ds, v[side], f"{eq}_{side}") + budget_fill_dict(ds, v[side], f"{eq}_{side}", allow_rechunk = allow_rechunk) -def budget_fill_dict(data, xbudget_dict, namepath): +def budget_fill_dict(data, xbudget_dict, namepath, allow_rechunk = True): """Recursively fill xbudget dictionary Parameters @@ -193,6 +197,10 @@ def budget_fill_dict(data, xbudget_dict, namepath): data : xgcm.grid or xr.Dataset xbudget_dict : dictionary in xbudget-compatible format containing variable in namepath namepath : name of variable in dataset (data._ds or data) + allow_rechunk : bool (default: True) + Whether to temporarily rechunk when taking differences along a dimension, + e.g. to compute flux divergences on `center` from fluxes on `outer` or + tendencies on `center` from snapshots on `outer`. """ if type(data)==xgcm.grid.Grid: grid = data @@ -216,7 +224,7 @@ def budget_fill_dict(data, xbudget_dict, namepath): op_list = [] for k_term, v_term in v.items(): if isinstance(v_term, dict): # recursive call to get this variable - v_term_recursive = budget_fill_dict(data, v_term, f"{namepath}_{k}_{k_term}") + v_term_recursive = budget_fill_dict(data, v_term, f"{namepath}_{k}_{k_term}", allow_rechunk = allow_rechunk) if v_term_recursive is not None: op_list.append(v_term_recursive) elif v_term.get("var") is not None and v_term.get("var") not in ds: @@ -287,6 +295,7 @@ def budget_fill_dict(data, xbudget_dict, namepath): if var_pref is None: var_pref = var.copy() + if k == "difference": if grid is not None: staggered_axes = { @@ -294,26 +303,51 @@ def budget_fill_dict(data, xbudget_dict, namepath): for pos,c in ax.coords.items() if pos!="center" } - v_term = [v_term for k_term,v_term in v.items() if k_term!="var"][0] - if v_term not in ds: - warnings.warn(f"Variable {v_term} is missing from the dataset `ds`, so it is being skipped. To suppress this warning, remove {v_term} from the `xbudget_dict`.") - continue - candidate_axes = [axn for (axn,c) in staggered_axes.items() if c in ds[v_term].dims] - if len(candidate_axes) == 1: - axis = candidate_axes[0] - else: - raise ValueError("Flux difference inconsistent with finite volume discretization.") + v_term = [v_term for k_term,v_term in v.items() if k_term!="var"][0] + if v_term not in ds: + warnings.warn(f"Variable {v_term} is missing from the dataset `ds`, so it is being skipped. To suppress this warning, remove {v_term} from the `xbudget_dict`.") + continue + + candidate_axes = [axn for (axn,c) in staggered_axes.items() if c in ds[v_term].dims] + if len(candidate_axes) == 1: + axis = candidate_axes[0] + else: + raise ValueError("Finite difference inconsistent with finite volume discretization.") + + if allow_rechunk: + try: #extract original chunks when possible + #not using ds[v_term] since it may not have the non-staggered dimension chunks. + original_chunks = dict(ds.chunksizes) + except Exception: + warnings.warn("Dataset chunks are inconsistent; using unify_chunks()", UserWarning) + original_chunks = dict(ds.unify_chunks().chunksizes) + + # Find the staggered dimension for the given axis in the DataArray + axis_dim = [d for d in ds[v_term].dims if d in grid.axes[axis].coords.values()] + if len(axis_dim) != 1: + raise ValueError(f"Expected to find one dimension for axis '{axis}' in variable '{v_term}', but found {len(axis_dim)}: {axis_dim}") + axis_dim = axis_dim[0] + + # Temporarily rechunk to put the difference dim in a single chunk, all other chunks are auto. + temporary_chunks = {axis_dim: -1, **{d: "auto" for d in ds[v_term].dims if d != axis_dim}} + var = grid.diff(ds[v_term].chunk(temporary_chunks).fillna(0.0), axis=axis) + # Attempt original chunking for preserved dimensions + var = var.chunk({d: original_chunks.get(d, var.chunksizes[d]) for d in var.dims}) + else: var = grid.diff(ds[v_term].fillna(0.), axis) - var_name = f"{namepath}_difference" - var = var.rename(var_name) - var_provenance = v_term - var.attrs["provenance"] = var_provenance - ds[var_name] = var - if var_pref is None: - var_pref = var.copy() + + var_name = f"{namepath}_difference" + var = var.rename(var_name) + var_provenance = v_term + var.attrs["provenance"] = var_provenance + ds[var_name] = var + if var_pref is None: + var_pref = var.copy() else: raise ValueError("Input `ds` must be `xgcm.Grid` instance if using `difference` operations.") + + return var_pref def get_vars(xbudget_dict, terms): diff --git a/xbudget/tests/test_utilities.py b/xbudget/tests/test_utilities.py index 1ba5388..1d39902 100644 --- a/xbudget/tests/test_utilities.py +++ b/xbudget/tests/test_utilities.py @@ -2,6 +2,8 @@ import numpy as np import xarray as xr import copy +import xgcm +import dask.array as da from xbudget.collect import ( aggregate, disaggregate, @@ -437,4 +439,68 @@ def test_budget_fill_dict_numeric_values(self): result = budget_fill_dict(ds, xbudget_dict, "heat_rhs") assert result is not None - assert np.allclose(ds["heat_rhs_product"].values, 2.0) \ No newline at end of file + assert np.allclose(ds["heat_rhs_product"].values, 2.0) + + def test_budget_fill_dict_allow_rechunk(self): + """Test the allow_rechunk option for the difference operation.""" + # Create a dataset with non-uniform chunks on the staggered grid, + # which would cause issues for xgcm.grid.diff + flux_data = da.from_array(np.random.rand(5, 3), chunks=((2, 2, 1), 3)) + ds_chunked = xr.Dataset( + { + "var": xr.DataArray( + flux_data, + dims=("x_g", "y_c"), + ) + }, + coords={ + "x_g": np.arange(5), + "x_c": np.arange(4) + 0.5, + "y_c": np.arange(3), + }, + ) + + grid_params = { + "coords": {"X": {"center": "x_c", "left": "x_g"}}, + "periodic": False, + "autoparse_metadata": False, + } + + xbudget_dict = { + "var": None, + "difference": {"var_diff": "var", "var": None}, + } + + # 1. Test that allow_rechunk=False raises an error when passing a chunked + # dataset through budget_fill_dict + with pytest.raises(ValueError): + grid_fail = xgcm.Grid(ds_chunked.copy(deep=True), **grid_params) + budget_fill_dict( + grid_fail, + copy.deepcopy(xbudget_dict), + "tendency_rhs", + allow_rechunk=False, + ) + + # 2. Test that shows allow_rechunk=True works + grid_success = xgcm.Grid(ds_chunked.copy(deep=True), **grid_params) + budget_fill_dict( + grid_success, + copy.deepcopy(xbudget_dict), + "tendency_rhs", + allow_rechunk=True, + ) + tendency_rechunked = grid_success._ds["tendency_rhs_difference"] + + # 3. Compare with a correct result from an unchunked array + grid_unchunked = xgcm.Grid(ds_chunked.chunk(-1), **grid_params) + budget_fill_dict( + grid_unchunked, + copy.deepcopy(xbudget_dict), + "tendency_rhs", + allow_rechunk=False, + ) + tendency_correct = grid_unchunked._ds["tendency_rhs_difference"] + + # The numerical results should be identical + xr.testing.assert_allclose(tendency_rechunked, tendency_correct)