Can be between 0 and 1 if coarse cell includes both wet and dry sub-cells.
[43200 values with dtype=float32]
areacello
(yh, xh)
float64
...
long_name :
Ocean Grid-Cell Area
units :
m2
cell_methods :
area:sum yh:sum xh:sum time: point
standard_name :
cell_area
note :
We ignore land cells in partially wet cells when coarsening, so that tracer content can be accurately reconstructed by multiplying coarsened area-averaged tendencies by it. Fully wet (`wet==1.0`) and fully dry (`wet==0.0`) cells should be unaffected, and will just represent the total cell area. For the partially wet cells, total cell area can be derived from the ocean area by divding `areacello` by `wet`.
[43200 values with dtype=float64]
cell_measures :
volume: volcello area: areacello
time_avg_info :
average_T1,average_T2,average_DT
standard_name :
cell_area
note :
We ignore land cells in partially wet cells when coarsening, so that tracer content can be accurately reconstructed by multiplying coarsened area-averaged tendencies by it. Fully wet (`wet==1.0`) and fully dry (`wet==0.0`) cells should be unaffected, and will just represent the total cell area. For the partially wet cells, total cell area can be derived from the ocean area by divding `areacello` by `wet`.
Can be between 0 and 1 if coarse cell includes both wet and dry sub-cells.
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
Array
\n",
+ "
Chunk
\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
Bytes
\n",
+ "
168.75 kiB
\n",
+ "
39.06 kiB
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
Shape
\n",
+ "
(180, 240)
\n",
+ "
(100, 100)
\n",
+ "
\n",
+ "
\n",
+ "
Dask graph
\n",
+ "
6 chunks in 2 graph layers
\n",
+ "
\n",
+ "
\n",
+ "
Data type
\n",
+ "
float32 numpy.ndarray
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
areacello
(yh, xh)
float64
dask.array<chunksize=(100, 100), meta=np.ndarray>
long_name :
Ocean Grid-Cell Area
units :
m2
cell_methods :
area:sum yh:sum xh:sum time: point
standard_name :
cell_area
note :
We ignore land cells in partially wet cells when coarsening, so that tracer content can be accurately reconstructed by multiplying coarsened area-averaged tendencies by it. Fully wet (`wet==1.0`) and fully dry (`wet==0.0`) cells should be unaffected, and will just represent the total cell area. For the partially wet cells, total cell area can be derived from the ocean area by divding `areacello` by `wet`.
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
Array
\n",
+ "
Chunk
\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
Bytes
\n",
+ "
337.50 kiB
\n",
+ "
78.12 kiB
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
Shape
\n",
+ "
(180, 240)
\n",
+ "
(100, 100)
\n",
+ "
\n",
+ "
\n",
+ "
Dask graph
\n",
+ "
6 chunks in 2 graph layers
\n",
+ "
\n",
+ "
\n",
+ "
Data type
\n",
+ "
float64 numpy.ndarray
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
cell_measures :
volume: volcello area: areacello
time_avg_info :
average_T1,average_T2,average_DT
standard_name :
cell_area
note :
We ignore land cells in partially wet cells when coarsening, so that tracer content can be accurately reconstructed by multiplying coarsened area-averaged tendencies by it. Fully wet (`wet==1.0`) and fully dry (`wet==0.0`) cells should be unaffected, and will just represent the total cell area. For the partially wet cells, total cell area can be derived from the ocean area by divding `areacello` by `wet`.
provenance :
heat_lhs_sum_advection_sum
"
],
"text/plain": [
" Size: 12MB\n",
- "array([[[[0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " ...,\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.]],\n",
- "\n",
- " [[0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " ...,\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.]],\n",
- "\n",
- " [[0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " ...,\n",
- "...\n",
- " ...,\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.]],\n",
- "\n",
- " [[0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " ...,\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.]],\n",
- "\n",
- " [[0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " ...,\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.],\n",
- " [0., 0., 0., ..., 0., 0., 0.]]]], shape=(1, 35, 180, 240))\n",
+ "dask.array\n",
"Coordinates:\n",
" * time (time) object 8B 2000-07-01 00:00:00\n",
" * z_l (z_l) float64 280B 2.5 10.0 20.0 32.5 ... 5.5e+03 6e+03 6.5e+03\n",
" * yh (yh) int64 1kB 0 1 2 3 4 5 6 7 ... 173 174 175 176 177 178 179\n",
" * xh (xh) int64 2kB 0 1 2 3 4 5 6 7 ... 233 234 235 236 237 238 239\n",
- " geolon (yh, xh) float64 346kB ...\n",
- " lon (yh, xh) float64 346kB ...\n",
- " geolat (yh, xh) float64 346kB ...\n",
- " lat (yh, xh) float64 346kB ...\n",
- " deptho (yh, xh) float32 173kB ...\n",
- " wet (yh, xh) float32 173kB ...\n",
- " areacello (yh, xh) float64 346kB ...\n",
+ " geolon (yh, xh) float64 346kB dask.array\n",
+ " lon (yh, xh) float64 346kB dask.array\n",
+ " geolat (yh, xh) float64 346kB dask.array\n",
+ " lat (yh, xh) float64 346kB dask.array\n",
+ " deptho (yh, xh) float32 173kB dask.array\n",
+ " wet (yh, xh) float32 173kB dask.array\n",
+ " areacello (yh, xh) float64 346kB dask.array\n",
"Attributes:\n",
" cell_measures: volume: volcello area: areacello\n",
" time_avg_info: average_T1,average_T2,average_DT\n",
diff --git a/examples/load_example_model_grid.py b/examples/load_example_model_grid.py
index c5eb9ac..9fd917f 100644
--- a/examples/load_example_model_grid.py
+++ b/examples/load_example_model_grid.py
@@ -25,6 +25,7 @@ def load_MOM6_example_grid(file_name):
"z_l":xr.DataArray([3000], dims=("z_l",)),
"z_i":xr.DataArray([0,6000], dims=("z_i",))
})
+ ds = ds.chunk({"xh":100, "yh":100, "xq":100, "yq":100, "time":1}) # Chunk up the data to make it more like a user's typical dataset
return construct_grid(ds)
def load_MOM6_coarsened_diagnostics():
diff --git a/xbudget/collect.py b/xbudget/collect.py
index 5c8feb4..509c3ee 100644
--- a/xbudget/collect.py
+++ b/xbudget/collect.py
@@ -158,7 +158,7 @@ def _deep_search(b, new_b={}, k_last=None):
_deep_search(v, new_b=new_b, k_last=k)
return new_b
-def collect_budgets(ds, xbudget_dict):
+def collect_budgets(ds, xbudget_dict, allow_rechunk = True):
"""Fills xbudget dictionary with all tracer content tendencies
Parameters
@@ -179,13 +179,17 @@ def collect_budgets(ds, xbudget_dict):
}
}
}
+ allow_rechunk : bool (default: True)
+ Whether to temporarily rechunk when taking differences along a dimension,
+ e.g. to compute flux divergences on `center` from fluxes on `outer` or
+ tendencies on `center` from snapshots on `outer`.
"""
for eq, v in xbudget_dict.items():
for side in ["lhs", "rhs"]:
if side in v:
- budget_fill_dict(ds, v[side], f"{eq}_{side}")
+ budget_fill_dict(ds, v[side], f"{eq}_{side}", allow_rechunk = allow_rechunk)
-def budget_fill_dict(data, xbudget_dict, namepath):
+def budget_fill_dict(data, xbudget_dict, namepath, allow_rechunk = True):
"""Recursively fill xbudget dictionary
Parameters
@@ -193,6 +197,10 @@ def budget_fill_dict(data, xbudget_dict, namepath):
data : xgcm.grid or xr.Dataset
xbudget_dict : dictionary in xbudget-compatible format containing variable in namepath
namepath : name of variable in dataset (data._ds or data)
+ allow_rechunk : bool (default: True)
+ Whether to temporarily rechunk when taking differences along a dimension,
+ e.g. to compute flux divergences on `center` from fluxes on `outer` or
+ tendencies on `center` from snapshots on `outer`.
"""
if type(data)==xgcm.grid.Grid:
grid = data
@@ -216,7 +224,7 @@ def budget_fill_dict(data, xbudget_dict, namepath):
op_list = []
for k_term, v_term in v.items():
if isinstance(v_term, dict): # recursive call to get this variable
- v_term_recursive = budget_fill_dict(data, v_term, f"{namepath}_{k}_{k_term}")
+ v_term_recursive = budget_fill_dict(data, v_term, f"{namepath}_{k}_{k_term}", allow_rechunk = allow_rechunk)
if v_term_recursive is not None:
op_list.append(v_term_recursive)
elif v_term.get("var") is not None and v_term.get("var") not in ds:
@@ -287,6 +295,7 @@ def budget_fill_dict(data, xbudget_dict, namepath):
if var_pref is None:
var_pref = var.copy()
+
if k == "difference":
if grid is not None:
staggered_axes = {
@@ -294,26 +303,51 @@ def budget_fill_dict(data, xbudget_dict, namepath):
for pos,c in ax.coords.items()
if pos!="center"
}
- v_term = [v_term for k_term,v_term in v.items() if k_term!="var"][0]
- if v_term not in ds:
- warnings.warn(f"Variable {v_term} is missing from the dataset `ds`, so it is being skipped. To suppress this warning, remove {v_term} from the `xbudget_dict`.")
- continue
- candidate_axes = [axn for (axn,c) in staggered_axes.items() if c in ds[v_term].dims]
- if len(candidate_axes) == 1:
- axis = candidate_axes[0]
- else:
- raise ValueError("Flux difference inconsistent with finite volume discretization.")
+ v_term = [v_term for k_term,v_term in v.items() if k_term!="var"][0]
+ if v_term not in ds:
+ warnings.warn(f"Variable {v_term} is missing from the dataset `ds`, so it is being skipped. To suppress this warning, remove {v_term} from the `xbudget_dict`.")
+ continue
+
+ candidate_axes = [axn for (axn,c) in staggered_axes.items() if c in ds[v_term].dims]
+ if len(candidate_axes) == 1:
+ axis = candidate_axes[0]
+ else:
+ raise ValueError("Finite difference inconsistent with finite volume discretization.")
+
+ if allow_rechunk:
+ try: #extract original chunks when possible
+ #not using ds[v_term] since it may not have the non-staggered dimension chunks.
+ original_chunks = dict(ds.chunksizes)
+ except Exception:
+ warnings.warn("Dataset chunks are inconsistent; using unify_chunks()", UserWarning)
+ original_chunks = dict(ds.unify_chunks().chunksizes)
+
+ # Find the staggered dimension for the given axis in the DataArray
+ axis_dim = [d for d in ds[v_term].dims if d in grid.axes[axis].coords.values()]
+ if len(axis_dim) != 1:
+ raise ValueError(f"Expected to find one dimension for axis '{axis}' in variable '{v_term}', but found {len(axis_dim)}: {axis_dim}")
+ axis_dim = axis_dim[0]
+
+ # Temporarily rechunk to put the difference dim in a single chunk, all other chunks are auto.
+ temporary_chunks = {axis_dim: -1, **{d: "auto" for d in ds[v_term].dims if d != axis_dim}}
+ var = grid.diff(ds[v_term].chunk(temporary_chunks).fillna(0.0), axis=axis)
+ # Attempt original chunking for preserved dimensions
+ var = var.chunk({d: original_chunks.get(d, var.chunksizes[d]) for d in var.dims})
+ else:
var = grid.diff(ds[v_term].fillna(0.), axis)
- var_name = f"{namepath}_difference"
- var = var.rename(var_name)
- var_provenance = v_term
- var.attrs["provenance"] = var_provenance
- ds[var_name] = var
- if var_pref is None:
- var_pref = var.copy()
+
+ var_name = f"{namepath}_difference"
+ var = var.rename(var_name)
+ var_provenance = v_term
+ var.attrs["provenance"] = var_provenance
+ ds[var_name] = var
+ if var_pref is None:
+ var_pref = var.copy()
else:
raise ValueError("Input `ds` must be `xgcm.Grid` instance if using `difference` operations.")
+
+
return var_pref
def get_vars(xbudget_dict, terms):
diff --git a/xbudget/tests/test_utilities.py b/xbudget/tests/test_utilities.py
index 1ba5388..1d39902 100644
--- a/xbudget/tests/test_utilities.py
+++ b/xbudget/tests/test_utilities.py
@@ -2,6 +2,8 @@
import numpy as np
import xarray as xr
import copy
+import xgcm
+import dask.array as da
from xbudget.collect import (
aggregate,
disaggregate,
@@ -437,4 +439,68 @@ def test_budget_fill_dict_numeric_values(self):
result = budget_fill_dict(ds, xbudget_dict, "heat_rhs")
assert result is not None
- assert np.allclose(ds["heat_rhs_product"].values, 2.0)
\ No newline at end of file
+ assert np.allclose(ds["heat_rhs_product"].values, 2.0)
+
+ def test_budget_fill_dict_allow_rechunk(self):
+ """Test the allow_rechunk option for the difference operation."""
+ # Create a dataset with non-uniform chunks on the staggered grid,
+ # which would cause issues for xgcm.grid.diff
+ flux_data = da.from_array(np.random.rand(5, 3), chunks=((2, 2, 1), 3))
+ ds_chunked = xr.Dataset(
+ {
+ "var": xr.DataArray(
+ flux_data,
+ dims=("x_g", "y_c"),
+ )
+ },
+ coords={
+ "x_g": np.arange(5),
+ "x_c": np.arange(4) + 0.5,
+ "y_c": np.arange(3),
+ },
+ )
+
+ grid_params = {
+ "coords": {"X": {"center": "x_c", "left": "x_g"}},
+ "periodic": False,
+ "autoparse_metadata": False,
+ }
+
+ xbudget_dict = {
+ "var": None,
+ "difference": {"var_diff": "var", "var": None},
+ }
+
+ # 1. Test that allow_rechunk=False raises an error when passing a chunked
+ # dataset through budget_fill_dict
+ with pytest.raises(ValueError):
+ grid_fail = xgcm.Grid(ds_chunked.copy(deep=True), **grid_params)
+ budget_fill_dict(
+ grid_fail,
+ copy.deepcopy(xbudget_dict),
+ "tendency_rhs",
+ allow_rechunk=False,
+ )
+
+ # 2. Test that shows allow_rechunk=True works
+ grid_success = xgcm.Grid(ds_chunked.copy(deep=True), **grid_params)
+ budget_fill_dict(
+ grid_success,
+ copy.deepcopy(xbudget_dict),
+ "tendency_rhs",
+ allow_rechunk=True,
+ )
+ tendency_rechunked = grid_success._ds["tendency_rhs_difference"]
+
+ # 3. Compare with a correct result from an unchunked array
+ grid_unchunked = xgcm.Grid(ds_chunked.chunk(-1), **grid_params)
+ budget_fill_dict(
+ grid_unchunked,
+ copy.deepcopy(xbudget_dict),
+ "tendency_rhs",
+ allow_rechunk=False,
+ )
+ tendency_correct = grid_unchunked._ds["tendency_rhs_difference"]
+
+ # The numerical results should be identical
+ xr.testing.assert_allclose(tendency_rechunked, tendency_correct)