From f056f4cc33e1696e2983d0603b5fefd9b4ab7c7b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 6 Mar 2026 11:15:59 -0700 Subject: [PATCH 01/26] change base loss name - not always used for physics loss start modifying obs model to optionally use "real" obs in sparse high res data --- sup3r/models/with_obs.py | 51 +++++- sup3r/utilities/loss_metrics.py | 78 +++++--- tests/training/test_train_target_obs.py | 232 ++++++++++++++++++++++++ 3 files changed, 330 insertions(+), 31 deletions(-) create mode 100644 tests/training/test_train_target_obs.py diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index 25ea107e94..ab2d8f3862 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -12,6 +12,24 @@ logger = logging.getLogger(__name__) +# TODO: Refactor so the observations can be either "proxies" sampled from +# gridded high res or "real" observations drawn from the .obs attribute of +# the batches in the batch handler. +# - either construct mask like currently done or calculate it from the +# .obs data by checking for NaN values. The former is done when using +# sampled proxies for training and the latter is done when using real +# observations for training. +# - might be able to remove obs loss weighting and delegate this to loss +# functions +# - add flag in init for proxies vs real. +# - or maybe the obs data can just be supplied through the high_res attr +# and can allow for some NaNs. Then the mask can be calculated on the high +# res data by checking for NaNs and this would work for both proxies and +# real obs. This would be simpler and more flexible but would require +# changes to the batch handler to allow for NaNs in the high res data +# - might be some quirks with tracking output features + + class Sup3rGanWithObs(Sup3rGan): """Sup3r GAN model which includes mid network observation fusion. This model is useful for when production runs will be over a domain for which @@ -40,6 +58,7 @@ class Sup3rGanWithObs(Sup3rGan): def __init__( self, *args, + use_proxy_obs=True, onshore_obs_frac=None, offshore_obs_frac=None, loss_obs_weight=0.0, @@ -53,6 +72,15 @@ def __init__( ---------- args : list Positional args for ``Sup3rGan`` parent class. + use_proxy_obs : bool + Whether to use proxy observations sampled from the gridded high res + data during training. If False, the model will expect real + observation data in the .high_res attribute of the batches in the + batch handler and will calculate the observation mask based on + where there are NaN values in the high res data. If True, the model + will create synthetic observation data by masking the gridded high + res data during training and will calculate the observation mask + based on the specified onshore and offshore observation fractions. onshore_obs_frac : dict[List] | dict[float] Fraction of the batch that should be treated as onshore observations. Should include ``spatial`` key and optionally @@ -79,6 +107,7 @@ def __init__( Keyword arguments for the ``Sup3rGan`` parent class. """ super().__init__(*args, **kwargs) + self.use_proxy_obs = use_proxy_obs self.onshore_obs_frac = ( {} if onshore_obs_frac is None else onshore_obs_frac ) @@ -105,14 +134,20 @@ def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask): @property def obs_training_inds(self): - """Get the indices of the observation features in the true high res - data. Obs features have an _obs suffix to avoid name conflict with - fully gridded features. During training these are matched with the - true high res data.""" - hr_feats = [f.replace('_obs', '') for f in self.hr_features] - obs_inds = [ - hr_feats.index(f.replace('_obs', '')) for f in self.obs_features - ] + """Get the observation feature indices in the true high res + data. True observation features are named with an '_obs' suffix. + When training with proxy observations these indices select + the corresponding gridded features (no '_obs' suffix). Otherwise, + these indices select observation features with an '_obs' suffix.""" + + if self.use_proxy_obs: + hr_feats = [f.replace('_obs', '') for f in self.hr_features] + obs_inds = [ + hr_feats.index(f.replace('_obs', '')) + for f in self.obs_features + ] + else: + obs_inds = [self.hr_features.index(f) for f in self.obs_features] return obs_inds def _get_single_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 07600a91f4..0fd3166b03 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -9,26 +9,34 @@ from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError -class PhysicsBasedLoss(tf.keras.losses.Loss): - """Base class for physics-based loss metrics. This is meant to be used as a +class Sup3rLoss(tf.keras.losses.Loss): + """Base class for custom sup3r loss metrics. This is meant to be used as a base class for loss metrics that require specific input features.""" - def __init__(self, input_features='all'): + def __init__(self, input_features='all', obs_features=None): """Initialize the loss with given input features Parameters ---------- input_features : list | str List of input features that the loss metric will be calculated on. - This is meant to be used for physics-based loss metrics that - require specific input features. If 'all', the loss will be - calculated on all features. Otherwise, the loss will be calculated - on the features specified in the list. The order of features in - the list will be checked to determine the order of features in the - input tensors. + If 'all', the loss will be calculated on all features. Otherwise, + the loss will be calculated on the features specified in the list. + The order of features in the list will be checked to determine the + order of features in the input tensors. + obs_features : list | None + Optional list of observation features to use as targets for the + loss metric. This is typically used in a physics based loss + when the ground truth data is sparse (e.g. observation points). + In this case a physics constraint is applied where there are no + observations, and an additional content loss is calculated for + points where observations are available. The order of features + in the list will be checked to determine the order of features in + the input tensors. """ super().__init__() self.input_features = input_features + self.obs_features = obs_features def tf_derivative(x, axis=1): @@ -117,7 +125,7 @@ def gaussian_kernel(x1, x2, sigma=1.0): return result -class ExpLoss(tf.keras.losses.Loss): +class ExpLoss(Sup3rLoss): """Loss class for squared exponential difference""" def __call__(self, x1, x2): @@ -140,7 +148,7 @@ def __call__(self, x1, x2): return tf.reduce_mean(1 - tf.exp(-((x1 - x2) ** 2))) -class MmdLoss(tf.keras.losses.Loss): +class MmdLoss(Sup3rLoss): """Loss class for max mean discrepancy loss""" def __call__(self, x1, x2, sigma=1.0): @@ -169,7 +177,7 @@ def __call__(self, x1, x2, sigma=1.0): return mmd -class SpatialDerivativeLoss(tf.keras.losses.Loss): +class SpatialDerivativeLoss(Sup3rLoss): """Loss class to encourage accurary of spatial derivatives.""" LOSS_METRIC = MeanAbsoluteError() @@ -204,7 +212,7 @@ def __call__(self, x1, x2): return self.LOSS_METRIC(x1_div, x2_div) -class TemporalDerivativeLoss(tf.keras.losses.Loss): +class TemporalDerivativeLoss(Sup3rLoss): """Loss class to encourage accurary of temporal derivative.""" LOSS_METRIC = MeanAbsoluteError() @@ -238,7 +246,7 @@ def __call__(self, x1, x2): return self.LOSS_METRIC(x1_div, x2_div) -class CoarseMseLoss(tf.keras.losses.Loss): +class CoarseMseLoss(Sup3rLoss): """Loss class for coarse mse on spatial average of 5D tensor""" MSE_LOSS = MeanSquaredError() @@ -266,7 +274,7 @@ def __call__(self, x1, x2): return self.MSE_LOSS(x1_coarse, x2_coarse) -class SpatialExtremesLoss(tf.keras.losses.Loss): +class SpatialExtremesLoss(Sup3rLoss): """Loss class that encourages accuracy of the min/max values in the spatial domain. This does not include an additional MAE term""" @@ -301,7 +309,7 @@ def __call__(self, x1, x2): return (mae_min + mae_max) / 2 -class TemporalExtremesLoss(tf.keras.losses.Loss): +class TemporalExtremesLoss(Sup3rLoss): """Loss class that encourages accuracy of the min/max values in the timeseries. This does not include an additional mae term""" @@ -336,7 +344,7 @@ def __call__(self, x1, x2): return (mae_min + mae_max) / 2 -class SpatialFftLoss(tf.keras.losses.Loss): +class SpatialFftLoss(Sup3rLoss): """Loss class that encourages accuracy of the spatial frequency spectrum""" MAE_LOSS = MeanAbsoluteError() @@ -381,7 +389,7 @@ def __call__(self, x1, x2): return self.MAE_LOSS(x1_hat, x2_hat) -class SpatiotemporalFftLoss(tf.keras.losses.Loss): +class SpatiotemporalFftLoss(Sup3rLoss): """Loss class that encourages accuracy of the spatiotemporal frequency spectrum""" @@ -429,7 +437,7 @@ def __call__(self, x1, x2): return self.MAE_LOSS(x1_hat, x2_hat) -class LowResLoss(tf.keras.losses.Loss): +class LowResLoss(Sup3rLoss): """Content loss that is calculated by coarsening the synthetic and true high-resolution data pairs and then performing the pointwise content loss on the low-resolution fields""" @@ -582,7 +590,7 @@ def __call__(self, x1, x2): return self._tf_loss(x1, x2) + ex_loss -class PerceptualLoss(tf.keras.losses.Loss): +class PerceptualLoss(Sup3rLoss): """Perceptual loss that is calculated as MSE between feature maps of ground truth and synthetic data""" @@ -665,7 +673,7 @@ def __call__(self, x1, x2): return tf.reduce_mean(losses) -class SlicedWassersteinLoss(tf.keras.losses.Loss): +class SlicedWassersteinLoss(Sup3rLoss): """Loss class for sliced wasserstein distance loss""" def __init__(self, n_projections=1024): @@ -733,7 +741,7 @@ def __call__(self, x1, x2): return tf.reduce_mean((x1_sorted - x2_sorted) ** 2) -class MaterialDerivativeLoss(PhysicsBasedLoss): +class MaterialDerivativeLoss(Sup3rLoss): """Loss class for the material derivative. This is the left hand side of the Navier-Stokes equation and is equal to internal + external forces divided by density in general. Under certain simplifying assumptions, this @@ -836,7 +844,7 @@ def __call__(self, x1, x2): return self.LOSS_METRIC(x1_div, x2_div) -class GeothermalPhysicsLoss(PhysicsBasedLoss): +class GeothermalPhysicsLoss(Sup3rLoss): """Physics based loss for Geothermal applications TODO: Fill in call with appropriate physics equations. This is currently @@ -857,3 +865,27 @@ def __call__(self, x1, x2): assert check, msg return self.LOSS_METRIC(x1, x2) + + +class GeothermalPhysicsLossWithObs(Sup3rLoss): + """Physics based loss for Geothermal applications + + TODO: Fill in call with appropriate physics equations. This is currently + just a dummy equation for testing. + """ + + LOSS_METRIC = MeanAbsoluteError() + + def __call__(self, x1, x2): + """Geothermal physics loss""" + check = x1.shape[-1] == len(self.input_features) + check &= x2.shape[-1] == len(self.obs_features) + msg = ( + f'Number of features in `x1`: {x1.shape[-1]}, must match the ' + f'length of `input_features`: {len(self.input_features)}, and ' + f'number of features in `x2`: {x2.shape[-1]}, must match the ' + f'length of `obs_features`: {len(self.obs_features)}' + ) + assert check, msg + + return self.LOSS_METRIC(x1, x2) diff --git a/tests/training/test_train_target_obs.py b/tests/training/test_train_target_obs.py new file mode 100644 index 0000000000..2399b21ba0 --- /dev/null +++ b/tests/training/test_train_target_obs.py @@ -0,0 +1,232 @@ +"""Test the training of GANs with dual data handler""" + +import itertools +import os +import tempfile + +import numpy as np +import pytest + +from sup3r.models import Sup3rGanWithObs +from sup3r.preprocessing import ( + Container, + DataHandler, + DualBatchHandler, + DualRasterizer, +) +from sup3r.preprocessing.samplers import DualSampler +from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory + +TARGET_COORD = (39.01, -105.15) +FEATURES = ['u_100m', 'v_100m'] + + +DualBatchHandlerWithObsTester = BatchHandlerTesterFactory( + DualBatchHandler, DualSampler +) + + +@pytest.mark.parametrize( + [ + 'fp_gen', + 'fp_disc', + 's_enhance', + 't_enhance', + 'sample_shape', + 'mode', + ], + [ + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), + ], +) +def test_train_h5_nc( + fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 +): + """Test model training with a dual data handler / batch handler with h5 and + era as hr / lr datasets. Tests both spatiotemporal and spatial models.""" + + lr = 1e-5 + kwargs = { + 'features': FEATURES, + 'target': TARGET_COORD, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=s_enhance, + t_enhance=t_enhance, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = list(range(0, 20, 4)) + lon_ids = list(range(0, 20, 4)) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[feat] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': dual_rasterizer.high_res, + 'obs': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=3, + mode=mode, + ) + + for batch in batch_handler: + assert hasattr(batch, 'obs') + assert not np.isnan(batch.obs).all() + assert np.isnan(batch.obs).any() + + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': n_epoch, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tlossg = model.history['train_loss_gen'].values + tlosso = model.history['train_loss_obs'].values + assert np.sum(np.diff(tlossg)) < 0 + assert np.sum(np.diff(tlosso)) < 0 + + +@pytest.mark.parametrize( + [ + 'fp_gen', + 'fp_disc', + 's_enhance', + 't_enhance', + 'sample_shape', + 'mode', + ], + [ + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), + (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), + (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), + ], +) +def test_train_coarse_h5( + fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 +): + """Test model training with a dual data handler / batch handler with + additional sparse observation data used in extra content loss term. Tests + both spatiotemporal and spatial models.""" + + lr = 1e-5 + kwargs = { + 'features': FEATURES, + 'target': TARGET_COORD, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + hr_spatial_coarsen=s_enhance, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=s_enhance, + t_enhance=t_enhance, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = list(range(0, 20, 4)) + lon_ids = list(range(0, 20, 4)) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[feat] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': dual_rasterizer.high_res, + 'obs': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=3, + mode=mode, + ) + + for batch in batch_handler: + assert hasattr(batch, 'obs') + assert not np.isnan(batch.obs).all() + assert np.isnan(batch.obs).any() + + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( + fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': n_epoch, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tlossg = model.history['train_loss_gen'].values + tlosso = model.history['train_loss_obs'].values + assert np.sum(np.diff(tlossg)) < 0 + assert np.sum(np.diff(tlosso)) < 0 From cf34689f26777df54c371f6ed577cd760bf0d20e Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 9 Mar 2026 15:13:09 -0600 Subject: [PATCH 02/26] Extend obs model to enable training on just sparse high res data --- pixi.lock | 5 +- pyproject.toml | 2 +- sup3r/models/abstract.py | 32 ++- sup3r/models/base.py | 4 +- sup3r/models/interface.py | 51 ++-- sup3r/models/utilities.py | 6 +- sup3r/models/with_obs.py | 30 ++- sup3r/utilities/loss_metrics.py | 5 +- tests/conftest.py | 91 +++++++- tests/training/test_train_conditioned_obs.py | 102 -------- tests/training/test_train_target_obs.py | 232 ------------------- tests/training/test_train_with_obs.py | 224 ++++++++++++++++++ 12 files changed, 390 insertions(+), 394 deletions(-) delete mode 100644 tests/training/test_train_conditioned_obs.py delete mode 100644 tests/training/test_train_target_obs.py create mode 100644 tests/training/test_train_with_obs.py diff --git a/pixi.lock b/pixi.lock index 770288f6f1..5edb78d6ed 100644 --- a/pixi.lock +++ b/pixi.lock @@ -10570,8 +10570,8 @@ packages: requires_python: '>=3.9' - pypi: ./ name: nrel-sup3r - version: 0.2.6.dev40+gcbf2eb538.d20260212 - sha256: 34fdf3835de987425a5895c3bde82593b60d70febbc04b705b8e35b103557138 + version: 0.2.6.dev74+g17ecd9ae1.d20260308 + sha256: 2308c5dfecdbc73a9093af87f6197421a11a154908eaaca2e6e9eb3ab1bc4900 requires_dist: - nrel-rex>=0.2.91 - nrel-phygnn>=0.0.32 @@ -10603,7 +10603,6 @@ packages: - pkginfo>=1.10.0,<2 ; extra == 'build' - twine>=5.0 ; extra == 'build' requires_python: '>=3.9' - editable: true - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py311h64a7726_0.conda sha256: 3f4365e11b28e244c95ba8579942b0802761ba7bb31c026f50d1a9ea9c728149 md5: a502d7aad449a1206efb366d6a12c52d diff --git a/pyproject.toml b/pyproject.toml index 47ecae58e4..1aabe5c005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -301,7 +301,7 @@ xarray = ">=2023.0" [tool.pixi.pypi-dependencies] NREL-sup3r = { path = ".", editable = true } -NREL-rex = { version = ">=0.2.87" } +NREL-rex = { version = ">=0.2.91" } NREL-phygnn = { version = ">=0.0.23" } NREL-gaps = { version = ">=0.6.13" } NREL-farms = { version = ">=1.0.4" } diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index eb819ec9aa..e25b3d8a61 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -23,7 +23,7 @@ from sup3r.utilities import VERSION_RECORD from sup3r.utilities.utilities import Timer, camel_to_underscore, safe_cast -from .utilities import SUP3R_LAYERS, SUP3R_OBS_LAYERS, TensorboardMixIn +from .utilities import SUP3R_LAYERS, TensorboardMixIn logger = logging.getLogger(__name__) @@ -1038,19 +1038,16 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): extras = [] features = getattr(layer, 'features', [layer.name]) exo_features = getattr(layer, 'exo_features', []) - is_obs_layer = isinstance(layer, SUP3R_OBS_LAYERS) for feat in features + exo_features: - missing_obs = feat in features and feat not in exogenous_data - if is_obs_layer and missing_obs: + missing_feat = feat in features and feat not in exogenous_data + if missing_feat: msg = ( f'{feat} does not match any features in exogenous_data ' - f'({list(exogenous_data)}). Will run without this ' - 'observation feature.' + f'({list(exogenous_data)}). Will try to run without this ' + 'feature.' ) logger.warning(msg) continue - msg = f'exogenous_data is missing required feature "{feat}"' - assert feat in exogenous_data, msg exo = exogenous_data.get_combine_type_data(feat, 'layer') exo = self._reshape_norm_exo( input_array, @@ -1167,7 +1164,7 @@ def _run_exo_layer(cls, layer, input_array, hi_res_exo): return layer(input_array, hr_exo, extras) return layer(input_array, hr_exo) - @tf.function + # @tf.function def _tf_generate(self, low_res, hi_res_exo=None): """Use the generator model to generate high res data from low res input @@ -1195,19 +1192,20 @@ def _tf_generate(self, low_res, hi_res_exo=None): """ hi_res = self.generator.layers[0](low_res) layer_num = 1 - try: - for i, layer in enumerate(self.generator.layers[1:]): + for i, layer in enumerate(self.generator.layers[1:]): + try: layer_num = i + 1 if isinstance(layer, SUP3R_LAYERS): hi_res = self._run_exo_layer(layer, hi_res, hi_res_exo) else: hi_res = layer(hi_res) - except Exception as e: - msg = 'Could not run layer #{} "{}" on tensor of shape {}'.format( - layer_num, layer, hi_res.shape - ) - logger.error(msg) - raise RuntimeError(msg) from e + except Exception as e: + msg = ( + f'Could not run layer #{layer_num} "{layer}" on tensor ' + f'of shape {hi_res.shape}' + ) + logger.error(msg) + raise RuntimeError(msg) from e return hi_res diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 0fa852d288..2ba65d2522 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -899,7 +899,9 @@ def calc_loss( loss_gen_advers = self.calc_loss_disc( disc_out_true=disc_out_gen, disc_out_gen=disc_out_true ) - loss = loss_gen_content + weight_gen_advers * loss_gen_advers + loss = loss_gen_content + if weight_gen_advers > 0: + loss += weight_gen_advers * loss_gen_advers loss_details['loss_gen'] = loss loss_details['loss_gen_content'] = loss_gen_content loss_details['loss_gen_advers'] = loss_gen_advers diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 1387296ecb..4e26cad8ea 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -15,7 +15,7 @@ from sup3r.utilities import VERSION_RECORD from sup3r.utilities.utilities import safe_cast -from .utilities import SUP3R_EXO_LAYERS, SUP3R_OBS_LAYERS +from .utilities import SUP3R_LAYERS logger = logging.getLogger(__name__) @@ -79,7 +79,6 @@ def input_dims(self): ------- int """ - # pylint: disable=E1101 if hasattr(self, '_gen'): return self._gen.layers[0].rank if hasattr(self, 'models'): @@ -96,7 +95,6 @@ def is_4d(self): """Check if model expects spatial only input""" return self.input_dims == 4 - # pylint: disable=E1101 def get_s_enhance_from_layers(self): """Compute factor by which model will enhance spatial resolution from layer attributes. Used in model training during high res coarsening""" @@ -109,7 +107,6 @@ def get_s_enhance_from_layers(self): s_enhance = int(np.prod(s_enhancements)) return s_enhance - # pylint: disable=E1101 def get_t_enhance_from_layers(self): """Compute factor by which model will enhance temporal resolution from layer attributes. Used in model training during high res coarsening""" @@ -241,6 +238,20 @@ def _ensure_valid_enhancement_factors(self): logger.error(msg) raise RuntimeError(msg) + def _get_layer_features(self): + """Get the list of features used in the model based on layer + attributes. This is used to check that the features provided in + exogenous_data match the features expected by the model + architecture.""" + features = [] + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if isinstance(layer, SUP3R_LAYERS): + layer_feats = getattr(layer, 'features', [layer.name]) + layer_feats = [f for f in layer_feats if f not in features] + features.extend(layer_feats) + return features + @property def output_resolution(self): """Resolution of output data. Given as a dictionary @@ -378,15 +389,8 @@ def hr_out_features(self): def obs_features(self): """Get list of exogenous observation feature names the model uses. These come from the names of the ``Sup3rObs..`` layers.""" - # pylint: disable=E1101 - features = [] - if hasattr(self, '_gen'): - for layer in self._gen.layers: - if isinstance(layer, SUP3R_OBS_LAYERS): - obs_feats = getattr(layer, 'features', [layer.name]) - obs_feats = [f for f in obs_feats if f not in features] - features.extend(obs_feats) - return features + features = self._get_layer_features() + return [f for f in features if '_obs' in f] @property def hr_exo_features(self): @@ -397,17 +401,14 @@ def hr_exo_features(self): [..., topo, sza], and the model has 2 concat or add layers, exo features will be [topo, sza]. Topo will then be used in the first concat layer and sza will be used in the second""" - # pylint: disable=E1101 - features = [] - if hasattr(self, '_gen'): - features = [ - layer.name - for layer in self._gen.layers - if isinstance(layer, SUP3R_EXO_LAYERS) - ] - obs_feats = [feat.replace('_obs', '') for feat in self.obs_features] + features = self._get_layer_features() + features = [f for f in features if '_obs' not in f] + obs_feats = [ + f.replace('_obs', '') + for f in self.obs_features + if f not in self.hr_out_features + ] features += [f for f in obs_feats if f not in self.hr_out_features] - return features @property def hr_features(self): @@ -458,7 +459,8 @@ def set_model_params(self, **kwargs): kwargs : dict Keyword arguments including 'input_resolution', 'lr_features', 'hr_exo_features', 'hr_out_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + 'obs_features', 'smoothed_features', 's_enhance', 't_enhance', + 'smoothing' """ keys = ( @@ -466,6 +468,7 @@ def set_model_params(self, **kwargs): 'lr_features', 'hr_exo_features', 'hr_out_features', + 'obs_features', 'smoothed_features', 's_enhance', 't_enhance', diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index f506de2834..e08464382c 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -20,11 +20,7 @@ logger = logging.getLogger(__name__) -SUP3R_OBS_LAYERS = Sup3rObsModel, Sup3rConcatObs - -SUP3R_EXO_LAYERS = Sup3rAdder, Sup3rConcat - -SUP3R_LAYERS = (*SUP3R_EXO_LAYERS, *SUP3R_OBS_LAYERS) +SUP3R_LAYERS = (Sup3rObsModel, Sup3rConcatObs, Sup3rAdder, Sup3rConcat) class TrainingSession: diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index ab2d8f3862..c4bf61dbf5 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -240,13 +240,13 @@ def _get_full_obs_mask(self, hi_res): ``_get_obs_mask`` by defining a composite mask based on separate onshore and offshore masks. This is because there is often more observation data available onshore than offshore.""" - on_sf = self.onshore_obs_frac['spatial'] + on_sf = self.onshore_obs_frac.get('spatial', 0.0) on_tf = self.onshore_obs_frac.get('time', 1.0) obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf) if 'topography' in self.hr_features and self.offshore_obs_frac: topo_idx = self.hr_features.index('topography') topo = hi_res[..., topo_idx] - off_sf = self.offshore_obs_frac['spatial'] + off_sf = self.offshore_obs_frac.get('spatial', 0.0) off_tf = self.offshore_obs_frac.get('time', 1.0) offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf) obs_mask = tf.where(topo[..., None] > 0, obs_mask, offshore_mask) @@ -275,16 +275,34 @@ def get_hr_exo_input(self, hi_res_true): exo_data = super().get_hr_exo_input(hi_res_true) if len(self.obs_features) == 0: return exo_data - obs_mask = self._get_full_obs_mask(hi_res_true) - nan_const = tf.constant(float('nan'), dtype=hi_res_true.dtype) - obs = tf.gather(hi_res_true, self.obs_training_inds, axis=-1) - obs = tf.where(obs_mask[..., : obs.shape[-1]], nan_const, obs) + + if self.use_proxy_obs: + obs, obs_mask = self._get_proxy_obs(hi_res_true) + else: + obs, obs_mask = self._get_real_obs(hi_res_true) + obs = tf.expand_dims(obs, axis=-2) exo_obs = dict(zip(self.obs_features, tf.unstack(obs, axis=-1))) exo_data.update(exo_obs) exo_data['mask'] = obs_mask return exo_data + def _get_real_obs(self, hi_res_true): + """Get real observation data and the corresponding mask from the + .high_res attribute of the batches in the batch handler. This is used + when not training with proxy observations.""" + obs = tf.gather(hi_res_true, self.obs_training_inds, axis=-1) + obs_mask = tf.math.is_nan(obs) + return obs, obs_mask + + def _get_proxy_obs(self, hi_res_true): + """Get proxy observation data by masking the true high res data.""" + obs_mask = self._get_full_obs_mask(hi_res_true) + nan_const = tf.constant(float('nan'), dtype=hi_res_true.dtype) + obs = tf.gather(hi_res_true, self.obs_training_inds, axis=-1) + obs = tf.where(obs_mask[..., : obs.shape[-1]], nan_const, obs) + return obs, obs_mask + def _get_hr_exo_and_loss( self, low_res, diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 0fd3166b03..437045725d 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -888,4 +888,7 @@ def __call__(self, x1, x2): ) assert check, msg - return self.LOSS_METRIC(x1, x2) + mask = tf.math.is_nan(x2) + return self.LOSS_METRIC( + x1[tf.math.logical_not(mask)], x2[tf.math.logical_not(mask)] + ) diff --git a/tests/conftest.py b/tests/conftest.py index 50863f48e3..71925443bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,8 +70,8 @@ def train_on_cpu(): @pytest.fixture(scope='package') -def gen_config_with_concat_masked(): - """Get generator config with custom concat masked layer.""" +def gen_config_with_obs_2d(): + """Get generator config with observation layers.""" def func(): return [ @@ -149,6 +149,93 @@ def func(): return func +@pytest.fixture(scope='package') +def gen_config_with_obs_3d(): + """Get generator config with observation layers.""" + + def func(): + return [ + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'temporal_mult': 2 + }, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'spatial_mult': 2 + }, + {'class': 'Activation', 'activation': 'relu'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + {'class': 'Sup3rConcatObs', 'name': 'u_10m_obs'}, + {'class': 'Sup3rConcatObs', 'name': 'v_10m_obs'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + ] + + return func + + @pytest.fixture(scope='package') def gen_config_with_topo(): """Get generator config with custom topo layer.""" diff --git a/tests/training/test_train_conditioned_obs.py b/tests/training/test_train_conditioned_obs.py deleted file mode 100644 index 9cd8d35720..0000000000 --- a/tests/training/test_train_conditioned_obs.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Test the training of super resolution GANs with exogenous observation -data.""" - -import os -import tempfile - -import numpy as np -import pytest - -from sup3r.models import Sup3rGanWithObs -from sup3r.preprocessing import ( - BatchHandler, - DataHandler, -) -from sup3r.utilities.utilities import RANDOM_GENERATOR - -SHAPE = (20, 20) -FEATURES_W = ['u_10m', 'v_10m'] -TARGET_W = (39.01, -105.15) - - -@pytest.mark.parametrize('gen_config', ['gen_config_with_concat_masked']) -def test_fixed_wind_obs(gen_config, request): - """Test a special model which fixes observations mid network with - ``Sup3rConcatObs`` layer.""" - - gen_config = request.getfixturevalue(gen_config)() - kwargs = { - 'file_paths': pytest.FP_WTK, - 'features': FEATURES_W, - 'target': TARGET_W, - 'shape': SHAPE, - } - - train_handler = DataHandler(**kwargs, time_slice=slice(None, 3000, 10)) - - val_handler = DataHandler(**kwargs, time_slice=slice(3000, None, 10)) - batcher = BatchHandler( - [train_handler], - [val_handler], - batch_size=2, - n_batches=1, - s_enhance=2, - t_enhance=1, - sample_shape=(20, 20, 1), - ) - - Sup3rGanWithObs.seed() - - model = Sup3rGanWithObs( - gen_config, - pytest.S_FP_DISC, - onshore_obs_frac={'spatial': 0.1}, - loss_obs_weight=0.1, - learning_rate=1e-4, - ) - model.meta['hr_out_features'] = ['u_10m', 'v_10m'] - test_mask = model._get_full_obs_mask(np.zeros((1, 20, 20, 1, 1))).numpy() - frac = 1 - test_mask.sum() / test_mask.size - assert np.abs(0.1 - frac) < test_mask.size / (2 * np.sqrt(test_mask.size)) - assert model.obs_features == ['u_10m_obs', 'v_10m_obs'] - with tempfile.TemporaryDirectory() as td: - model_kwargs = { - 'input_resolution': {'spatial': '16km', 'temporal': '3600min'}, - 'n_epoch': 3, - 'weight_gen_advers': 0.0, - 'train_gen': True, - 'train_disc': False, - 'checkpoint_int': None, - 'out_dir': os.path.join(td, 'test_{epoch}'), - } - - model.train(batcher, **model_kwargs) - - loaded = model.load(os.path.join(td, 'test_2')) - loaded.train(batcher, **model_kwargs) - - x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(FEATURES_W))) - u10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) - v10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) - mask = RANDOM_GENERATOR.choice([True, False], (60, 60, 1), p=[0.9, 0.1]) - u10m_obs[:, mask] = np.nan - v10m_obs[:, mask] = np.nan - - with pytest.raises(RuntimeError): - y = model.generate(x, exogenous_data=None) - - exo_tmp = { - 'u_10m_obs': { - 'steps': [{'model': 0, 'combine_type': 'layer', 'data': u10m_obs}] - }, - 'v_10m_obs': { - 'steps': [{'model': 0, 'combine_type': 'layer', 'data': v10m_obs}] - }, - } - y = model.generate(x, exogenous_data=exo_tmp) - - assert y.dtype == np.float32 - assert y.shape[0] == x.shape[0] - assert y.shape[1] == x.shape[1] * 2 - assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == len(FEATURES_W) diff --git a/tests/training/test_train_target_obs.py b/tests/training/test_train_target_obs.py deleted file mode 100644 index 2399b21ba0..0000000000 --- a/tests/training/test_train_target_obs.py +++ /dev/null @@ -1,232 +0,0 @@ -"""Test the training of GANs with dual data handler""" - -import itertools -import os -import tempfile - -import numpy as np -import pytest - -from sup3r.models import Sup3rGanWithObs -from sup3r.preprocessing import ( - Container, - DataHandler, - DualBatchHandler, - DualRasterizer, -) -from sup3r.preprocessing.samplers import DualSampler -from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory - -TARGET_COORD = (39.01, -105.15) -FEATURES = ['u_100m', 'v_100m'] - - -DualBatchHandlerWithObsTester = BatchHandlerTesterFactory( - DualBatchHandler, DualSampler -) - - -@pytest.mark.parametrize( - [ - 'fp_gen', - 'fp_disc', - 's_enhance', - 't_enhance', - 'sample_shape', - 'mode', - ], - [ - (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), - (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), - (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), - (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), - ], -) -def test_train_h5_nc( - fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 -): - """Test model training with a dual data handler / batch handler with h5 and - era as hr / lr datasets. Tests both spatiotemporal and spatial models.""" - - lr = 1e-5 - kwargs = { - 'features': FEATURES, - 'target': TARGET_COORD, - 'shape': (20, 20), - } - hr_handler = DataHandler( - pytest.FP_WTK, - **kwargs, - time_slice=slice(None, None, 1), - ) - - lr_handler = DataHandler( - pytest.FP_ERA, - features=FEATURES, - time_slice=slice(None, None, t_enhance), - ) - - dual_rasterizer = DualRasterizer( - data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, - s_enhance=s_enhance, - t_enhance=t_enhance, - ) - obs_data = dual_rasterizer.high_res.copy() - for feat in FEATURES: - tmp = np.full(obs_data[feat].shape, np.nan) - lat_ids = list(range(0, 20, 4)) - lon_ids = list(range(0, 20, 4)) - for ilat, ilon in itertools.product(lat_ids, lon_ids): - tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] - obs_data[feat] = (obs_data[feat].dims, tmp) - - dual_with_obs = Container( - data={ - 'low_res': dual_rasterizer.low_res, - 'high_res': dual_rasterizer.high_res, - 'obs': obs_data, - } - ) - - batch_handler = DualBatchHandlerWithObsTester( - train_containers=[dual_with_obs], - val_containers=[], - sample_shape=sample_shape, - batch_size=3, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=3, - mode=mode, - ) - - for batch in batch_handler: - assert hasattr(batch, 'obs') - assert not np.isnan(batch.obs).all() - assert np.isnan(batch.obs).any() - - Sup3rGanWithObs.seed() - model = Sup3rGanWithObs( - fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' - ) - - with tempfile.TemporaryDirectory() as td: - model_kwargs = { - 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, - 'n_epoch': n_epoch, - 'weight_gen_advers': 0.0, - 'train_gen': True, - 'train_disc': False, - 'checkpoint_int': 1, - 'out_dir': os.path.join(td, 'test_{epoch}'), - } - - model.train(batch_handler, **model_kwargs) - - tlossg = model.history['train_loss_gen'].values - tlosso = model.history['train_loss_obs'].values - assert np.sum(np.diff(tlossg)) < 0 - assert np.sum(np.diff(tlosso)) < 0 - - -@pytest.mark.parametrize( - [ - 'fp_gen', - 'fp_disc', - 's_enhance', - 't_enhance', - 'sample_shape', - 'mode', - ], - [ - (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'lazy'), - (pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16), 'eager'), - (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'lazy'), - (pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (20, 20, 1), 'eager'), - ], -) -def test_train_coarse_h5( - fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, mode, n_epoch=2 -): - """Test model training with a dual data handler / batch handler with - additional sparse observation data used in extra content loss term. Tests - both spatiotemporal and spatial models.""" - - lr = 1e-5 - kwargs = { - 'features': FEATURES, - 'target': TARGET_COORD, - 'shape': (20, 20), - } - hr_handler = DataHandler( - pytest.FP_WTK, - **kwargs, - time_slice=slice(None, None, 1), - ) - - lr_handler = DataHandler( - pytest.FP_WTK, - **kwargs, - hr_spatial_coarsen=s_enhance, - time_slice=slice(None, None, t_enhance), - ) - - dual_rasterizer = DualRasterizer( - data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, - s_enhance=s_enhance, - t_enhance=t_enhance, - ) - obs_data = dual_rasterizer.high_res.copy() - for feat in FEATURES: - tmp = np.full(obs_data[feat].shape, np.nan) - lat_ids = list(range(0, 20, 4)) - lon_ids = list(range(0, 20, 4)) - for ilat, ilon in itertools.product(lat_ids, lon_ids): - tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] - obs_data[feat] = (obs_data[feat].dims, tmp) - - dual_with_obs = Container( - data={ - 'low_res': dual_rasterizer.low_res, - 'high_res': dual_rasterizer.high_res, - 'obs': obs_data, - } - ) - - batch_handler = DualBatchHandlerWithObsTester( - train_containers=[dual_with_obs], - val_containers=[], - sample_shape=sample_shape, - batch_size=3, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=3, - mode=mode, - ) - - for batch in batch_handler: - assert hasattr(batch, 'obs') - assert not np.isnan(batch.obs).all() - assert np.isnan(batch.obs).any() - - Sup3rGanWithObs.seed() - model = Sup3rGanWithObs( - fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' - ) - - with tempfile.TemporaryDirectory() as td: - model_kwargs = { - 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, - 'n_epoch': n_epoch, - 'weight_gen_advers': 0.0, - 'train_gen': True, - 'train_disc': False, - 'checkpoint_int': 1, - 'out_dir': os.path.join(td, 'test_{epoch}'), - } - - model.train(batch_handler, **model_kwargs) - - tlossg = model.history['train_loss_gen'].values - tlosso = model.history['train_loss_obs'].values - assert np.sum(np.diff(tlossg)) < 0 - assert np.sum(np.diff(tlosso)) < 0 diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py new file mode 100644 index 0000000000..f4d0628552 --- /dev/null +++ b/tests/training/test_train_with_obs.py @@ -0,0 +1,224 @@ +"""Test the training of super resolution GANs with exogenous observation +data.""" + +import itertools +import os +import tempfile + +import numpy as np +import pytest + +from sup3r.models import Sup3rGanWithObs +from sup3r.preprocessing import ( + BatchHandler, + Container, + DataHandler, + DualBatchHandler, + DualRasterizer, +) +from sup3r.preprocessing.samplers import DualSampler +from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory +from sup3r.utilities.utilities import RANDOM_GENERATOR + +DualBatchHandlerWithObsTester = BatchHandlerTesterFactory( + DualBatchHandler, DualSampler +) + +SHAPE = (20, 20) +FEATURES_W = ['u_10m', 'v_10m'] +TARGET_W = (39.01, -105.15) + + +@pytest.mark.parametrize( + 'gen_config, sample_shape, t_enhance, fp_disc', + [ + ('gen_config_with_obs_2d', (20, 20, 1), 1, pytest.S_FP_DISC), + ('gen_config_with_obs_3d', (20, 20, 10), 2, pytest.ST_FP_DISC), + ], +) +def test_train_cond_obs(gen_config, sample_shape, t_enhance, fp_disc, request): + """Test a special model which conditions model output on observations + with a ``Sup3rConcatObs`` layer.""" + + gen_config = request.getfixturevalue(gen_config)() + kwargs = { + 'file_paths': pytest.FP_WTK, + 'features': FEATURES_W, + 'target': TARGET_W, + 'shape': SHAPE, + } + + train_handler = DataHandler(**kwargs, time_slice=slice(None, 3000, 10)) + + val_handler = DataHandler(**kwargs, time_slice=slice(3000, None, 10)) + batcher = BatchHandler( + [train_handler], + [val_handler], + batch_size=2, + n_batches=1, + s_enhance=2, + t_enhance=t_enhance, + sample_shape=sample_shape, + ) + + Sup3rGanWithObs.seed() + + model = Sup3rGanWithObs( + gen_config, + fp_disc, + onshore_obs_frac={'spatial': 0.1}, + loss_obs_weight=0.1, + learning_rate=1e-4, + ) + model.meta['hr_out_features'] = ['u_10m', 'v_10m'] + test_mask = model._get_full_obs_mask(np.zeros((1, 20, 20, 1, 1))).numpy() + frac = 1 - test_mask.sum() / test_mask.size + assert np.abs(0.1 - frac) < test_mask.size / (2 * np.sqrt(test_mask.size)) + assert model.obs_features == ['u_10m_obs', 'v_10m_obs'] + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '16km', 'temporal': '3600min'}, + 'n_epoch': 3, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': None, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batcher, **model_kwargs) + + loaded = model.load(os.path.join(td, 'test_2')) + loaded.train(batcher, **model_kwargs) + + if t_enhance == 1: + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(FEATURES_W))) + u10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) + v10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) + else: + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, 10, len(FEATURES_W))) + u10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 20, 1)) + v10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 20, 1)) + mask = RANDOM_GENERATOR.choice( + [True, False], u10m_obs.shape[1:], p=[0.9, 0.1] + ) + u10m_obs[:, mask] = np.nan + v10m_obs[:, mask] = np.nan + + with pytest.raises(RuntimeError): + y = model.generate(x, exogenous_data=None) + + exo_tmp = { + 'u_10m_obs': { + 'steps': [{'model': 0, 'combine_type': 'layer', 'data': u10m_obs}] + }, + 'v_10m_obs': { + 'steps': [{'model': 0, 'combine_type': 'layer', 'data': v10m_obs}] + }, + } + y = model.generate(x, exogenous_data=exo_tmp) + + assert y.dtype == np.float32 + assert y.shape[0] == x.shape[0] + assert y.shape[1] == x.shape[1] * 2 + assert y.shape[2] == x.shape[2] * 2 + assert y.shape[-1] == len(FEATURES_W) + if y.ndim == 5: + assert y.shape[3] == x.shape[3] * t_enhance + + +@pytest.mark.parametrize( + 'gen_config, sample_shape, t_enhance, fp_disc', + [ + ('gen_config_with_obs_2d', (20, 20, 1), 1, pytest.S_FP_DISC), + ('gen_config_with_obs_3d', (20, 20, 10), 2, pytest.ST_FP_DISC), + ], +) +def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): + """Test model training with sparse high resolution ground truth data.""" + + gen_config = request.getfixturevalue(gen_config)() + kwargs = { + 'features': FEATURES_W, + 'target': TARGET_W, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES_W, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=2, + t_enhance=t_enhance, + run_qa=False, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES_W: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = list(range(0, 20, 4)) + lon_ids = list(range(0, 20, 4)) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[f'{feat}_obs'] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=2, + t_enhance=t_enhance, + n_batches=2, + feature_sets={'lr_only_features': FEATURES_W}, + mode='lazy', + ) + + for batch in batch_handler: + assert not np.isnan(batch.high_res).all() + assert np.isnan(batch.high_res).any() + + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( + gen_config, + fp_disc, + use_proxy_obs=False, + learning_rate=1e-4, + loss={ + 'GeothermalPhysicsLossWithObs': { + 'input_features': [f'{feat}_obs' for feat in FEATURES_W], + 'obs_features': [f'{feat}_obs' for feat in FEATURES_W], + } + }, + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': 5, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tloss = model.history['train_geothermal_physics_loss_with_obs'].values + assert np.sum(np.diff(tloss)) < 0 From a80d2eb463360f46cd3c8d325105d8a4641a4557 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 9 Mar 2026 15:13:44 -0600 Subject: [PATCH 03/26] Remove TODO comments regarding observation handling in Sup3rGanWithObs --- sup3r/models/with_obs.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index c4bf61dbf5..e95b12340c 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -12,24 +12,6 @@ logger = logging.getLogger(__name__) -# TODO: Refactor so the observations can be either "proxies" sampled from -# gridded high res or "real" observations drawn from the .obs attribute of -# the batches in the batch handler. -# - either construct mask like currently done or calculate it from the -# .obs data by checking for NaN values. The former is done when using -# sampled proxies for training and the latter is done when using real -# observations for training. -# - might be able to remove obs loss weighting and delegate this to loss -# functions -# - add flag in init for proxies vs real. -# - or maybe the obs data can just be supplied through the high_res attr -# and can allow for some NaNs. Then the mask can be calculated on the high -# res data by checking for NaNs and this would work for both proxies and -# real obs. This would be simpler and more flexible but would require -# changes to the batch handler to allow for NaNs in the high res data -# - might be some quirks with tracking output features - - class Sup3rGanWithObs(Sup3rGan): """Sup3r GAN model which includes mid network observation fusion. This model is useful for when production runs will be over a domain for which From b264af1c5fc95c76b5ab56d4c7536f988091dc89 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 12 Mar 2026 15:40:15 -0600 Subject: [PATCH 04/26] refact: move proxy obs sampling logic to sampler objects. removed Sup3rGanWithObs model, since all models can now train with obs with the appropriate batch handler args. refact: `feature_sets` now uses `lr_features`, `hr_exo_features`, `hr_out_features` instead of `lr_only_features`. I think this is simpler to understand how features split. feat: enabled training on just sparse targets - this is currently setup to discriminate only on non sparse features - this could be changed if the discriminator model has a way to handle the sparse data. --- sup3r/models/__init__.py | 1 - sup3r/models/abstract.py | 106 ++-- sup3r/models/base.py | 27 +- sup3r/models/interface.py | 60 +-- sup3r/models/with_obs.py | 334 ------------- sup3r/preprocessing/batch_handlers/factory.py | 34 +- sup3r/preprocessing/batch_queues/base.py | 5 +- sup3r/preprocessing/batch_queues/dual.py | 9 +- sup3r/preprocessing/collections/stats.py | 11 +- sup3r/preprocessing/rasterizers/dual.py | 60 +-- sup3r/preprocessing/samplers/base.py | 453 ++++++++++++++---- sup3r/preprocessing/samplers/cc.py | 45 +- sup3r/preprocessing/samplers/dc.py | 26 +- sup3r/preprocessing/samplers/dual.py | 127 +++-- sup3r/utilities/loss_metrics.py | 76 +-- tests/batch_handlers/test_bh_h5_cc.py | 20 +- tests/batch_queues/test_bq_general.py | 5 +- tests/forward_pass/test_forward_pass_obs.py | 26 +- tests/samplers/test_feature_sets.py | 197 +++++++- tests/training/test_train_exo.py | 31 +- tests/training/test_train_exo_cc.py | 27 +- tests/training/test_train_gan.py | 2 +- tests/training/test_train_solar.py | 15 +- tests/training/test_train_with_obs.py | 44 +- tests/utilities/test_loss_metrics.py | 6 +- 25 files changed, 995 insertions(+), 752 deletions(-) delete mode 100644 sup3r/models/with_obs.py diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 0cb5f1bf51..f079fd9310 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -7,6 +7,5 @@ from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel -from .with_obs import Sup3rGanWithObs SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index e25b3d8a61..66d145c19d 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -437,6 +437,37 @@ def get_hr_exo_input(self, hi_res): exo = dict(zip(self.hr_exo_features, tf.unstack(exo, axis=-1))) return exo + def _extract_obs(self, hi_res_true): + """Extract observation features from the end of hi_res_true. + Observation features are appended after hr_out + hr_exo features + by the Sampler when ``use_proxy_obs=True``, or included directly + when real obs are in the data. + + Parameters + ---------- + hi_res_true : tf.Tensor + Ground truth high resolution data, possibly with obs features + appended at the end. + + Returns + ------- + obs_data : dict + Dictionary of observation feature data. Keys are obs feature + names, values are tensors. Also includes ``'mask'`` key with + boolean mask (True where observations are NaN / missing). + Empty dict if no obs features are present. + """ + if len(self.obs_features) == 0: + return {} + obs = hi_res_true[..., -len(self.obs_features) :] + obs_mask = tf.math.is_nan(obs) + obs_expanded = tf.expand_dims(obs, axis=-2) + obs_data = dict( + zip(self.obs_features, tf.unstack(obs_expanded, axis=-1)) + ) + obs_data['mask'] = obs_mask + return obs_data + def _combine_loss_input(self, hi_res_true, hi_res_gen): """Combine exogenous feature data from hi_res_true with hi_res_gen for loss calculation @@ -460,33 +491,46 @@ def _combine_loss_input(self, hi_res_true, hi_res_gen): hi_res_gen = tf.concat((hi_res_gen, *exo_data), axis=-1) return hi_res_gen - def _get_loss_inputs(self, hi_res_true, hi_res_gen, loss_func): + def _get_loss_inputs(self, hi_res_gen, hi_res_true, loss_func): """Get inputs for the given loss function according to the required input features""" - msg = ( - f'{loss_func} requires input features: ' - f'{loss_func.input_features}, but these are not found ' - f'in the model output features: {self.hr_out_features}' - ) - if not all( - f in self.hr_out_features for f in loss_func.input_features + gen_feats = getattr(loss_func, 'gen_features', 'all') + true_feats = getattr(loss_func, 'true_features', 'all') + + if gen_feats != 'all' and not all( + f in self.hr_features for f in gen_feats + ): + msg = ( + f'{loss_func} requires gen_features: ' + f'{loss_func.gen_features}, but these are not found ' + f'in the high-resolution features: {self.hr_features}' + ) + logger.error(msg) + raise ValueError(msg) + + if true_feats != 'all' and not all( + f in self.hr_features for f in true_feats ): + msg = ( + f'{loss_func} requires true_features: ' + f'{loss_func.true_features}, but these are not found ' + f'in the high-resolution output features: {self.hr_features}' + ) logger.error(msg) raise ValueError(msg) - input_inds = [ - self.hr_out_features.index(f) for f in loss_func.input_features - ] - hr_true = tf.stack( - [hi_res_true[..., idx] for idx in input_inds], - axis=-1, - ) - hr_gen = tf.stack( - [hi_res_gen[..., idx] for idx in input_inds], - axis=-1, - ) - return hr_true, hr_gen + if gen_feats == 'all': + gen_feats = self.hr_features + if true_feats == 'all': + true_feats = self.hr_features + + gen_inds = [self.hr_features.index(f) for f in gen_feats] + true_inds = [self.hr_features.index(f) for f in true_feats] + + hr_true = tf.gather(hi_res_true, true_inds, axis=-1) + hr_gen = tf.gather(hi_res_gen, gen_inds, axis=-1) + return hr_gen, hr_true def get_loss_fun(self, loss): """Get full, possibly multi-term, loss function from the provided str @@ -519,20 +563,14 @@ def get_loss_fun(self, loss): loss_funcs = [self._get_loss_fun({ln: loss[ln]}) for ln in lns] weights = copy.deepcopy(loss).pop('term_weights', [1.0] * len(lns)) - def loss_fun(hi_res_true, hi_res_gen): + def loss_fun(hi_res_gen, hi_res_true): loss_details = {} loss = 0 for i, (ln, loss_func) in enumerate(zip(lns, loss_funcs)): - if ( - not hasattr(loss_func, 'input_features') - or loss_func.input_features == 'all' - ): - val = loss_func(hi_res_true, hi_res_gen) - else: - hr_true, hr_gen = self._get_loss_inputs( - hi_res_true, hi_res_gen, loss_func - ) - val = loss_func(hr_true, hr_gen) + hr_gen, hr_true = self._get_loss_inputs( + hi_res_gen, hi_res_true, loss_func + ) + val = loss_func(hr_gen, hr_true) loss_details[camel_to_underscore(ln)] = val loss += weights[i] * val return loss, loss_details @@ -1164,7 +1202,7 @@ def _run_exo_layer(cls, layer, input_array, hi_res_exo): return layer(input_array, hr_exo, extras) return layer(input_array, hr_exo) - # @tf.function + @tf.function def _tf_generate(self, low_res, hi_res_exo=None): """Use the generator model to generate high res data from low res input @@ -1216,7 +1254,9 @@ def _get_hr_exo_and_loss( **calc_loss_kwargs, ): """Get high-resolution exogenous data, generate synthetic output, and - compute loss.""" + compute loss. Obs features (if present at the end of hi_res_true) are + extracted and added to exo_data, and trimmed from hi_res_true before + loss calculation.""" hi_res_exo = self.get_hr_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss, loss_details = self.calc_loss( diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 2ba65d2522..108c8e49bf 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -298,7 +298,19 @@ def _tf_discriminate(self, hi_res): out : np.ndarray Discriminator output logits """ - out = self.discriminator.layers[0](hi_res) + + # TODO: We currently assume the discriminator is convolutional so we + # remove the sparse obs data. Change this once we support + # non-convolutional discriminators + hr = ( + hi_res + if len(self.obs_features) == 0 + else hi_res[..., : -len(self.obs_features)] + ) + if hr.shape[-1] == 0: + return tf.constant([], dtype=tf.float32) + + out = self.discriminator.layers[0](hr) layer_num = 1 try: for i, layer in enumerate(self.discriminator.layers[1:]): @@ -426,7 +438,7 @@ def init_weights(self, lr_shape, hr_shape, device=None): with tf.device(device): hr_exo_data = {} - for feature in self.hr_exo_features + self.obs_features: + for feature in self.hr_exo_features: hr_exo_data[feature] = hr_exo out = self._tf_generate(low_res, hr_exo_data) msg = ( @@ -494,14 +506,7 @@ def calc_loss_gen_content(self, hi_res_true, hi_res_gen): 0D tensor generator model loss for the content loss comparing the hi res ground truth to the hi res synthetically generated output. """ - slc = ( - slice(0, None) - if len(self.hr_exo_features) == 0 - else slice(0, -len(self.hr_exo_features)) - ) - # gen is first since loss can included regularizers which just - # apply to generator output - return self.loss_fun(hi_res_gen[..., slc], hi_res_true[..., slc]) + return self.loss_fun(hi_res_gen, hi_res_true) @staticmethod @tf.function @@ -617,6 +622,8 @@ def check_batch_handler_attrs(batch_handler): 'lr_features', 'hr_exo_features', 'hr_out_features', + 'hr_features', + 'obs_features', 'smoothed_features', ] if hasattr(batch_handler, k) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 4e26cad8ea..1aae2b0995 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -238,20 +238,6 @@ def _ensure_valid_enhancement_factors(self): logger.error(msg) raise RuntimeError(msg) - def _get_layer_features(self): - """Get the list of features used in the model based on layer - attributes. This is used to check that the features provided in - exogenous_data match the features expected by the model - architecture.""" - features = [] - if hasattr(self, '_gen'): - for layer in self._gen.layers: - if isinstance(layer, SUP3R_LAYERS): - layer_feats = getattr(layer, 'features', [layer.name]) - layer_feats = [f for f in layer_feats if f not in features] - features.extend(layer_feats) - return features - @property def output_resolution(self): """Resolution of output data. Given as a dictionary @@ -366,6 +352,18 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): hi_res = np.concatenate((hi_res, exo_output), axis=-1) return hi_res + def get_layer_features(self): + """Get the features that are input to mid-network layers. These are + typically gapless high-resolution features like topography or sparse + observations""" + features = [] + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if isinstance(layer, SUP3R_LAYERS): + feats = getattr(layer, 'features', [layer.name]) + features.extend(feats) + return features + @property @abstractmethod def meta(self): @@ -389,26 +387,13 @@ def hr_out_features(self): def obs_features(self): """Get list of exogenous observation feature names the model uses. These come from the names of the ``Sup3rObs..`` layers.""" - features = self._get_layer_features() - return [f for f in features if '_obs' in f] + return self.meta.get('obs_features', []) @property def hr_exo_features(self): - """Get list of high-resolution exogenous filter names the model uses. - If the model has N concat or add layers this list will be the last N - features in the training features list. The ordering is assumed to be - the same as the order of concat or add layers. If training features is - [..., topo, sza], and the model has 2 concat or add layers, exo - features will be [topo, sza]. Topo will then be used in the first - concat layer and sza will be used in the second""" - features = self._get_layer_features() - features = [f for f in features if '_obs' not in f] - obs_feats = [ - f.replace('_obs', '') - for f in self.obs_features - if f not in self.hr_out_features - ] - features += [f for f in obs_feats if f not in self.hr_out_features] + """Get list of gapless exogenous high-resolution feature names the + model uses, like topography.""" + return self.meta.get('hr_exo_features', []) @property def hr_features(self): @@ -416,7 +401,7 @@ def hr_features(self): the high-resolution data during training. This includes both output and exogenous features. """ - return self.hr_out_features + self.hr_exo_features + return self.meta.get('hr_features', []) @property def smoothing(self): @@ -458,7 +443,7 @@ def set_model_params(self, **kwargs): ---------- kwargs : dict Keyword arguments including 'input_resolution', - 'lr_features', 'hr_exo_features', 'hr_out_features', + 'lr_features', 'hr_exo_features', 'hr_out_features', 'hr_features', 'obs_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ @@ -468,6 +453,7 @@ def set_model_params(self, **kwargs): 'lr_features', 'hr_exo_features', 'hr_out_features', + 'hr_features', 'obs_features', 'smoothed_features', 's_enhance', @@ -478,14 +464,6 @@ def set_model_params(self, **kwargs): if 'hr_out_features' in kwargs: self.meta['hr_out_features'] = kwargs['hr_out_features'] - hr_exo_feat = kwargs.get('hr_exo_features', []) - msg = ( - f'Expected high-res exo features {self.hr_exo_features} ' - f'based on model architecture but received "hr_exo_features" ' - f'from data handler: {hr_exo_feat}' - ) - assert list(self.hr_exo_features) == list(hr_exo_feat), msg - for var in keys: val = self.meta.get(var, None) if val is None: diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py deleted file mode 100644 index e95b12340c..0000000000 --- a/sup3r/models/with_obs.py +++ /dev/null @@ -1,334 +0,0 @@ -"""Sup3r model with training on observation data.""" - -import logging - -import numpy as np -import tensorflow as tf - -from sup3r.utilities.utilities import RANDOM_GENERATOR - -from .base import Sup3rGan - -logger = logging.getLogger(__name__) - - -class Sup3rGanWithObs(Sup3rGan): - """Sup3r GAN model which includes mid network observation fusion. This - model is useful for when production runs will be over a domain for which - observation data is available. - - Note - ---- - During training this model uses sparse sampling of ground truth data to - simulate observation data. This is done by creating masks of ground truth - data and then selecting unmasked data. The features used for sampling are - defined in the model configuration, in observation specific model layers - such as ``Sup3rObsModel`` or ``Sup3rConcatObs``. The observation features - in the model configuration should be named with an `_obs` suffix, and - during training these are matched with the corresponding true high res - features without the `_obs` suffix. e.g. If the goal is to condition a - model on sparse temperature_10m data then the observation feature should be - named temperature_10m_obs, and during training the gridded high res - temperature_10m data will be masked to create synthetic observation data - for this feature. The observation masks are only used during training. - During inference "real" observation data, with the full name defined in the - model configuration, is passed in as exogenous data with NaN values for - where the observations are not available. These NaN values are then handled - by observation specific model layers. - """ - - def __init__( - self, - *args, - use_proxy_obs=True, - onshore_obs_frac=None, - offshore_obs_frac=None, - loss_obs_weight=0.0, - loss_obs=None, - **kwargs, - ): - """ - Initialize the Sup3rGanWithObs model. - - Parameters - ---------- - args : list - Positional args for ``Sup3rGan`` parent class. - use_proxy_obs : bool - Whether to use proxy observations sampled from the gridded high res - data during training. If False, the model will expect real - observation data in the .high_res attribute of the batches in the - batch handler and will calculate the observation mask based on - where there are NaN values in the high res data. If True, the model - will create synthetic observation data by masking the gridded high - res data during training and will calculate the observation mask - based on the specified onshore and offshore observation fractions. - onshore_obs_frac : dict[List] | dict[float] - Fraction of the batch that should be treated as onshore - observations. Should include ``spatial`` key and optionally - ``time`` key if this is a spatiotemporal model. The values should - correspond roughly to the fraction of the production domain for - which onshore observations are available (spatial) and the fraction - of the full time period that these cover. The values can be either - a list (for a lower and upper bound, respectively) or a single - float. For each batch a spatial frac will be selected by either - sampling uniformly between this lower and upper bound or just - using a single float. - offshore_obs_frac : dict - Same as ``onshore_obs_frac`` but for offshore observations. - Offshore observations are frequently sparser than onshore - observations. - loss_obs : str | dict - Loss function to use for the additional observation loss term. This - defaults to the content loss function specified with ``loss``. - loss_obs_weight : float - Value used to weight observation locations in extra content loss - term. e.g. The new content loss will include ``obs_loss_weight * - MAE(hi_res_gen[~obs_mask], hi_res_true[~obs_mask])`` - kwargs : dict - Keyword arguments for the ``Sup3rGan`` parent class. - """ - super().__init__(*args, **kwargs) - self.use_proxy_obs = use_proxy_obs - self.onshore_obs_frac = ( - {} if onshore_obs_frac is None else onshore_obs_frac - ) - self.offshore_obs_frac = ( - {} if offshore_obs_frac is None else offshore_obs_frac - ) - loss_obs = self.loss_name if loss_obs is None else loss_obs - self.loss_obs_name = loss_obs - self.loss_obs_fun = self.get_loss_fun(loss_obs) - self.loss_obs_weight = loss_obs_weight - - @tf.function - def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask): - """Get loss for observation locations and for non observation - locations.""" - hr_true = hi_res_true[..., : len(self.hr_out_features)] - loss_obs, _ = self.loss_obs_fun( - hi_res_gen[~obs_mask], hr_true[~obs_mask] - ) - loss_non_obs, _ = self.loss_obs_fun( - hi_res_gen[obs_mask], hr_true[obs_mask] - ) - return loss_obs, loss_non_obs - - @property - def obs_training_inds(self): - """Get the observation feature indices in the true high res - data. True observation features are named with an '_obs' suffix. - When training with proxy observations these indices select - the corresponding gridded features (no '_obs' suffix). Otherwise, - these indices select observation features with an '_obs' suffix.""" - - if self.use_proxy_obs: - hr_feats = [f.replace('_obs', '') for f in self.hr_features] - obs_inds = [ - hr_feats.index(f.replace('_obs', '')) - for f in self.obs_features - ] - else: - obs_inds = [self.hr_features.index(f) for f in self.obs_features] - return obs_inds - - def _get_single_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): - """Get observation mask for a given spatial and temporal obs - fraction for a single batch entry. - - Parameters - ---------- - hi_res : np.ndarray - True high resolution data for a single batch entry. - spatial_frac : float - Fraction of the spatial domain that should be treated as - observations. This is a value between 0 and 1. - time_frac : float, optional - Fraction of the temporal domain that should be treated as - observations. This is a value between 0 and 1. Default is 1.0 - - Returns - ------- - np.ndarray - Mask which is True for locations that are not observed and False - for locations that are observed. - (spatial_1, spatial_2, n_features) - (spatial_1, spatial_2, n_temporal, n_features) - """ - mask_shape = [*hi_res.shape[:3], 1, len(self.hr_out_features)] - mask_shape[3] = hi_res.shape[3] if self.is_5d else 1 - s_mask = RANDOM_GENERATOR.uniform(size=mask_shape[1:3]) <= spatial_frac - s_mask = s_mask[..., None, None] - t_mask = RANDOM_GENERATOR.uniform(size=mask_shape[-2]) <= time_frac - t_mask = t_mask[None, None, ..., None] - mask = ~(s_mask & t_mask) - mask = np.repeat(mask, mask_shape[-1], axis=-1) - return mask if self.is_5d else np.squeeze(mask, axis=-2) - - def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): - """Get observation mask for a given spatial and temporal obs - fraction for an entire batch. This is divided between spatial and - temporal fractions because often the spatial fraction is significantly - lower than the temporal fraction in practice, e.g. for a given spatial - location there might be observations for most of the time period but - only a small fraction of the spatial domain is observed. - - Parameters - ---------- - hi_res : np.ndarray - True high resolution data for the entire batch. - spatial_frac : float | list - Fraction of the spatial domain that should be treated as - observations. This is a value between 0 and 1 or a list with - lower and upper bounds for the spatial fraction. - time_frac : float | list, optional - Fraction of the temporal domain that should be treated as - observations. This is a value between 0 and 1 or a list with - lower and upper bounds for the temporal fraction. Default is 1.0 - - Returns - ------- - np.ndarray - Mask which is True for locations that are not observed and False - for locations that are observed. - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - s_range = ( - spatial_frac - if isinstance(spatial_frac, (list, tuple)) - else [spatial_frac, spatial_frac] - ) - t_range = ( - time_frac - if isinstance(time_frac, (list, tuple)) - else [time_frac, time_frac] - ) - s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=hi_res.shape[0]) - t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=hi_res.shape[0]) - s_fracs = np.clip(s_fracs, 0, 1) - t_fracs = np.clip(t_fracs, 0, 1) - mask = tf.stack( - [ - self._get_single_obs_mask(hi_res, s, t) - for s, t in zip(s_fracs, t_fracs) - ], - axis=0, - ) - return mask - - def _get_full_obs_mask(self, hi_res): - """Define observation mask for the current batch. This differs from - ``_get_obs_mask`` by defining a composite mask based on separate - onshore and offshore masks. This is because there is often more - observation data available onshore than offshore.""" - on_sf = self.onshore_obs_frac.get('spatial', 0.0) - on_tf = self.onshore_obs_frac.get('time', 1.0) - obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf) - if 'topography' in self.hr_features and self.offshore_obs_frac: - topo_idx = self.hr_features.index('topography') - topo = hi_res[..., topo_idx] - off_sf = self.offshore_obs_frac.get('spatial', 0.0) - off_tf = self.offshore_obs_frac.get('time', 1.0) - offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf) - obs_mask = tf.where(topo[..., None] > 0, obs_mask, offshore_mask) - return obs_mask - - @property - def model_params(self): - """ - Model parameters, used to save model to disc - - Returns - ------- - dict - """ - params = super().model_params - params['onshore_obs_frac'] = self.onshore_obs_frac - params['offshore_obs_frac'] = self.offshore_obs_frac - params['loss_obs_weight'] = self.loss_obs_weight - params['loss_obs'] = self.loss_obs_name - return params - - @tf.function - def get_hr_exo_input(self, hi_res_true): - """Mask high res data to act as sparse observation data. Add this to - the standard high res exo input""" - exo_data = super().get_hr_exo_input(hi_res_true) - if len(self.obs_features) == 0: - return exo_data - - if self.use_proxy_obs: - obs, obs_mask = self._get_proxy_obs(hi_res_true) - else: - obs, obs_mask = self._get_real_obs(hi_res_true) - - obs = tf.expand_dims(obs, axis=-2) - exo_obs = dict(zip(self.obs_features, tf.unstack(obs, axis=-1))) - exo_data.update(exo_obs) - exo_data['mask'] = obs_mask - return exo_data - - def _get_real_obs(self, hi_res_true): - """Get real observation data and the corresponding mask from the - .high_res attribute of the batches in the batch handler. This is used - when not training with proxy observations.""" - obs = tf.gather(hi_res_true, self.obs_training_inds, axis=-1) - obs_mask = tf.math.is_nan(obs) - return obs, obs_mask - - def _get_proxy_obs(self, hi_res_true): - """Get proxy observation data by masking the true high res data.""" - obs_mask = self._get_full_obs_mask(hi_res_true) - nan_const = tf.constant(float('nan'), dtype=hi_res_true.dtype) - obs = tf.gather(hi_res_true, self.obs_training_inds, axis=-1) - obs = tf.where(obs_mask[..., : obs.shape[-1]], nan_const, obs) - return obs, obs_mask - - def _get_hr_exo_and_loss( - self, - low_res, - hi_res_true, - **calc_loss_kwargs, - ): - """Get high-resolution exogenous data, generate synthetic output, and - compute loss. Includes artificially masking hi res data to act as - sparse observation data.""" - out = super()._get_hr_exo_and_loss( - low_res, hi_res_true, **calc_loss_kwargs - ) - loss, loss_details, hi_res_gen, hi_res_exo = out - - if calc_loss_kwargs.get('train_gen', True): - loss_obs, loss_non_obs = self._get_loss_obs_comparison( - hi_res_true, - hi_res_gen, - hi_res_exo['mask'], - ) - n_obs = tf.reduce_sum(tf.cast(~hi_res_exo['mask'], tf.float32)) - n_total = tf.cast(tf.size(hi_res_exo['mask']), tf.float32) - obs_frac = n_obs / n_total - loss_update = { - 'loss_obs': loss_obs, - 'loss_non_obs': loss_non_obs, - 'obs_frac': obs_frac, - } - if self.loss_obs_weight and obs_frac > 0: - loss_obs *= self.loss_obs_weight - loss += loss_obs - loss_details['loss_gen'] += loss_obs - loss_details['loss_gen_content'] += loss_obs - loss_details.update(loss_update) - return loss, loss_details, hi_res_gen, hi_res_exo - - def _post_batch(self, ib, b_loss_details, n_batches, previous_means): - """Update loss details after the current batch and write to log.""" - if 'obs_frac' in b_loss_details: - logger.debug( - 'Batch {} out of {} has obs_frac: {:.4e}'.format( - ib + 1, n_batches, b_loss_details['obs_frac'] - ) - ) - return super()._post_batch( - ib, b_loss_details, n_batches, previous_means - ) diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index c22e19992d..d19cc8e337 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -151,21 +151,31 @@ def __init__( memory right away. feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. + split between ``lr_features``, ``hr_features``, and + ``hr_out_features``. - features : list | tuple - List of full set of features to use for sampling. If no - entry is provided then all data_vars from container data + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output by + the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will be used. - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. This hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model + mid-network to condition output on high-resolution + information. The model configuration should have the + appropriate layers to use these features. e.g. + ``Sup3rConcat`` for topography injection, ``Sup3rObsModel`` + or ``Sup3rCrossAttention`` for obs injection. If no entry + is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. kwargs : dict Additional keyword arguments for BatchQueue and / or Samplers. This can vary depending on the type of BatchQueue / Sampler diff --git a/sup3r/preprocessing/batch_queues/base.py b/sup3r/preprocessing/batch_queues/base.py index ccba5deb09..ee22439666 100644 --- a/sup3r/preprocessing/batch_queues/base.py +++ b/sup3r/preprocessing/batch_queues/base.py @@ -69,7 +69,8 @@ def transform( (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features) """ - low_res = spatial_coarsening(samples, self.s_enhance) + lr_samples = numpy_if_tensor(samples)[..., self.lr_features_ind] + low_res = spatial_coarsening(lr_samples, self.s_enhance) low_res = ( low_res if self.t_enhance == 1 @@ -81,7 +82,7 @@ def transform( smoothing_ignore if smoothing_ignore is not None else [] ) low_res = smooth_data( - low_res, self.features, smoothing_ignore, smoothing + low_res, self.lr_features, smoothing_ignore, smoothing ) high_res = numpy_if_tensor(samples)[..., self.hr_features_ind] return low_res, high_res diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 56b2b08d42..865e6ff745 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -20,7 +20,6 @@ def __init__(self, samplers, **kwargs): -------- :class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue` """ - self.BATCH_MEMBERS = samplers[0].dset_names super().__init__(samplers, **kwargs) self.check_enhancement_factors() @@ -31,16 +30,10 @@ def queue_shape(self): """Shape of objects stored in the queue. Optionally includes shape of observation data which would be included in an extra content loss term""" - obs_shape = ( - *self.hr_shape[:-1], - len(self.containers[0].hr_out_features), - ) - queue_shapes = [ + return [ (self.batch_size, *self.lr_shape), (self.batch_size, *self.hr_shape), - (self.batch_size, *obs_shape), ] - return queue_shapes[: len(self.BATCH_MEMBERS)] def check_enhancement_factors(self): """Make sure each DualSampler has the same enhancment factors and they diff --git a/sup3r/preprocessing/collections/stats.py b/sup3r/preprocessing/collections/stats.py index 665994547c..1f49fdad81 100644 --- a/sup3r/preprocessing/collections/stats.py +++ b/sup3r/preprocessing/collections/stats.py @@ -62,7 +62,16 @@ def _get_stat(self, stat_type, needed_features='all'): getattr(c.high_res[hr_feats], stat_type)(skipna=True) for c in self.containers ] - if any(lr_feats): + lr_check = any(lr_feats) + container_check = all(hasattr(c, 'low_res') for c in self.containers) + if lr_check and not container_check: + msg = ( + f'Found low-res features {lr_feats} but not all containers ' + 'have low-res data. ' + ) + logger.error(msg) + raise ValueError(msg) + elif lr_check and container_check: cstats_lr = [ getattr(c.low_res[lr_feats], stat_type)(skipna=True) for c in self.containers diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 981cfec09f..66e1a9f6f3 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -92,11 +92,11 @@ def __init__( f'Sup3rDataset instance. Received {type(data)}.' ) assert isinstance(data, Sup3rDataset), msg - self.lr_data, self.hr_data = data.low_res, data.high_res + self.data = data self.regrid_workers = regrid_workers - lr_step = self.lr_data.time_step - hr_step = self.hr_data.time_step + lr_step = self.data.low_res.time_step + hr_step = self.data.high_res.time_step msg = ( f'Time steps of high-res data ({hr_step} seconds) and low-res ' f'data ({lr_step} seconds) are inconsistent with t_enhance = ' @@ -105,9 +105,9 @@ def __init__( assert np.allclose(lr_step, hr_step * self.t_enhance), msg self.lr_required_shape = ( - self.hr_data.shape[0] // self.s_enhance, - self.hr_data.shape[1] // self.s_enhance, - self.hr_data.shape[2] // self.t_enhance, + self.data.high_res.shape[0] // self.s_enhance, + self.data.high_res.shape[1] // self.s_enhance, + self.data.high_res.shape[2] // self.t_enhance, ) self.hr_required_shape = ( self.s_enhance * self.lr_required_shape[0], @@ -118,16 +118,16 @@ def __init__( msg = ( f'The required low-res shape {self.lr_required_shape} is ' 'inconsistent with the shape of the raw data ' - f'{self.lr_data.shape}' + f'{self.data.low_res.shape}' ) assert all( req_s <= true_s for req_s, true_s in zip( - self.lr_required_shape, self.lr_data.shape + self.lr_required_shape, self.data.low_res.shape ) ), msg - self.hr_lat_lon = self.hr_data.lat_lon[ + self.hr_lat_lon = self.data.high_res.lat_lon[ slice(self.hr_required_shape[0]), slice(self.hr_required_shape[1]) ] self.lr_lat_lon = spatial_coarsening( @@ -137,49 +137,51 @@ def __init__( self.update_lr_data() self.update_hr_data() - super().__init__(data=(self.lr_data, self.hr_data)) + super().__init__(data=(self.data.low_res, self.data.high_res)) if run_qa: self.check_regridded_lr_data() if lr_cache_kwargs is not None: - Cacher(self.lr_data, lr_cache_kwargs) + Cacher(self.data.low_res, lr_cache_kwargs) if hr_cache_kwargs is not None: - Cacher(self.hr_data, hr_cache_kwargs) + Cacher(self.data.high_res, hr_cache_kwargs) def update_hr_data(self): """Set the high resolution data attribute and check if hr_data.shape is divisible by s_enhance. If not, take the largest shape that can be.""" msg = ( - f'hr_data.shape: {self.hr_data.shape[:3]} is not ' + f'hr_data.shape: {self.data.high_res.shape[:3]} is not ' f'divisible by s_enhance: {self.s_enhance}. Using shape: ' f'{self.hr_required_shape} instead.' ) - need_new_shape = self.hr_data.shape[:3] != self.hr_required_shape[:3] + need_new_shape = ( + self.data.high_res.shape[:3] != self.hr_required_shape[:3] + ) if need_new_shape: logger.warning(msg) warn(msg) hr_data_new = {} - for f in self.hr_data.features: + for f in self.data.high_res.features: hr_slices = [slice(sh) for sh in self.hr_required_shape] - hr = self.hr_data.to_dataarray().sel(variable=f).data + hr = self.data.high_res.to_dataarray().sel(variable=f).data hr_data_new[f] = hr[tuple(hr_slices)] hr_coords_new = { Dimension.LATITUDE: self.hr_lat_lon[..., 0], Dimension.LONGITUDE: self.hr_lat_lon[..., 1], - Dimension.TIME: self.hr_data.indexes['time'][ + Dimension.TIME: self.data.high_res.indexes['time'][ : self.hr_required_shape[2] ], } logger.info( - 'Updating self.hr_data with new shape: ' + 'Updating self.data.high_res with new shape: ' f'{self.hr_required_shape[:3]}' ) - self.hr_data = self.hr_data.update_ds({ + self.data.high_res = self.data.high_res.update_ds({ **hr_coords_new, **hr_data_new, }) @@ -191,7 +193,9 @@ def get_regridder(self): data=self.lr_lat_lon.reshape((-1, 2)), ) return Regridder( - self.lr_data.meta, target_meta, max_workers=self.regrid_workers + self.data.low_res.meta, + target_meta, + max_workers=self.regrid_workers, ) def update_lr_data(self): @@ -203,20 +207,20 @@ def update_lr_data(self): regridder = self.get_regridder() lr_data_new = {} - for f in self.lr_data.features: - lr = self.lr_data.to_dataarray().sel(variable=f).data + for f in self.data.low_res.features: + lr = self.data.low_res.to_dataarray().sel(variable=f).data lr = lr[..., : self.lr_required_shape[2]] lr_data_new[f] = regridder(lr).reshape(self.lr_required_shape) lr_coords_new = { Dimension.LATITUDE: self.lr_lat_lon[..., 0], Dimension.LONGITUDE: self.lr_lat_lon[..., 1], - Dimension.TIME: self.lr_data.indexes[Dimension.TIME][ + Dimension.TIME: self.data.low_res.indexes[Dimension.TIME][ : self.lr_required_shape[2] ], } - logger.info('Updating self.lr_data with regridded data.') - self.lr_data = self.lr_data.update_ds({ + logger.info('Updating self.data.low_res with regridded data.') + self.data.low_res = self.data.low_res.update_ds({ **lr_coords_new, **lr_data_new, }) @@ -225,8 +229,8 @@ def check_regridded_lr_data(self): """Check for NaNs after regridding and do NN fill if needed.""" fill_feats = [] logger.info('Checking for NaNs after regridding') - qa_info = self.lr_data.qa(stats=['nan_perc']) - for f in self.lr_data.features: + qa_info = self.data.low_res.qa(stats=['nan_perc']) + for f in self.data.low_res.features: nan_perc = qa_info[f]['nan_perc'] if nan_perc > 0: msg = f'{f} data has {nan_perc:.3f}% NaN values!' @@ -244,6 +248,6 @@ def check_regridded_lr_data(self): f'features = {fill_feats}' ) logger.info(msg) - self.lr_data = self.lr_data.interpolate_na( + self.data.low_res = self.data.low_res.interpolate_na( features=fill_feats, method='nearest' ) diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index b6be139d71..e6b02802f1 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -15,6 +15,7 @@ uniform_time_sampler, ) from sup3r.preprocessing.utilities import compute_if_dask, lowered +from sup3r.utilities.utilities import RANDOM_GENERATOR logger = logging.getLogger(__name__) @@ -29,6 +30,7 @@ def __init__( sample_shape: Optional[tuple] = None, batch_size: int = 16, feature_sets: Optional[dict] = None, + proxy_obs_kwargs: Optional[dict] = None, mode: str = 'lazy', ): """ @@ -52,20 +54,46 @@ def __init__( samples and then stacking. feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between ``lr_only_features`` and ``hr_exo_features``. - - features : list | tuple - List of full set of features to use for sampling. If no entry - is provided then all data_vars from data will be used. - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. + + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. + proxy_obs_kwargs : dict | None + Optional dictionary of keyword arguments to pass to the proxy + observation generator. This is only used when training with proxy + observations. Keys can include ``onshore_obs_frac`` and + ``offshore_obs_frac`` which specify the fraction of the batch that + should be treated as onshore and offshore observations, + respectively. For example, ``proxy_obs_kwargs={'onshore_obs_frac': + {'spatial': 0.1, 'temporal': 0.2}, 'offshore_obs_frac': {'spatial': + 0.05, 'temporal': 0.1}}`` would specify that for the onshore + region observations cover 10% of the spatial domain and 20% of the + temporal domain, while for the offshore region observations cover + 5% of the spatial domain and 10% of the temporal domain. Instead of + a single float, these can also be lists to specify a lower and + upper bound for the spatial and temporal fractions, in which case + the actual fraction for each batch will be sampled uniformly + between these bounds. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -74,14 +102,54 @@ def __init__( """ super().__init__(data=data) feature_sets = feature_sets or {} - self.features = feature_sets.get('features', self.data.features) - self._lr_only_features = feature_sets.get('lr_only_features', []) + self._lr_features = feature_sets.get('lr_features', self.data.features) self._hr_exo_features = feature_sets.get('hr_exo_features', []) + self._hr_out_features = feature_sets.get('hr_out_features', []) + self.proxy_obs_kwargs = proxy_obs_kwargs or {} self.mode = mode self.sample_shape = sample_shape or (10, 10, 1) self.batch_size = batch_size - self.lr_features = self.features self.preflight() + self.check_feature_consistency() + + @property + def use_proxy_obs(self): + """Whether to use proxy observations. When True, proxy observation + features are generated by masking the corresponding gridded ground + truth data and are appended to the samples. The obs features are + specified by the ``obs_features`` argument and should have a + corresponding source feature in the data features that is used for + sampling. For example, an obs feature named ``temperature_obs`` would + be generated from the gridded ground truth feature named + ``temperature``. + """ + check = bool(self.proxy_obs_kwargs) + check = check or ( + len(self.obs_features) > 0 + and all(f not in self.features for f in self.obs_features) + ) + return check + + @property + def onshore_obs_frac(self): + """Fraction of onshore observations to include in each batch when using + proxy observations. This can be a single float or a dictionary with + keys 'spatial' and 'temporal' to specify the fraction for each domain. + If a dictionary is provided, the actual fraction for each batch will be + sampled uniformly between the specified spatial and temporal fractions. + """ + return self.proxy_obs_kwargs.get('onshore_obs_frac', {}) + + @property + def offshore_obs_frac(self): + """Fraction of offshore observations to include in each batch when + using proxy observations. This can be a single float or a dictionary + with keys 'spatial' and 'temporal' to specify the fraction for each + domain. If a dictionary is provided, the actual fraction for each + batch will be sampled uniformly between the specified spatial and + temporal fractions. + """ + return self.proxy_obs_kwargs.get('offshore_obs_frac', {}) def get_sample_index(self, n_obs=None): """Randomly gets spatiotemporal sample index. @@ -105,11 +173,15 @@ def get_sample_index(self, n_obs=None): time_slice = uniform_time_sampler( self.shape, self.sample_shape[2] * n_obs ) - return (*spatial_slice, time_slice, self.features) + feats = ( + self.features + if not self.use_proxy_obs + else self.features[: -len(self.obs_features)] + ) + return (*spatial_slice, time_slice, feats) def preflight(self): - """Check if the sample_shape is larger than the requested raster - size""" + """Perform shape and feature checks.""" good_shape = ( self.sample_shape[0] <= self.data.shape[0] and self.sample_shape[1] <= self.data.shape[1] @@ -144,6 +216,45 @@ def preflight(self): logger.info('Received mode = "eager".') _ = self.compute() + def check_feature_consistency(self): + """Check that the feature sets are consistent with each other and the + obs features are configured correctly.""" + if self.use_proxy_obs and not all( + f in self.hr_features for f in self.obs_features + ): + msg = ( + 'When using proxy observations, all obs features must be ' + 'included either in hr_out_features or hr_exo_features.' + ) + raise ValueError(msg) + + if self.use_proxy_obs and any( + f in self.data.features for f in self.obs_features + ): + msg = ( + f'Obs features {self.obs_features} cannot be in the data ' + f'features {self.data.features} when using proxy observations.' + ) + raise ValueError(msg) + + if len(self.obs_features) > 0: + msg = ( + f'Obs features {self.obs_features} must come at the end of ' + f'the hr_exo_features {self.hr_exo_features}' + ) + assert list(self.obs_features) == list( + self.hr_exo_features[-len(self.obs_features) :] + ), msg + + if len(self.hr_exo_features) > 0: + msg = ( + f'hr_exo_features {self.hr_exo_features} must come at the end ' + f'of the full high-res feature set: {self.hr_features}' + ) + assert list(self.hr_exo_features) == list( + self.hr_features[-len(self.hr_exo_features) :] + ), msg + @property def sample_shape(self) -> tuple: """Shape of the data sample to select when ``__next__()`` is called.""" @@ -256,8 +367,10 @@ def _fast_batch(self): out = self.data.sample(self.get_sample_index(n_obs=self.batch_size)) out = self._compute_samples(out) if isinstance(out, tuple): - return tuple(self._reshape_samples(o) for o in out) - return self._reshape_samples(out) + out = tuple(self._reshape_samples(o) for o in out) + else: + out = self._reshape_samples(out) + return self._append_obs_features(out) def _slow_batch(self): """Get batch of samples with random time slices.""" @@ -266,23 +379,104 @@ def _slow_batch(self): for _ in range(self.batch_size) ] out = self._compute_samples(out) - return self._stack_samples(out) + out = self._stack_samples(out) + return self._append_obs_features(out) def _fast_batch_possible(self): return self.batch_size * self.sample_shape[2] <= self.data.shape[2] + @property + def obs_features_ind(self): + """Get the source feature indices in ``features`` for each obs + feature. Each obs feature named ``_obs`` maps to the + corresponding ```` in the features. + + Returns + ------- + list[int] + Indices into ``features`` for each obs feature source. + """ + if len(self.obs_features) == 0: + return [] + + if self.use_proxy_obs: + return [ + self.hr_features.index(f.replace('_obs', '')) + for f in self.obs_features + ] + else: + return [self.hr_features.index(f) for f in self.obs_features] + + def _get_proxy_obs(self, hi_res): + """Generate proxy observation data by masking the gridded high-res + data. Unobserved locations are set to NaN. + + Parameters + ---------- + hi_res : np.ndarray + High resolution batch data with shape: + (batch_size, spatial_1, spatial_2, temporal, n_features) + + Returns + ------- + obs : np.ndarray + Observation data with NaN for unobserved locations. Shape: + (batch_size, spatial_1, spatial_2, temporal, n_obs_features) + """ + obs_mask = self._get_full_obs_mask(hi_res) + obs = hi_res[..., self.obs_features_ind].copy() + obs[obs_mask[..., : obs.shape[-1]]] = np.nan + return obs + + def _append_obs_features(self, samples): + """Append proxy observation features to the batch samples when + ``use_proxy_obs=True``. The obs features are generated by masking + the corresponding gridded ground truth features. + + Parameters + ---------- + samples : np.ndarray | tuple[np.ndarray, ...] + Batch samples from the data source. For single datasets, shape + is (batch_size, s1, s2, t, n_features). For dual datasets, + this is a tuple of arrays. + + Returns + ------- + samples : np.ndarray | tuple[np.ndarray, ...] + Same as input, but with obs features appended to the last + dimension if proxy obs are enabled. + """ + if not self.use_proxy_obs: + return samples + + if isinstance(samples, tuple): + # For dual datasets, obs features are appended to the high-res + # member (last element) + hr = samples[-1] + obs = self._get_proxy_obs(hr) + hr = np.concatenate([hr, obs], axis=-1) + return (*samples[:-1], hr) + + obs = self._get_proxy_obs(samples) + return np.concatenate([samples, obs], axis=-1) + def __next__(self): """Get next batch of samples. This retrieves n_samples = batch_size with shape = sample_shape from the `.data` (a xr.Dataset or Sup3rDataset) through the Sup3rX accessor. + When ``use_proxy_obs=True`` and ``obs_features`` are configured, proxy + observation features are generated by masking the corresponding + gridded ground truth data and are appended to the samples. + Returns ------- samples : tuple(np.ndarray | da.core.Array) | np.ndarray | da.core.Array Either a tuple or single array of samples. This is a tuple when this method is sampling from a ``Sup3rDataset`` with two data - members - """ # pylint: disable=line-too-long # noqa + members. When proxy obs are enabled, obs features are appended + to the feature dimension. + """ # noqa: E501 if self._fast_batch_possible(): return self._fast_batch() return self._slow_batch() @@ -311,10 +505,31 @@ def _parse_features(self, unparsed_feats): return lowered(parsed_feats) @property - def lr_only_features(self): - """List of feature names or patt*erns that should only be included in - the low-res training set and not the high-res observations.""" - return self._parse_features(self._lr_only_features) + def lr_features(self): + """List of feature names or patt*erns to use as low-resolution model + inputs. If no entry is provided then all available features from the + data will be used.""" + return self._parse_features(self._lr_features) + + @property + def hr_features(self): + """List of feature names or patt*erns that should be available as + either high-resolution model inputs (like topography or observations) + or as ground truth targets. If no entry is provided then all available + features from data will be used.""" + out = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + out += self.hr_exo_features + return out + + @property + def hr_out_features(self): + """List of feature names or patt*erns that should be output by the + generative model. If no entry is provided then all features in + hr_features will be used.""" + hr_out = self._parse_features(self._hr_out_features) + return self.lr_features if len(hr_out) == 0 else hr_out @property def hr_exo_features(self): @@ -322,63 +537,141 @@ def hr_exo_features(self): for training e.g., mid-network high-res topo injection. These must come at the end of the high-res feature set. These can also be input to the model as low-res features.""" - self._hr_exo_features = self._parse_features(self._hr_exo_features) - - if len(self._hr_exo_features) > 0: - msg = ( - f'High-res train-only features "{self._hr_exo_features}" ' - f'do not come at the end of the full high-res feature set: ' - f'{self.features}' - ) - last_feat = self.features[-len(self._hr_exo_features) :] - assert list(self._hr_exo_features) == list(last_feat), msg - - return self._hr_exo_features + return self._parse_features(self._hr_exo_features) @property - def hr_out_features(self): - """Get a list of high-resolution features that are intended to be - output by the GAN. Does not include high-resolution exogenous - features""" - - out = [] - for feature in self.features: - lr_only = any( - fnmatch(feature.lower(), pattern.lower()) - for pattern in self.lr_only_features - ) - ignore = lr_only or feature in self.hr_exo_features - if not ignore: - out.append(feature) - - if len(out) == 0: - msg = ( - f'It appears that all handler features "{self.features}" ' - 'were specified as `hr_exo_features` or `lr_only_features` ' - 'and therefore there are no output features!' - ) - logger.error(msg) - raise RuntimeError(msg) - - return lowered(out) + def obs_features(self): + """List of feature names or patt*erns that should be treated as + observations. These features will be included in the high-res data but + not the low-res data and won't necessarily be expected to be output by + the generative model. These are different from the `hr_exo_features` in + that they are intended to be used as observation features with NaN + values where observations are not available.""" + return [f for f in self.hr_features if '_obs' in f] @property def hr_features_ind(self): """Get the high-resolution feature channel indices that should be - included for training. Any high-resolution features that are only - included in the data handler to be coarsened for the low-res input are - removed""" - hr_features = list(self.hr_out_features) + list(self.hr_exo_features) - if list(self.features) == hr_features: - return np.arange(len(self.features)) - return [ - i - for i, feature in enumerate(self.features) - if feature in hr_features - ] + included for training. This includes hr_out_features and + hr_exo_features, Any high-resolution features that are only included in + the data handler to be coarsened for the low-res input are removed. + """ + return [self.features.index(f) for f in self.hr_features] @property - def hr_features(self): - """Get the high-resolution features corresponding to - `hr_features_ind`""" - return [self.features[ind].lower() for ind in self.hr_features_ind] + def lr_features_ind(self): + """Get the low-resolution feature channel indices that should be + included for training. This includes lr_features. + """ + return [self.features.index(f) for f in self.lr_features] + + @property + def features(self): + """Get the full set of features that should be included for training. + This is the union of lr_features, hr_out_features, hr_exo_features, and + obs_features. This is the set of features that will be sampled from the + data.""" + feats = self.lr_features + feats += [f for f in self.hr_out_features if f not in feats] + feats += [f for f in self.hr_exo_features if f not in feats] + return feats + + def _get_single_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): + """Get observation mask for a given spatial and temporal obs + fraction for a single batch entry. + + Parameters + ---------- + hi_res : np.ndarray + True high resolution data for a single batch entry. + spatial_frac : float + Fraction of the spatial domain that should be treated as + observations. This is a value between 0 and 1. + time_frac : float, optional + Fraction of the temporal domain that should be treated as + observations. This is a value between 0 and 1. Default is 1.0 + + Returns + ------- + np.ndarray + Mask which is True for locations that are not observed and False + for locations that are observed. + (spatial_1, spatial_2, n_features) + (spatial_1, spatial_2, n_temporal, n_features) + """ + mask_shape = [*hi_res.shape[:-1], len(self.hr_out_features)] + s_mask = RANDOM_GENERATOR.uniform(size=mask_shape[1:3]) <= spatial_frac + s_mask = s_mask[..., None, None] + t_mask = RANDOM_GENERATOR.uniform(size=mask_shape[-2]) <= time_frac + t_mask = t_mask[None, None, ..., None] + mask = ~(s_mask & t_mask) + return np.repeat(mask, mask_shape[-1], axis=-1) + + def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): + """Get observation mask for a given spatial and temporal obs + fraction for an entire batch. This is divided between spatial and + temporal fractions because often the spatial fraction is significantly + lower than the temporal fraction in practice, e.g. for a given spatial + location there might be observations for most of the time period but + only a small fraction of the spatial domain is observed. + + Parameters + ---------- + hi_res : np.ndarray + True high resolution data for the entire batch. + spatial_frac : float | list + Fraction of the spatial domain that should be treated as + observations. This is a value between 0 and 1 or a list with + lower and upper bounds for the spatial fraction. + time_frac : float | list, optional + Fraction of the temporal domain that should be treated as + observations. This is a value between 0 and 1 or a list with + lower and upper bounds for the temporal fraction. Default is 1.0 + + Returns + ------- + np.ndarray + Mask which is True for locations that are not observed and False + for locations that are observed. + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + s_range = ( + spatial_frac + if isinstance(spatial_frac, (list, tuple)) + else [spatial_frac, spatial_frac] + ) + t_range = ( + time_frac + if isinstance(time_frac, (list, tuple)) + else [time_frac, time_frac] + ) + s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=hi_res.shape[0]) + t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=hi_res.shape[0]) + s_fracs = np.clip(s_fracs, 0, 1) + t_fracs = np.clip(t_fracs, 0, 1) + mask = np.stack( + [ + self._get_single_obs_mask(hi_res, s, t) + for s, t in zip(s_fracs, t_fracs) + ], + axis=0, + ) + return mask + + def _get_full_obs_mask(self, hi_res): + """Define observation mask for the current batch. This differs from + ``_get_obs_mask`` by defining a composite mask based on separate + onshore and offshore masks. This is because there is often more + observation data available onshore than offshore.""" + on_sf = self.onshore_obs_frac.get('spatial', 0.0) + on_tf = self.onshore_obs_frac.get('time', 1.0) + obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf) + if 'topography' in self.hr_features and self.offshore_obs_frac: + topo_idx = self.hr_features.index('topography') + topo = hi_res[..., topo_idx] + off_sf = self.offshore_obs_frac.get('spatial', 0.0) + off_tf = self.offshore_obs_frac.get('time', 1.0) + offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf) + obs_mask = np.where(topo[..., None] > 0, obs_mask, offshore_mask) + return obs_mask diff --git a/sup3r/preprocessing/samplers/cc.py b/sup3r/preprocessing/samplers/cc.py index 85af6d88c1..57fead2dbf 100644 --- a/sup3r/preprocessing/samplers/cc.py +++ b/sup3r/preprocessing/samplers/cc.py @@ -54,17 +54,30 @@ def __init__( Temporal enhancement factor feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between ``lr_only_features`` and ``hr_exo_features``. - - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. + + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -100,21 +113,21 @@ def __init__( mode=mode, ) - def check_for_consistent_shapes(self): + def check_shape_consistency(self): """Make sure container shapes and sample shapes are compatible with enhancement factors.""" enhanced_shape = ( - self.lr_data.shape[0] * self.s_enhance, - self.lr_data.shape[1] * self.s_enhance, - self.lr_data.shape[2] * (1 if self.t_enhance == 1 else 24), + self.data.low_res.shape[0] * self.s_enhance, + self.data.low_res.shape[1] * self.s_enhance, + self.data.low_res.shape[2] * (1 if self.t_enhance == 1 else 24), ) msg = ( - f'hr_data.shape {self.hr_data.shape} and enhanced ' + f'hr_data.shape {self.data.high_res.shape} and enhanced ' f'lr_data.shape {enhanced_shape} are not compatible with ' f'the given enhancement factors t_enhance = {self.t_enhance}, ' f's_enhance = {self.s_enhance}' ) - assert self.hr_data.shape[:3] == enhanced_shape, msg + assert self.data.high_res.shape[:3] == enhanced_shape, msg def reduce_high_res_sub_daily(self, high_res, csr_ind=0): """Take an hourly high-res observation and reduce the temporal axis diff --git a/sup3r/preprocessing/samplers/dc.py b/sup3r/preprocessing/samplers/dc.py index a48c80df57..2312ce9ced 100644 --- a/sup3r/preprocessing/samplers/dc.py +++ b/sup3r/preprocessing/samplers/dc.py @@ -56,8 +56,30 @@ def __init__( efficient than getting N = batch_size samples and then stacking. feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. See - :class:`~sup3r.preprocessing.Sampler` + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. + + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. + hr_exo_features : list | tuple + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. mode : str Loading mode for sampling. See :class:`~sup3r.preprocessing.Sampler` diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index 8eb8a446c0..b81e2fff11 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -6,7 +6,6 @@ from typing import Optional from sup3r.preprocessing.base import Sup3rDataset -from sup3r.preprocessing.utilities import lowered from .base import Sampler from .utilities import uniform_box_sampler, uniform_time_sampler @@ -29,6 +28,7 @@ def __init__( s_enhance: int = 1, t_enhance: int = 1, feature_sets: Optional[dict] = None, + proxy_obs_kwargs: Optional[dict] = None, mode: str = 'lazy', ): """ @@ -36,7 +36,7 @@ def __init__( ---------- data : Sup3rDataset A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with - low-res and high-res data members, and optionally an obs member. + low-res and high-res data members. sample_shape : tuple Size of arrays to sample from the high-res data. The sample shape for the low-res sampler will be determined from the enhancement @@ -47,17 +47,46 @@ def __init__( Temporal enhancement factor feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. + proxy_obs_kwargs : dict | None + Optional dictionary of keyword arguments to pass to the proxy + observation generator. This is only used when training with proxy + observations. Keys can include ``onshore_obs_frac`` and + ``offshore_obs_frac`` which specify the fraction of the batch that + should be treated as onshore and offshore observations, + respectively. For example, ``proxy_obs_kwargs={'onshore_obs_frac': + {'spatial': 0.1, 'temporal': 0.2}, 'offshore_obs_frac': {'spatial': + 0.05, 'temporal': 0.1}}`` would specify that for the onshore + region observations cover 10% of the spatial domain and 20% of the + temporal domain, while for the offshore region observations cover + 5% of the spatial domain and 10% of the temporal domain. Instead of + a single float, these can also be lists to specify a lower and + upper bound for the spatial and temporal fractions, in which case + the actual fraction for each batch will be sampled uniformly + between these bounds. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -66,38 +95,40 @@ def __init__( """ msg = ( f'{self.__class__.__name__} requires a Sup3rDataset object ' - 'with `.low_res` and `.high_res` data members, and optionally an ' - '`.obs` member, in that order' + 'with `.low_res` and `.high_res` data members, in that order' ) - dnames = ['low_res', 'high_res', 'obs'][: len(data)] + dnames = ['low_res', 'high_res'] check = ( hasattr(data, dname) and getattr(data, dname) == data[i] for i, dname in enumerate(dnames) ) assert check, msg - super().__init__( - data=data, - sample_shape=sample_shape, - batch_size=batch_size, - mode=mode, + self.data = data + feature_sets = feature_sets or {} + self._lr_features = feature_sets.get( + 'lr_features', self.data.low_res.features + ) + self._hr_exo_features = feature_sets.get('hr_exo_features', []) + self._hr_out_features = feature_sets.get( + 'hr_out_features', self.data.high_res.features ) - self.lr_data, self.hr_data = self.data.low_res, self.data.high_res + self.proxy_obs_kwargs = proxy_obs_kwargs or {} + self.mode = mode + self.sample_shape = sample_shape or (10, 10, 1) + self.batch_size = batch_size + self.lr_sample_shape = ( self.hr_sample_shape[0] // s_enhance, self.hr_sample_shape[1] // s_enhance, self.hr_sample_shape[2] // t_enhance, ) - feature_sets = feature_sets or {} - self._lr_only_features = feature_sets.get('lr_only_features', []) - self._hr_exo_features = feature_sets.get('hr_exo_features', []) - self.features = self.get_features(feature_sets) - self.lr_features = [ - f for f in self.features if f in self.lr_data.features - ] self.s_enhance = s_enhance self.t_enhance = t_enhance - self.check_for_consistent_shapes() + + self.preflight() + self.check_shape_consistency() + self.check_feature_consistency() post_init_args = { 'lr_sample_shape': self.lr_sample_shape, 'hr_sample_shape': self.hr_sample_shape, @@ -106,33 +137,20 @@ def __init__( } self.post_init_log(post_init_args) - def get_features(self, feature_sets): - """Return default set of features composed from data vars in low res - and high res data objects or the value provided through the - feature_sets dictionary.""" - features = [] - _ = [ - features.append(f) - for f in [*self.lr_data.features, *self.hr_data.features] - if f not in features and f not in lowered(self._hr_exo_features) - ] - features += lowered(self._hr_exo_features) - return feature_sets.get('features', features) - - def check_for_consistent_shapes(self): + def check_shape_consistency(self): """Make sure container shapes are compatible with enhancement factors.""" enhanced_shape = ( - self.lr_data.shape[0] * self.s_enhance, - self.lr_data.shape[1] * self.s_enhance, - self.lr_data.shape[2] * self.t_enhance, + self.data.low_res.shape[0] * self.s_enhance, + self.data.low_res.shape[1] * self.s_enhance, + self.data.low_res.shape[2] * self.t_enhance, ) msg = ( - f'hr_data.shape {self.hr_data.shape[:-1]} and enhanced ' + f'hr_data.shape {self.data.high_res.shape[:-1]} and enhanced ' f'lr_data.shape {enhanced_shape} are not compatible with ' 'the given enhancement factors' ) - assert self.hr_data.shape[:-1] == enhanced_shape, msg + assert self.data.high_res.shape[:-1] == enhanced_shape, msg def get_sample_index(self, n_obs=None): """Get paired sample index, consisting of index for the low res sample @@ -141,10 +159,10 @@ def get_sample_index(self, n_obs=None): includes observation data.""" n_obs = n_obs or self.batch_size spatial_slice = uniform_box_sampler( - self.lr_data.shape, self.lr_sample_shape[:2] + self.data.low_res.shape, self.lr_sample_shape[:2] ) time_slice = uniform_time_sampler( - self.lr_data.shape, self.lr_sample_shape[2] * n_obs + self.data.low_res.shape, self.lr_sample_shape[2] * n_obs ) lr_index = (*spatial_slice, time_slice, self.lr_features) hr_index = [ @@ -155,8 +173,11 @@ def get_sample_index(self, n_obs=None): slice(s.start * self.t_enhance, s.stop * self.t_enhance) for s in lr_index[2:-1] ] - obs_index = (*hr_index, self.hr_out_features) - hr_index = (*hr_index, self.hr_features) + hr_feats = ( + self.hr_features[: -len(self.obs_features)] + if self.use_proxy_obs + else self.hr_features + ) + hr_index = (*hr_index, hr_feats) - sample_index = (lr_index, hr_index, obs_index) - return sample_index[: len(self.data)] + return (lr_index, hr_index) diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 437045725d..e7c950ddbd 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -13,30 +13,29 @@ class Sup3rLoss(tf.keras.losses.Loss): """Base class for custom sup3r loss metrics. This is meant to be used as a base class for loss metrics that require specific input features.""" - def __init__(self, input_features='all', obs_features=None): + def __init__(self, gen_features='all', true_features=None): """Initialize the loss with given input features Parameters ---------- - input_features : list | str - List of input features that the loss metric will be calculated on. - If 'all', the loss will be calculated on all features. Otherwise, - the loss will be calculated on the features specified in the list. - The order of features in the list will be checked to determine the - order of features in the input tensors. - obs_features : list | None - Optional list of observation features to use as targets for the - loss metric. This is typically used in a physics based loss - when the ground truth data is sparse (e.g. observation points). - In this case a physics constraint is applied where there are no - observations, and an additional content loss is calculated for - points where observations are available. The order of features - in the list will be checked to determine the order of features in - the input tensors. + gen_features : list | str + List of generator output features that the loss metric will be + calculated on. If 'all', the loss will be calculated on all + generator features. Otherwise, the loss will be calculated on the + features specified in the list. The order of features in the list + will be checked to determine the order of features in the generator + output tensor. + true_features : list | str + List of true features that the loss metric will be calculated on. + If None, this will be the same as gen_features. The order of + features in the list will be checked to determine the order of + features in the ground truth tensor. """ super().__init__() - self.input_features = input_features - self.obs_features = obs_features + self.gen_features = gen_features + self.true_features = ( + true_features if true_features is not None else gen_features + ) def tf_derivative(x, axis=1): @@ -754,19 +753,19 @@ class MaterialDerivativeLoss(Sup3rLoss): LOSS_METRIC = MeanAbsoluteError() - def __init__(self, input_features): - super().__init__(input_features=input_features) + def __init__(self, gen_features): + super().__init__(gen_features=gen_features) self.u_inds = [ - i for i, f in enumerate(input_features) if f.startswith('u_') + i for i, f in enumerate(gen_features) if f.startswith('u_') ] self.v_inds = [ - i for i, f in enumerate(input_features) if f.startswith('v_') + i for i, f in enumerate(gen_features) if f.startswith('v_') ] self.u_heights = [ - f.split('_')[1] for f in input_features if f.startswith('u_') + f.split('_')[1] for f in gen_features if f.startswith('u_') ] self.v_heights = [ - f.split('_')[1] for f in input_features if f.startswith('v_') + f.split('_')[1] for f in gen_features if f.startswith('v_') ] assert len(self.u_inds) == len(self.v_inds), ( 'The number of u and v components must be equal for ' @@ -795,7 +794,7 @@ def _compute_md(self, x, feature): """ # df/dt height = feature.split('_')[1] - fidx = self.input_features.index(feature) + fidx = self.gen_features.index(feature) uidx = self.u_inds[self.u_heights.index(height)] vidx = self.v_inds[self.v_heights.index(height)] x_div = tf_derivative(x[..., fidx], axis=3) @@ -835,10 +834,10 @@ def __call__(self, x1, x2): assert len(x1.shape) == 5 and len(x2.shape) == 5, msg x1_div = tf.stack([ - self._compute_md(x1, feature) for feature in self.input_features + self._compute_md(x1, feature) for feature in self.gen_features ]) x2_div = tf.stack([ - self._compute_md(x2, feature) for feature in self.input_features + self._compute_md(x2, feature) for feature in self.gen_features ]) return self.LOSS_METRIC(x1_div, x2_div) @@ -855,12 +854,13 @@ class GeothermalPhysicsLoss(Sup3rLoss): def __call__(self, x1, x2): """Geothermal physics loss""" - check = x1.shape[-1] == len(self.input_features) - check &= x2.shape[-1] == len(self.input_features) + check = x1.shape[-1] == len(self.gen_features) + check &= x2.shape[-1] == len(self.true_features) msg = ( - f'Number of features in `x1`: {x1.shape[-1]}, `x2`: ' - f'{x2.shape[-1]} must match the length of `input_features`: ' - f'{len(self.input_features)}' + f'Number of features in `x1`: {x1.shape[-1]} must match the ' + f'length of `gen_features`: {len(self.gen_features)}, `x2`: ' + f'{x2.shape[-1]} must match the length of `true_features`: ' + f'{len(self.true_features)}' ) assert check, msg @@ -878,13 +878,13 @@ class GeothermalPhysicsLossWithObs(Sup3rLoss): def __call__(self, x1, x2): """Geothermal physics loss""" - check = x1.shape[-1] == len(self.input_features) - check &= x2.shape[-1] == len(self.obs_features) + check = x1.shape[-1] == len(self.gen_features) + check &= x2.shape[-1] == len(self.true_features) msg = ( - f'Number of features in `x1`: {x1.shape[-1]}, must match the ' - f'length of `input_features`: {len(self.input_features)}, and ' - f'number of features in `x2`: {x2.shape[-1]}, must match the ' - f'length of `obs_features`: {len(self.obs_features)}' + f'Number of features in `x1`: {x1.shape[-1]} must match the ' + f'length of `gen_features`: {len(self.gen_features)}, `x2`: ' + f'{x2.shape[-1]} must match the length of `true_features`: ' + f'{len(self.true_features)}' ) assert check, msg diff --git a/tests/batch_handlers/test_bh_h5_cc.py b/tests/batch_handlers/test_bh_h5_cc.py index 942e52eed4..f349d6eceb 100644 --- a/tests/batch_handlers/test_bh_h5_cc.py +++ b/tests/batch_handlers/test_bh_h5_cc.py @@ -104,7 +104,10 @@ def test_solar_batching_spatial(): s_enhance=2, t_enhance=1, sample_shape=(20, 20, 1), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) for batch in batcher: @@ -162,7 +165,10 @@ def test_solar_multi_day_coarse_data(): s_enhance=4, t_enhance=3, sample_shape=(20, 20, 9), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) for batch in batcher: @@ -176,7 +182,10 @@ def test_solar_multi_day_coarse_data(): # run another test with u/v on low res side but not high res features = ['clearsky_ratio', 'u', 'v', 'ghi', 'clearsky_ghi'] - feature_sets = {'lr_only_features': ['u', 'v', 'clearsky_ghi', 'ghi']} + feature_sets = { + 'lr_features': features, + 'hr_out_features': ['clearsky_ratio'], + } handler = DataHandlerH5SolarCC(pytest.FP_NSRDB, features, **dh_kwargs) batcher = BatchHandlerTesterCC( @@ -329,7 +338,10 @@ def test_surf_min_max_vars(): s_enhance=1, t_enhance=24, sample_shape=(20, 20, 72), - feature_sets={'lr_only_features': ['*_min_*', '*_max_*']}, + feature_sets={ + 'lr_features': surf_features, + 'hr_out_features': ['temperature_2m', 'relativehumidity_2m'], + }, mode='eager', ) diff --git a/tests/batch_queues/test_bq_general.py b/tests/batch_queues/test_bq_general.py index 52b2ffb257..2fd0607034 100644 --- a/tests/batch_queues/test_bq_general.py +++ b/tests/batch_queues/test_bq_general.py @@ -184,7 +184,10 @@ def test_pair_batch_queue_with_lr_only_features(): s_enhance=2, t_enhance=2, batch_size=4, - feature_sets={'lr_only_features': lr_only_features}, + feature_sets={ + 'lr_features': lr_features, + 'hr_out_features': FEATURES, + }, ) for lr, hr in zip(lr_containers, hr_containers) ] diff --git a/tests/forward_pass/test_forward_pass_obs.py b/tests/forward_pass/test_forward_pass_obs.py index cd56ce70b2..f35793ad5d 100644 --- a/tests/forward_pass/test_forward_pass_obs.py +++ b/tests/forward_pass/test_forward_pass_obs.py @@ -9,7 +9,7 @@ import pytest from rex import Outputs -from sup3r.models import Sup3rGanWithObs +from sup3r.models import Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.utilities.pytest.helpers import make_fake_dset from sup3r.utilities.utilities import RANDOM_GENERATOR @@ -108,24 +108,19 @@ def h5_obs_file(tmpdir_factory): @pytest.mark.parametrize('obs_file', ['nc_obs_file', 'h5_obs_file']) -def test_fwp_with_obs( - input_file, obs_file, gen_config_with_concat_masked, request -): +def test_fwp_with_obs(input_file, obs_file, gen_config_with_obs_2d, request): """Test a special model trained to condition output on input observations.""" obs_file = request.getfixturevalue(obs_file) - Sup3rGanWithObs.seed() - - model = Sup3rGanWithObs( - gen_config_with_concat_masked(), - pytest.S_FP_DISC, - onshore_obs_frac={'spatial': 0.1}, - loss_obs_weight=0.1, - learning_rate=1e-4, + Sup3rGan.seed() + + model = Sup3rGan( + gen_config_with_obs_2d(), pytest.S_FP_DISC, learning_rate=1e-4 ) model.meta['input_resolution'] = {'spatial': '16km', 'temporal': '3600min'} model.meta['lr_features'] = ['u_10m', 'v_10m'] + model.meta['hr_exo_features'] = ['u_10m_obs', 'v_10m_obs'] model.meta['hr_out_features'] = ['u_10m', 'v_10m'] model.meta['s_enhance'] = 2 model.meta['t_enhance'] = 1 @@ -151,10 +146,7 @@ def test_fwp_with_obs( ] }, } - _ = model.generate( - np.ones((6, 10, 10, 2)), - exogenous_data=exo_tmp - ) + _ = model.generate(np.ones((6, 10, 10, 2)), exogenous_data=exo_tmp) model_dir = os.path.join(td, 'test') model.save(model_dir) @@ -182,7 +174,7 @@ def test_fwp_with_obs( handler = ForwardPassStrategy( input_file, model_kwargs=model_kwargs, - model_class='Sup3rGanWithObs', + model_class='Sup3rGan', fwp_chunk_shape=fwp_chunk_shape, input_handler_kwargs=input_handler_kwargs, spatial_pad=0, diff --git a/tests/samplers/test_feature_sets.py b/tests/samplers/test_feature_sets.py index 7c87bb72b6..ade600c007 100644 --- a/tests/samplers/test_feature_sets.py +++ b/tests/samplers/test_feature_sets.py @@ -8,24 +8,27 @@ @pytest.mark.parametrize( - ['features', 'lr_only_features', 'hr_exo_features'], + ['features', 'lr_features', 'hr_exo_features', 'hr_out_features'], [ - (['V_100m'], ['V_100m'], []), - (['U_100m'], ['V_100m'], ['V_100m']), - (['U_100m'], [], ['U_100m']), - (['U_100m', 'V_100m'], [], ['U_100m']), - (['U_100m', 'V_100m'], [], ['V_100m', 'U_100m']), + (['V_100m'], ['V_100m'], [], []), + (['U_100m'], ['V_100m'], ['V_100m'], []), + (['U_100m'], [], ['U_100m'], []), + (['U_100m', 'V_100m'], [], ['U_100m'], []), + (['U_100m', 'V_100m'], [], ['V_100m', 'U_100m'], []), ], ) -def test_feature_errors(features, lr_only_features, hr_exo_features): +def test_feature_errors( + features, lr_features, hr_exo_features, hr_out_features +): """Each of these feature combinations should raise an error due to no features left in hr output or bad ordering""" sampler = Sampler( DummyData(data_shape=(20, 20, 10), features=features), sample_shape=(5, 5, 4), feature_sets={ - 'lr_only_features': lr_only_features, + 'lr_features': lr_features, 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, }, ) @@ -33,6 +36,40 @@ def test_feature_errors(features, lr_only_features, hr_exo_features): _ = sampler.lr_features _ = sampler.hr_out_features _ = sampler.hr_exo_features + _ = sampler.obs_features + + +@pytest.mark.parametrize( + ['features', 'lr_features', 'hr_exo_features', 'hr_out_features'], + [ + (['V_100m', 'topography'], [], ['topography'], ['V_100m_obs']), + ( + ['V_100m', 'V_100m_obs', 'topography'], + [], + ['topography'], + ['V_100m_obs'], + ), + ], +) +def test_sampler_feature_sets( + features, lr_features, hr_exo_features, hr_out_features +): + """Each of these feature combinations should pass without raising an + error.""" + sampler = Sampler( + DummyData(data_shape=(20, 20, 10), features=features), + sample_shape=(5, 5, 4), + feature_sets={ + 'lr_features': lr_features, + 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, + }, + ) + + _ = sampler.lr_features + _ = sampler.hr_out_features + _ = sampler.hr_exo_features + _ = sampler.obs_features @pytest.mark.parametrize( @@ -87,7 +124,7 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): @pytest.mark.parametrize( - ['features', 'lr_only_features', 'hr_exo_features'], + ['lr_features', 'hr_exo_features', 'hr_out_features'], [ ( [ @@ -102,8 +139,8 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'dewpoint_temperature', 'topography', ], - ['pressure', 'kx', 'dewpoint_temperature'], ['topography'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ( [ @@ -119,8 +156,8 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'topography', 'srl', ], - ['pressure', 'kx', 'dewpoint_temperature'], ['topography', 'srl'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ( [ @@ -134,8 +171,8 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'kx', 'dewpoint_temperature', ], - ['pressure', 'kx', 'dewpoint_temperature'], [], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ( [ @@ -148,12 +185,12 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'topography', 'srl', ], - [], ['topography', 'srl'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ], ) -def test_dual_feature_sets(features, lr_only_features, hr_exo_features): +def test_dual_feature_sets(lr_features, hr_exo_features, hr_out_features): """Each of these feature combinations should work fine with the dual sampler""" @@ -161,21 +198,136 @@ def test_dual_feature_sets(features, lr_only_features, hr_exo_features): lr_containers = [ DummyData( data_shape=(10, 10, 20), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], + ), + DummyData( + data_shape=(12, 12, 15), + features=[f.lower() for f in lr_features], + ), + ] + hr_containers = [ + DummyData( + data_shape=(20, 20, 40), + features=[f.lower() for f in lr_features], + ), + DummyData( + data_shape=(24, 24, 30), + features=[f.lower() for f in lr_features], + ), + ] + sampler_pairs = [ + DualSampler( + Sup3rDataset(low_res=lr.data, high_res=hr.data), + hr_sample_shape, + s_enhance=2, + t_enhance=2, + feature_sets={ + 'lr_features': lr_features, + 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, + }, + ) + for lr, hr in zip(lr_containers, hr_containers) + ] + + for pair in sampler_pairs: + _ = pair.lr_features + _ = pair.hr_out_features + _ = pair.hr_exo_features + + +@pytest.mark.parametrize( + ['lr_features', 'hr_exo_features', 'hr_out_features'], + [ + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'pressure', + 'kx', + 'dewpoint_temperature', + 'topography', + ], + ['topography', 'u_100m_obs', 'v_100m_obs'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'pressure', + 'kx', + 'dewpoint_temperature', + 'topography', + 'srl', + ], + ['topography', 'srl', 'u_100m_obs', 'v_100m_obs'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'pressure', + 'kx', + 'dewpoint_temperature', + ], + [], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'topography', + 'srl', + ], + ['topography', 'srl', 'u_100m_obs', 'v_100m_obs'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ], +) +def test_dual_feature_sets_with_obs( + lr_features, hr_exo_features, hr_out_features +): + """Each of these feature combinations should work fine with the dual + sampler when obs features are included""" + + hr_sample_shape = (8, 8, 10) + lr_containers = [ + DummyData( + data_shape=(10, 10, 20), + features=[f.lower() for f in lr_features], ), DummyData( data_shape=(12, 12, 15), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], ), ] hr_containers = [ DummyData( data_shape=(20, 20, 40), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], ), DummyData( data_shape=(24, 24, 30), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], ), ] sampler_pairs = [ @@ -184,10 +336,12 @@ def test_dual_feature_sets(features, lr_only_features, hr_exo_features): hr_sample_shape, s_enhance=2, t_enhance=2, + proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1}}, feature_sets={ - 'features': features, - 'lr_only_features': lr_only_features, - 'hr_exo_features': hr_exo_features}, + 'lr_features': lr_features, + 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, + }, ) for lr, hr in zip(lr_containers, hr_containers) ] @@ -196,3 +350,4 @@ def test_dual_feature_sets(features, lr_only_features, hr_exo_features): _ = pair.lr_features _ = pair.hr_out_features _ = pair.hr_exo_features + _ = pair.obs_features diff --git a/tests/training/test_train_exo.py b/tests/training/test_train_exo.py index 02b283bf8a..86cb157663 100644 --- a/tests/training/test_train_exo.py +++ b/tests/training/test_train_exo.py @@ -19,24 +19,24 @@ @pytest.mark.parametrize( - ('CustomLayer', 'features', 'lr_only_features', 'mode'), + ('CustomLayer', 'lr_features', 'hr_out_features', 'mode'), [ - ('Sup3rAdder', FEATURES_W, ['temperature_100m'], 'lazy'), - ('Sup3rConcat', FEATURES_W, ['temperature_100m'], 'lazy'), - ('Sup3rAdder', FEATURES_W[1:], [], 'lazy'), - ('Sup3rConcat', FEATURES_W[1:], [], 'lazy'), - ('Sup3rConcat', FEATURES_W[1:], [], 'eager'), + ('Sup3rAdder', FEATURES_W, FEATURES_W[1:-1], 'lazy'), + ('Sup3rConcat', FEATURES_W, FEATURES_W[1:-1], 'lazy'), + ('Sup3rAdder', FEATURES_W[1:], FEATURES_W[1:-1], 'lazy'), + ('Sup3rConcat', FEATURES_W[1:], FEATURES_W[1:-1], 'lazy'), + ('Sup3rConcat', FEATURES_W[1:], FEATURES_W[1:-1], 'eager'), ], ) def test_wind_hi_res_topo( - CustomLayer, features, lr_only_features, mode, gen_config_with_topo + CustomLayer, lr_features, hr_out_features, mode, gen_config_with_topo ): """Test a special wind model for non cc with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the network.""" kwargs = { 'file_paths': pytest.FP_WTK, - 'features': features, + 'features': lr_features, 'target': TARGET_W, 'shape': SHAPE, } @@ -54,7 +54,8 @@ def test_wind_hi_res_topo( t_enhance=1, sample_shape=(20, 20, 1), feature_sets={ - 'lr_only_features': lr_only_features, + 'lr_features': lr_features, + 'hr_out_features': hr_out_features, 'hr_exo_features': ['topography'], }, mode=mode, @@ -80,16 +81,18 @@ def test_wind_hi_res_topo( out_dir=os.path.join(td, 'test_{epoch}'), ) - assert model.lr_features == [f.lower() for f in features] - assert model.hr_out_features == ['u_100m', 'v_100m'] + assert model.lr_features == [f.lower() for f in lr_features] + assert model.hr_out_features == [f.lower() for f in hr_out_features] assert model.hr_exo_features == ['topography'] assert 'test_0' in os.listdir(td) - assert model.meta['hr_out_features'] == ['u_100m', 'v_100m'] + assert model.meta['hr_out_features'] == [ + f.lower() for f in hr_out_features + ] assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.hr_exo_features assert 'topography' not in model.hr_out_features - x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(features))) + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(lr_features))) hi_res_topo = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): @@ -108,4 +111,4 @@ def test_wind_hi_res_topo( assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == len(features) - len(lr_only_features) - 1 + assert y.shape[3] == len(hr_out_features) diff --git a/tests/training/test_train_exo_cc.py b/tests/training/test_train_exo_cc.py index 30225c8d4f..6d5622895e 100644 --- a/tests/training/test_train_exo_cc.py +++ b/tests/training/test_train_exo_cc.py @@ -19,16 +19,16 @@ @pytest.mark.parametrize( - ('CustomLayer', 'features', 'lr_only_features'), + ('CustomLayer', 'lr_features', 'hr_out_features'), [ - ('Sup3rAdder', FEATURES_W, ['temperature_100m']), - ('Sup3rConcat', FEATURES_W, ['temperature_100m']), - ('Sup3rAdder', FEATURES_W[1:], []), - ('Sup3rConcat', FEATURES_W[1:], []), + ('Sup3rAdder', FEATURES_W, FEATURES_W[1:-1]), + ('Sup3rConcat', FEATURES_W, FEATURES_W[1:-1]), + ('Sup3rAdder', FEATURES_W[1:], FEATURES_W[1:-1]), + ('Sup3rConcat', FEATURES_W[1:], FEATURES_W[1:-1]), ], ) def test_wind_hi_res_topo( - CustomLayer, features, lr_only_features, gen_config_with_topo + CustomLayer, lr_features, hr_out_features, gen_config_with_topo ): """Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the @@ -36,7 +36,7 @@ def test_wind_hi_res_topo( handler = DataHandlerH5WindCC( pytest.FP_WTK, - features=features, + features=lr_features, target=TARGET_W, shape=SHAPE, time_slice=slice(None, None, 2), @@ -50,7 +50,8 @@ def test_wind_hi_res_topo( s_enhance=2, sample_shape=(20, 20), feature_sets={ - 'lr_only_features': lr_only_features, + 'lr_features': lr_features, + 'hr_out_features': hr_out_features, 'hr_exo_features': ['topography'], }, mode='eager', @@ -73,16 +74,16 @@ def test_wind_hi_res_topo( out_dir=os.path.join(td, 'test_{epoch}'), ) - assert model.lr_features == lowered(features) - assert model.hr_out_features == ['u_100m', 'v_100m'] + assert model.lr_features == lowered(lr_features) + assert model.hr_out_features == lowered(hr_out_features) assert model.hr_exo_features == ['topography'] assert 'test_0' in os.listdir(td) - assert model.meta['hr_out_features'] == ['u_100m', 'v_100m'] + assert model.meta['hr_out_features'] == lowered(hr_out_features) assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.hr_exo_features assert 'topography' not in model.hr_out_features - x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(features))) + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(lr_features))) hi_res_topo = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): @@ -101,4 +102,4 @@ def test_wind_hi_res_topo( assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == x.shape[3] - len(lr_only_features) - 1 + assert y.shape[3] == len(hr_out_features) diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 5ed5a6bbf5..3bd3304c5e 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -252,7 +252,7 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): 'loss_func', [ {'SlicedWassersteinLoss': {}}, - {'GeothermalPhysicsLoss': {'input_features': ['u_100m']}}, + {'GeothermalPhysicsLoss': {'gen_features': ['u_100m']}}, ], ) def test_train_with_custom_loss(loss_func, n_epoch=8): diff --git a/tests/training/test_train_solar.py b/tests/training/test_train_solar.py index 83860808d3..27ddaff4dd 100644 --- a/tests/training/test_train_solar.py +++ b/tests/training/test_train_solar.py @@ -51,7 +51,10 @@ def test_solar_cc_model(hr_steps): s_enhance=1, t_enhance=8, sample_shape=(20, 20, hr_steps), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc/gen_solar_1x_8x_1f.json') @@ -126,7 +129,10 @@ def test_solar_cc_model_spatial(): s_enhance=5, t_enhance=1, sample_shape=(20, 20), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc/gen_solar_5x_1x_1f.json') @@ -178,7 +184,10 @@ def test_solar_custom_loss(): s_enhance=1, t_enhance=8, sample_shape=(5, 5, 24), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc/gen_solar_1x_8x_1f.json') diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py index f4d0628552..d8c421968a 100644 --- a/tests/training/test_train_with_obs.py +++ b/tests/training/test_train_with_obs.py @@ -8,7 +8,7 @@ import numpy as np import pytest -from sup3r.models import Sup3rGanWithObs +from sup3r.models import Sup3rGan from sup3r.preprocessing import ( BatchHandler, Container, @@ -59,22 +59,29 @@ def test_train_cond_obs(gen_config, sample_shape, t_enhance, fp_disc, request): s_enhance=2, t_enhance=t_enhance, sample_shape=sample_shape, + proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1}}, + feature_sets={ + 'lr_features': FEATURES_W, + 'hr_exo_features': [f'{feat}_obs' for feat in FEATURES_W], + 'hr_out_features': FEATURES_W, + }, ) - Sup3rGanWithObs.seed() + Sup3rGan.seed() - model = Sup3rGanWithObs( + model = Sup3rGan( gen_config, fp_disc, - onshore_obs_frac={'spatial': 0.1}, - loss_obs_weight=0.1, learning_rate=1e-4, + loss={ + 'GeothermalPhysicsLossWithObs': { + 'gen_features': FEATURES_W, + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], + }, + 'GeothermalPhysicsLoss': {'gen_features': FEATURES_W}, + }, ) - model.meta['hr_out_features'] = ['u_10m', 'v_10m'] - test_mask = model._get_full_obs_mask(np.zeros((1, 20, 20, 1, 1))).numpy() - frac = 1 - test_mask.sum() / test_mask.size - assert np.abs(0.1 - frac) < test_mask.size / (2 * np.sqrt(test_mask.size)) - assert model.obs_features == ['u_10m_obs', 'v_10m_obs'] + model.meta['hr_out_features'] = FEATURES_W with tempfile.TemporaryDirectory() as td: model_kwargs = { 'input_resolution': {'spatial': '16km', 'temporal': '3600min'}, @@ -91,6 +98,8 @@ def test_train_cond_obs(gen_config, sample_shape, t_enhance, fp_disc, request): loaded = model.load(os.path.join(td, 'test_2')) loaded.train(batcher, **model_kwargs) + assert model.obs_features == [f'{feat}_obs' for feat in FEATURES_W] + if t_enhance == 1: x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(FEATURES_W))) u10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) @@ -185,7 +194,11 @@ def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): s_enhance=2, t_enhance=t_enhance, n_batches=2, - feature_sets={'lr_only_features': FEATURES_W}, + feature_sets={ + 'lr_features': FEATURES_W, + 'hr_exo_features': [f'{feat}_obs' for feat in FEATURES_W], + 'hr_out_features': [f'{feat}_obs' for feat in FEATURES_W], + }, mode='lazy', ) @@ -193,16 +206,15 @@ def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): assert not np.isnan(batch.high_res).all() assert np.isnan(batch.high_res).any() - Sup3rGanWithObs.seed() - model = Sup3rGanWithObs( + Sup3rGan.seed() + model = Sup3rGan( gen_config, fp_disc, - use_proxy_obs=False, learning_rate=1e-4, loss={ 'GeothermalPhysicsLossWithObs': { - 'input_features': [f'{feat}_obs' for feat in FEATURES_W], - 'obs_features': [f'{feat}_obs' for feat in FEATURES_W], + 'gen_features': [f'{feat}_obs' for feat in FEATURES_W], + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], } }, ) diff --git a/tests/utilities/test_loss_metrics.py b/tests/utilities/test_loss_metrics.py index b32e35b441..ab4eef1420 100644 --- a/tests/utilities/test_loss_metrics.py +++ b/tests/utilities/test_loss_metrics.py @@ -271,7 +271,7 @@ def test_md_loss(): y = x.copy() md_loss = MaterialDerivativeLoss( - input_features=['u_100m', 'v_100m'] + gen_features=['u_100m', 'v_100m'] ) u_div = md_loss._compute_md(x, feature='u_100m') v_div = md_loss._compute_md(x, feature='v_100m') @@ -301,7 +301,7 @@ def test_multiterm_loss(): y = x.copy() md_loss = MaterialDerivativeLoss( - input_features=['u_100m', 'v_100m', 'temp_100m'] + gen_features=['u_100m', 'v_100m', 'temp_100m'] ) mae_loss = MeanAbsoluteError() fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') @@ -310,7 +310,7 @@ def test_multiterm_loss(): model.meta['hr_out_features'] = ['u_100m', 'v_100m', 'temp_100m'] multi_loss = model.get_loss_fun({ 'MaterialDerivativeLoss': { - 'input_features': ['u_100m', 'v_100m', 'temp_100m'] + 'gen_features': ['u_100m', 'v_100m', 'temp_100m'] }, 'MeanAbsoluteError': {}, 'term_weights': [0.2, 0.8], From bba5e9160ea282e8c6659d5c2687888c2fe92fa5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 12 Mar 2026 16:15:24 -0600 Subject: [PATCH 05/26] `sparse_disc` arg added to Sup3rGan - True if discriminator can handle sparse data --- sup3r/models/abstract.py | 1 + sup3r/models/base.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 66d145c19d..9600b30688 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -49,6 +49,7 @@ def __init__(self): self._gen = None self._means = None self._stdevs = None + self._sparse_disc = None self._train_record = pd.DataFrame() self._val_record = pd.DataFrame() diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 108c8e49bf..bc7fcd1ecf 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -38,6 +38,7 @@ def __init__( stdevs=None, default_device=None, name=None, + sparse_disc=False, ): """ Parameters @@ -99,6 +100,11 @@ def __init__( "/gpu:0" or "/cpu:0" name : str | None Optional name for the GAN. + sparse_disc : bool + Whether the discriminator can accept sparse features as input. + If False, the discriminator will only receive the dense features + as input. If True, the discriminator will receive both dense and + sparse features as input. """ super().__init__() @@ -130,6 +136,7 @@ def __init__( self._means = means self._stdevs = stdevs + self._sparse_disc = sparse_disc def save(self, out_dir): """Save the GAN with its sub-networks to a directory. @@ -299,12 +306,9 @@ def _tf_discriminate(self, hi_res): Discriminator output logits """ - # TODO: We currently assume the discriminator is convolutional so we - # remove the sparse obs data. Change this once we support - # non-convolutional discriminators hr = ( hi_res - if len(self.obs_features) == 0 + if len(self.obs_features) == 0 or self._sparse_disc else hi_res[..., : -len(self.obs_features)] ) if hr.shape[-1] == 0: From 129b99b00543b28cd142d3ea3631e0e9be796a2a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 13 Mar 2026 07:10:01 -0600 Subject: [PATCH 06/26] test fixes --- sup3r/models/abstract.py | 18 ++++++++ sup3r/models/conditional.py | 17 ++++---- sup3r/models/interface.py | 2 - sup3r/preprocessing/samplers/base.py | 27 +++++++++++- tests/forward_pass/test_forward_pass_exo.py | 7 ++++ tests/rasterizers/test_dual.py | 14 +++---- tests/samplers/test_feature_sets.py | 43 ++++++++------------ tests/training/test_train_conditional_exo.py | 5 ++- tests/training/test_train_exo_dc.py | 12 ++++-- 9 files changed, 95 insertions(+), 50 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 9600b30688..3fd7930f4a 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -198,6 +198,24 @@ def set_norm_stats(self, new_means, new_stdevs): pprint.pformat(self._stdevs, indent=2), ) + @staticmethod + def check_batch_handler_attrs(batch_handler): + """Not all batch handlers have the following attributes. So we perform + some sanitation before sending to `set_model_params`""" + return { + k: getattr(batch_handler, k, None) + for k in [ + 'smoothing', + 'lr_features', + 'hr_exo_features', + 'hr_out_features', + 'hr_features', + 'obs_features', + 'smoothed_features', + ] + if hasattr(batch_handler, k) + } + def norm_input(self, low_res): """Normalize low resolution data being input to the generator. diff --git a/sup3r/models/conditional.py b/sup3r/models/conditional.py index a37cd8ff9c..ff22409181 100644 --- a/sup3r/models/conditional.py +++ b/sup3r/models/conditional.py @@ -423,15 +423,14 @@ def train( self._init_tensorboard_writer(out_dir) self.set_norm_stats(batch_handler.means, batch_handler.stds) - self.set_model_params( - input_resolution=input_resolution, - s_enhance=batch_handler.s_enhance, - t_enhance=batch_handler.t_enhance, - smoothing=batch_handler.smoothing, - lr_features=batch_handler.lr_features, - hr_exo_features=batch_handler.hr_exo_features, - hr_out_features=batch_handler.hr_out_features, - smoothed_features=batch_handler.smoothed_features, + params = self.check_batch_handler_attrs(batch_handler) + lower_models = getattr(batch_handler, 'lower_models', {}) + for model in [self, *lower_models.values()]: + model.set_model_params( + input_resolution=input_resolution, + s_enhance=batch_handler.s_enhance, + t_enhance=batch_handler.t_enhance, + **params, ) epochs = list(range(n_epoch)) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 1aae2b0995..eaca179c19 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -461,8 +461,6 @@ def set_model_params(self, **kwargs): 'smoothing', ) keys = [k for k in keys if k in kwargs] - if 'hr_out_features' in kwargs: - self.meta['hr_out_features'] = kwargs['hr_out_features'] for var in keys: val = self.meta.get(var, None) diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index e6b02802f1..7e526d0b42 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -237,7 +237,9 @@ def check_feature_consistency(self): ) raise ValueError(msg) - if len(self.obs_features) > 0: + if len(self.obs_features) > 0 and any( + f in self.hr_exo_features for f in self.obs_features + ): msg = ( f'Obs features {self.obs_features} must come at the end of ' f'the hr_exo_features {self.hr_exo_features}' @@ -255,6 +257,29 @@ def check_feature_consistency(self): self.hr_features[-len(self.hr_exo_features) :] ), msg + assert all(f in self.data.features for f in self.lr_features), ( + f'All lr_features {self.lr_features} must be in the data features ' + f'{self.data.features}.' + ) + assert all(f in self.data.features for f in self.hr_out_features), ( + f'All hr_out_features {self.hr_out_features} must be in the data ' + f'features {self.data.features}.' + ) + if not self.use_proxy_obs: + assert all( + f in self.data.features for f in self.hr_exo_features + ), ( + f'All hr_exo_features {self.hr_exo_features} must be in the ' + f'data features {self.data.features} when not using proxy ' + 'observations.' + ) + else: + feats = set(self.hr_exo_features) - set(self.obs_features) + assert all(f in self.data.features for f in feats), ( + f'All non-obs hr_exo_features {feats} must be in the data ' + f'features {self.data.features} when using proxy observations.' + ) + @property def sample_shape(self) -> tuple: """Shape of the data sample to select when ``__next__()`` is called.""" diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index 9ccde0e106..212b9882c5 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -479,6 +479,7 @@ def test_fwp_single_step_wind_hi_res_topo(input_files, plot=False): model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] model.meta['hr_out_features'] = ['u_100m', 'v_100m'] + model.meta['hr_exo_features'] = ['topography'] model.meta['s_enhance'] = 2 model.meta['t_enhance'] = 2 model.meta['input_resolution'] = {'spatial': '8km', 'temporal': '60min'} @@ -561,6 +562,7 @@ def test_fwp_multi_step_wind_hi_res_topo(input_files, gen_config_with_topo): gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s1_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] + s1_model.meta['hr_exo_features'] = ['topography'] s1_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 @@ -586,6 +588,7 @@ def test_fwp_multi_step_wind_hi_res_topo(input_files, gen_config_with_topo): gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s2_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] + s2_model.meta['hr_exo_features'] = ['topography'] s2_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 @@ -663,6 +666,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(input_files, gen_config_with_topo): gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] + s_model.meta['hr_exo_features'] = ['topography'] s_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s_model.meta['s_enhance'] = 2 s_model.meta['t_enhance'] = 1 @@ -896,6 +900,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza( gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s1_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography', 'sza'] + s1_model.meta['hr_exo_features'] = ['topography'] s1_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 @@ -929,6 +934,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza( gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s2_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography', 'sza'] + s2_model.meta['hr_exo_features'] = ['topography'] s2_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 @@ -941,6 +947,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza( fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(gen_t_model, fp_disc, learning_rate=1e-4) st_model.meta['lr_features'] = ['u_100m', 'v_100m', 'sza'] + st_model.meta['hr_exo_features'] = ['sza'] st_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 2 diff --git a/tests/rasterizers/test_dual.py b/tests/rasterizers/test_dual.py index 725c2fdb05..6499096a2f 100644 --- a/tests/rasterizers/test_dual.py +++ b/tests/rasterizers/test_dual.py @@ -34,10 +34,10 @@ def test_dual_rasterizer_shapes(full_shape=(20, 20)): s_enhance=2, t_enhance=1, ) - assert pair_rasterizer.lr_data.shape == ( - pair_rasterizer.hr_data.shape[0] // 2, - pair_rasterizer.hr_data.shape[1] // 2, - *pair_rasterizer.hr_data.shape[2:], + assert pair_rasterizer.data.low_res.shape == ( + pair_rasterizer.data.high_res.shape[0] // 2, + pair_rasterizer.data.high_res.shape[1] // 2, + *pair_rasterizer.data.high_res.shape[2:], ) @@ -70,7 +70,7 @@ def test_dual_nan_fill(full_shape=(20, 20)): t_enhance=1, ) - assert not np.isnan(pair_rasterizer.lr_data.as_array()).any() + assert not np.isnan(pair_rasterizer.data.low_res.as_array()).any() def test_regrid_caching(full_shape=(20, 20)): @@ -110,9 +110,9 @@ def test_regrid_caching(full_shape=(20, 20)): assert np.array_equal( lr_container_new.data[FEATURES][...], - pair_rasterizer.lr_data[FEATURES][...], + pair_rasterizer.data.low_res[FEATURES][...], ) assert np.array_equal( hr_container_new.data[FEATURES][...], - pair_rasterizer.hr_data[FEATURES][...], + pair_rasterizer.data.high_res[FEATURES][...], ) diff --git a/tests/samplers/test_feature_sets.py b/tests/samplers/test_feature_sets.py index ade600c007..b829fee852 100644 --- a/tests/samplers/test_feature_sets.py +++ b/tests/samplers/test_feature_sets.py @@ -10,11 +10,8 @@ @pytest.mark.parametrize( ['features', 'lr_features', 'hr_exo_features', 'hr_out_features'], [ - (['V_100m'], ['V_100m'], [], []), + (['V_100m'], ['V_100m'], [], ['U_100m']), (['U_100m'], ['V_100m'], ['V_100m'], []), - (['U_100m'], [], ['U_100m'], []), - (['U_100m', 'V_100m'], [], ['U_100m'], []), - (['U_100m', 'V_100m'], [], ['V_100m', 'U_100m'], []), ], ) def test_feature_errors( @@ -22,42 +19,34 @@ def test_feature_errors( ): """Each of these feature combinations should raise an error due to no features left in hr output or bad ordering""" - sampler = Sampler( - DummyData(data_shape=(20, 20, 10), features=features), - sample_shape=(5, 5, 4), - feature_sets={ - 'lr_features': lr_features, - 'hr_exo_features': hr_exo_features, - 'hr_out_features': hr_out_features, - }, - ) - with pytest.raises((RuntimeError, AssertionError)): - _ = sampler.lr_features - _ = sampler.hr_out_features - _ = sampler.hr_exo_features - _ = sampler.obs_features + _ = Sampler( + DummyData(data_shape=(20, 20, 10), features=features), + sample_shape=(5, 5, 4), + feature_sets={ + 'lr_features': lr_features, + 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, + }, + ) @pytest.mark.parametrize( - ['features', 'lr_features', 'hr_exo_features', 'hr_out_features'], + ['lr_features', 'hr_exo_features', 'hr_out_features'], [ - (['V_100m', 'topography'], [], ['topography'], ['V_100m_obs']), + (['V_100m', 'topography'], ['topography'], ['V_100m_obs']), ( - ['V_100m', 'V_100m_obs', 'topography'], - [], - ['topography'], + ['V_100m', 'topography'], + ['topography', 'V_100m_obs'], ['V_100m_obs'], ), ], ) -def test_sampler_feature_sets( - features, lr_features, hr_exo_features, hr_out_features -): +def test_sampler_feature_sets(lr_features, hr_exo_features, hr_out_features): """Each of these feature combinations should pass without raising an error.""" sampler = Sampler( - DummyData(data_shape=(20, 20, 10), features=features), + DummyData(data_shape=(20, 20, 10), features=lr_features), sample_shape=(5, 5, 4), feature_sets={ 'lr_features': lr_features, diff --git a/tests/training/test_train_conditional_exo.py b/tests/training/test_train_conditional_exo.py index 60e61d96bd..84b9697723 100644 --- a/tests/training/test_train_conditional_exo.py +++ b/tests/training/test_train_conditional_exo.py @@ -170,7 +170,10 @@ def test_wind_non_cc_hi_res_st_topo_mom2( lower_models={1: model_mom1}, n_batches=n_batches, sample_shape=(12, 12, 24), - feature_sets={'hr_exo_features': ['topography']}, + feature_sets={ + 'lr_features': ['u_100m', 'v_100m', 'topography'], + 'hr_out_features': ['u_100m', 'v_100m'], + 'hr_exo_features': ['topography']}, mode='eager' ) diff --git a/tests/training/test_train_exo_dc.py b/tests/training/test_train_exo_dc.py index 34ccfca83b..fea1a192dc 100644 --- a/tests/training/test_train_exo_dc.py +++ b/tests/training/test_train_exo_dc.py @@ -43,8 +43,11 @@ def test_wind_dc_hi_res_topo(CustomLayer): n_batches=1, s_enhance=2, sample_shape=(20, 20, 8), - feature_sets={'hr_exo_features': ['topography']}, - ) + feature_sets={ + 'lr_features': ['u_100m', 'v_100m', 'topography'], + 'hr_out_features': ['u_100m', 'v_100m'], + 'hr_exo_features': ['topography']}, + ) batcher = BatchHandlerTesterDC( train_containers=[handler], @@ -55,7 +58,10 @@ def test_wind_dc_hi_res_topo(CustomLayer): n_batches=1, s_enhance=2, sample_shape=(10, 10, 8), - feature_sets={'hr_exo_features': ['topography']}, + feature_sets={ + 'lr_features': ['u_100m', 'v_100m', 'topography'], + 'hr_out_features': ['u_100m', 'v_100m'], + 'hr_exo_features': ['topography']}, ) gen_model = [ From 8966ee6f06d2ed088bb6a78add5b340d905ab630 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 13 Mar 2026 09:02:13 -0600 Subject: [PATCH 07/26] test fixes --- sup3r/models/abstract.py | 31 ---------------------------- sup3r/preprocessing/samplers/base.py | 12 +++++++---- tests/samplers/test_feature_sets.py | 3 ++- tests/training/test_train_gan.py | 4 ++++ 4 files changed, 14 insertions(+), 36 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 3fd7930f4a..145d8724b0 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -456,37 +456,6 @@ def get_hr_exo_input(self, hi_res): exo = dict(zip(self.hr_exo_features, tf.unstack(exo, axis=-1))) return exo - def _extract_obs(self, hi_res_true): - """Extract observation features from the end of hi_res_true. - Observation features are appended after hr_out + hr_exo features - by the Sampler when ``use_proxy_obs=True``, or included directly - when real obs are in the data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution data, possibly with obs features - appended at the end. - - Returns - ------- - obs_data : dict - Dictionary of observation feature data. Keys are obs feature - names, values are tensors. Also includes ``'mask'`` key with - boolean mask (True where observations are NaN / missing). - Empty dict if no obs features are present. - """ - if len(self.obs_features) == 0: - return {} - obs = hi_res_true[..., -len(self.obs_features) :] - obs_mask = tf.math.is_nan(obs) - obs_expanded = tf.expand_dims(obs, axis=-2) - obs_data = dict( - zip(self.obs_features, tf.unstack(obs_expanded, axis=-1)) - ) - obs_data['mask'] = obs_mask - return obs_data - def _combine_loss_input(self, hi_res_true, hi_res_gen): """Combine exogenous feature data from hi_res_true with hi_res_gen for loss calculation diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 7e526d0b42..5353a4ae8e 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -257,17 +257,21 @@ def check_feature_consistency(self): self.hr_features[-len(self.hr_exo_features) :] ), msg - assert all(f in self.data.features for f in self.lr_features), ( + assert all( + f in lowered(self.data.features) for f in self.lr_features + ), ( f'All lr_features {self.lr_features} must be in the data features ' f'{self.data.features}.' ) - assert all(f in self.data.features for f in self.hr_out_features), ( + assert all( + f in lowered(self.data.features) for f in self.hr_out_features + ), ( f'All hr_out_features {self.hr_out_features} must be in the data ' f'features {self.data.features}.' ) if not self.use_proxy_obs: assert all( - f in self.data.features for f in self.hr_exo_features + f in lowered(self.data.features) for f in self.hr_exo_features ), ( f'All hr_exo_features {self.hr_exo_features} must be in the ' f'data features {self.data.features} when not using proxy ' @@ -275,7 +279,7 @@ def check_feature_consistency(self): ) else: feats = set(self.hr_exo_features) - set(self.obs_features) - assert all(f in self.data.features for f in feats), ( + assert all(f in lowered(self.data.features) for f in feats), ( f'All non-obs hr_exo_features {feats} must be in the data ' f'features {self.data.features} when using proxy observations.' ) diff --git a/tests/samplers/test_feature_sets.py b/tests/samplers/test_feature_sets.py index b829fee852..8eafad85d4 100644 --- a/tests/samplers/test_feature_sets.py +++ b/tests/samplers/test_feature_sets.py @@ -45,8 +45,9 @@ def test_feature_errors( def test_sampler_feature_sets(lr_features, hr_exo_features, hr_out_features): """Each of these feature combinations should pass without raising an error.""" + feats = set(lr_features) | set(hr_exo_features) | set(hr_out_features) sampler = Sampler( - DummyData(data_shape=(20, 20, 10), features=lr_features), + DummyData(data_shape=(20, 20, 10), features=feats), sample_shape=(5, 5, 4), feature_sets={ 'lr_features': lr_features, diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 3bd3304c5e..3bd154132b 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -210,6 +210,10 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): learning_rate=lr, loss={'MeanAbsoluteError': {}, 'MeanSquaredError': {}}, ) + dummy.meta['lr_features'] = model.meta['lr_features'] + dummy.meta['hr_features'] = model.meta['hr_features'] + dummy.meta['hr_exo_features'] = model.meta['hr_exo_features'] + dummy.meta['hr_out_features'] = model.meta['hr_out_features'] for batch in batch_handler: out_og = model._tf_generate(batch.low_res) From 65a681dbba35deeb13264ad2c7f6cd9b899edeb9 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 13 Mar 2026 11:54:08 -0600 Subject: [PATCH 08/26] test fixes --- sup3r/models/interface.py | 20 +++++++++++++++++--- sup3r/preprocessing/collections/base.py | 18 +++++++++++++++++- sup3r/preprocessing/collections/stats.py | 14 +++++++------- sup3r/preprocessing/samplers/base.py | 5 ++--- tests/utilities/test_loss_metrics.py | 1 + 5 files changed, 44 insertions(+), 14 deletions(-) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index eaca179c19..990b09811a 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -387,13 +387,23 @@ def hr_out_features(self): def obs_features(self): """Get list of exogenous observation feature names the model uses. These come from the names of the ``Sup3rObs..`` layers.""" - return self.meta.get('obs_features', []) + default = [f for f in self.hr_features if '_obs' in f] + return self.meta.get('obs_features', default) @property def hr_exo_features(self): """Get list of gapless exogenous high-resolution feature names the model uses, like topography.""" - return self.meta.get('hr_exo_features', []) + check = self.get_layer_features() + out = self.meta.get('hr_exo_features', []) + if set(out) != set(check): + msg = ( + f'Model meta hr_exo_features {out} does not match features ' + f'{check} found in model layers.' + ) + logger.warning(msg) + warn(msg) + return out @property def hr_features(self): @@ -401,7 +411,11 @@ def hr_features(self): the high-resolution data during training. This includes both output and exogenous features. """ - return self.meta.get('hr_features', []) + default = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + default += self.hr_exo_features + return self.meta.get('hr_features', default) @property def smoothing(self): diff --git a/sup3r/preprocessing/collections/base.py b/sup3r/preprocessing/collections/base.py index f9949cb282..83524817c8 100644 --- a/sup3r/preprocessing/collections/base.py +++ b/sup3r/preprocessing/collections/base.py @@ -37,10 +37,13 @@ def __init__( self.data = tuple(c.data for c in containers) self.containers = containers self._features: list = [] + self._data_features: list = [] @property def features(self): - """Get all features contained in data.""" + """Get all features "available" in containers. This can include proxy + observations that are not actually in the data but created dynamically + during sampling.""" if not self._features: _ = [ self._features.append(f) @@ -49,6 +52,19 @@ def features(self): ] return self._features + @property + def data_features(self): + """Get all features contained in data.""" + if not self._data_features: + _ = [ + self._data_features.append(f) + for f in np.concatenate([ + c.data.features for c in self.containers + ]) + if f not in self._data_features + ] + return self._data_features + @property def container_weights(self): """Get weights used to sample from different containers based on diff --git a/sup3r/preprocessing/collections/stats.py b/sup3r/preprocessing/collections/stats.py index 1f49fdad81..958864d690 100644 --- a/sup3r/preprocessing/collections/stats.py +++ b/sup3r/preprocessing/collections/stats.py @@ -51,7 +51,7 @@ def __init__(self, containers, means=None, stds=None): def _get_stat(self, stat_type, needed_features='all'): """Get either mean or std for all features and all containers.""" all_feats = ( - self.features if needed_features == 'all' else needed_features + self.data_features if needed_features == 'all' else needed_features ) hr_feats = set(self.containers[0].high_res.features).intersection( all_feats @@ -98,12 +98,12 @@ def _init_stats_dict(self, stats): if ( isinstance(stats, dict) and stats != {} - and any(f not in stats for f in self.features) + and any(f not in stats for f in self.data_features) ): msg = ( - f'Not all features ({self.features}) are found in the given ' - f'stats dictionary {stats}. This is obviously from a prior ' - 'run so you better be sure these stats carry over.' + f'Not all features ({self.data_features}) are found in the ' + f'given stats dictionary {stats}. This is obviously from a ' + 'prior run so you better be sure these stats carry over.' ) logger.warning(msg) warn(msg) @@ -113,7 +113,7 @@ def get_means(self, means): """Dictionary of means for each feature, computed across all data handlers.""" means = self._init_stats_dict(means) - needed_features = set(self.features) - set(means) + needed_features = set(self.data_features) - set(means) if any(needed_features): logger.info(f'Getting means for {needed_features}.') cmeans = [ @@ -132,7 +132,7 @@ def get_stds(self, stds): """Dictionary of standard deviations for each feature, computed across all data handlers.""" stds = self._init_stats_dict(stds) - needed_features = set(self.features) - set(stds) + needed_features = set(self.data_features) - set(stds) if any(needed_features): logger.info(f'Getting stds for {needed_features}.') cstds = [ diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 5353a4ae8e..859d941334 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -597,9 +597,8 @@ def lr_features_ind(self): @property def features(self): """Get the full set of features that should be included for training. - This is the union of lr_features, hr_out_features, hr_exo_features, and - obs_features. This is the set of features that will be sampled from the - data.""" + This is the union of lr_features, hr_out_features and hr_exo_features. + """ feats = self.lr_features feats += [f for f in self.hr_out_features if f not in feats] feats += [f for f in self.hr_exo_features if f not in feats] diff --git a/tests/utilities/test_loss_metrics.py b/tests/utilities/test_loss_metrics.py index ab4eef1420..c4ba3f96ab 100644 --- a/tests/utilities/test_loss_metrics.py +++ b/tests/utilities/test_loss_metrics.py @@ -308,6 +308,7 @@ def test_multiterm_loss(): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) model.meta['hr_out_features'] = ['u_100m', 'v_100m', 'temp_100m'] + model.meta['hr_features'] = ['u_100m', 'v_100m', 'temp_100m'] multi_loss = model.get_loss_fun({ 'MaterialDerivativeLoss': { 'gen_features': ['u_100m', 'v_100m', 'temp_100m'] From 0e87c2547b464d38171401e94170f16bf9f2938c Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 14 Mar 2026 11:01:59 -0600 Subject: [PATCH 09/26] refact: remove unused batch handler attribute checks and improve feature consistency validation refact: vectorized obs mask creation --- sup3r/models/abstract.py | 20 +----- sup3r/models/base.py | 24 +------- sup3r/models/conditional.py | 10 +-- sup3r/models/interface.py | 87 ++++++++++++++------------ sup3r/preprocessing/samplers/base.py | 74 ++++++++-------------- tests/conftest.py | 88 +++++++++++++++++++++++++++ tests/training/test_train_with_obs.py | 62 +++++++++++++++++++ 7 files changed, 227 insertions(+), 138 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 145d8724b0..f7092fb53b 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -198,24 +198,6 @@ def set_norm_stats(self, new_means, new_stdevs): pprint.pformat(self._stdevs, indent=2), ) - @staticmethod - def check_batch_handler_attrs(batch_handler): - """Not all batch handlers have the following attributes. So we perform - some sanitation before sending to `set_model_params`""" - return { - k: getattr(batch_handler, k, None) - for k in [ - 'smoothing', - 'lr_features', - 'hr_exo_features', - 'hr_out_features', - 'hr_features', - 'obs_features', - 'smoothed_features', - ] - if hasattr(batch_handler, k) - } - def norm_input(self, low_res): """Normalize low resolution data being input to the generator. @@ -1190,7 +1172,7 @@ def _run_exo_layer(cls, layer, input_array, hi_res_exo): return layer(input_array, hr_exo, extras) return layer(input_array, hr_exo) - @tf.function + #@tf.function def _tf_generate(self, low_res, hi_res_exo=None): """Use the generator model to generate high res data from low res input diff --git a/sup3r/models/base.py b/sup3r/models/base.py index bc7fcd1ecf..9de1e046b2 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -615,24 +615,6 @@ def update_adversarial_weights( return weight_gen_advers - @staticmethod - def check_batch_handler_attrs(batch_handler): - """Not all batch handlers have the following attributes. So we perform - some sanitation before sending to `set_model_params`""" - return { - k: getattr(batch_handler, k, None) - for k in [ - 'smoothing', - 'lr_features', - 'hr_exo_features', - 'hr_out_features', - 'hr_features', - 'obs_features', - 'smoothed_features', - ] - if hasattr(batch_handler, k) - } - def train( self, batch_handler, @@ -736,12 +718,8 @@ def train( self._init_tensorboard_writer(out_dir) self.set_norm_stats(batch_handler.means, batch_handler.stds) - params = self.check_batch_handler_attrs(batch_handler) self.set_model_params( - input_resolution=input_resolution, - s_enhance=batch_handler.s_enhance, - t_enhance=batch_handler.t_enhance, - **params, + input_resolution=input_resolution, batch_handler=batch_handler ) epochs = list(range(n_epoch)) diff --git a/sup3r/models/conditional.py b/sup3r/models/conditional.py index ff22409181..c6d1f4e7ff 100644 --- a/sup3r/models/conditional.py +++ b/sup3r/models/conditional.py @@ -423,15 +423,11 @@ def train( self._init_tensorboard_writer(out_dir) self.set_norm_stats(batch_handler.means, batch_handler.stds) - params = self.check_batch_handler_attrs(batch_handler) lower_models = getattr(batch_handler, 'lower_models', {}) for model in [self, *lower_models.values()]: model.set_model_params( - input_resolution=input_resolution, - s_enhance=batch_handler.s_enhance, - t_enhance=batch_handler.t_enhance, - **params, - ) + input_resolution=input_resolution, batch_handler=batch_handler + ) epochs = list(range(n_epoch)) @@ -443,7 +439,7 @@ def train( t0 = time.time() logger.info( - 'Training model ' 'for {} epochs starting at epoch {}'.format( + 'Training model for {} epochs starting at epoch {}'.format( n_epoch, epochs[0] ) ) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 990b09811a..3d3d6c06d8 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -233,11 +233,28 @@ def _ensure_valid_enhancement_factors(self): f'conflict with user provided values (s_enhance={s_enhance}, ' f't_enhance={t_enhance})' ) - check = layer_se == s_enhance or layer_te == t_enhance + check = layer_se == s_enhance and layer_te == t_enhance if not check: logger.error(msg) raise RuntimeError(msg) + def _ensure_feature_consistency(self): + """Ensure that the exogenous features specified in meta are consistent + with the features found in the model layers""" + features = [] + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if isinstance(layer, SUP3R_LAYERS): + feats = getattr(layer, 'features', [layer.name]) + features.extend(feats) + if set(self.hr_exo_features) != set(features): + msg = ( + f'Model meta hr_exo_features {self.hr_exo_features} does not ' + f'match features {features} found in model layers.' + ) + logger.error(msg) + raise RuntimeError(msg) + @property def output_resolution(self): """Resolution of output data. Given as a dictionary @@ -352,18 +369,6 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): hi_res = np.concatenate((hi_res, exo_output), axis=-1) return hi_res - def get_layer_features(self): - """Get the features that are input to mid-network layers. These are - typically gapless high-resolution features like topography or sparse - observations""" - features = [] - if hasattr(self, '_gen'): - for layer in self._gen.layers: - if isinstance(layer, SUP3R_LAYERS): - feats = getattr(layer, 'features', [layer.name]) - features.extend(feats) - return features - @property @abstractmethod def meta(self): @@ -394,16 +399,7 @@ def obs_features(self): def hr_exo_features(self): """Get list of gapless exogenous high-resolution feature names the model uses, like topography.""" - check = self.get_layer_features() - out = self.meta.get('hr_exo_features', []) - if set(out) != set(check): - msg = ( - f'Model meta hr_exo_features {out} does not match features ' - f'{check} found in model layers.' - ) - logger.warning(msg) - warn(msg) - return out + return self.meta.get('hr_exo_features', []) @property def hr_features(self): @@ -450,20 +446,25 @@ def version_record(self): """ return VERSION_RECORD - def set_model_params(self, **kwargs): + def set_model_params(self, batch_handler=None, **kwargs): """Set parameters used for training the model Parameters ---------- + batch_handler : object | None + Object that contains attributes used to set model meta parameters. + This is used during training to set parameters that are needed for + generation and also to check for consistency with previously set + parameters when loading a model from disk. kwargs : dict - Keyword arguments including 'input_resolution', - 'lr_features', 'hr_exo_features', 'hr_out_features', 'hr_features', - 'obs_features', 'smoothed_features', 's_enhance', 't_enhance', - 'smoothing' + Keyword arguments that are used to set model meta parameters. This + is used during training to set parameters that are needed for + generation and also to check for consistency with previously set + parameters when loading a model from disk. """ - keys = ( - 'input_resolution', + bh_keys = [ + 'smoothing', 'lr_features', 'hr_exo_features', 'hr_out_features', @@ -472,24 +473,32 @@ def set_model_params(self, **kwargs): 'smoothed_features', 's_enhance', 't_enhance', - 'smoothing', - ) - keys = [k for k in keys if k in kwargs] + ] - for var in keys: - val = self.meta.get(var, None) + meta_params = {} + if batch_handler is not None: + meta_params.update({ + k: getattr(batch_handler, k, None) + for k in bh_keys + if hasattr(batch_handler, k) + }) + + meta_params.update(kwargs) + for key, var in meta_params.items(): + val = self.meta.get(key, None) if val is None: - self.meta[var] = kwargs[var] - elif val != kwargs[var]: + self.meta[key] = var + elif val != var: msg = ( - 'Model was previously trained with {var}={} but ' - 'received new {var}={}'.format(val, kwargs[var], var=var) + 'Model was previously trained with {key}={} but ' + 'received new {key}={}'.format(val, var, key=key) ) logger.warning(msg) warn(msg) self._ensure_valid_enhancement_factors() self._ensure_valid_input_resolution() + self._ensure_feature_consistency() def save_params(self, out_dir): """ diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 859d941334..8ed2bef224 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -425,16 +425,13 @@ def obs_features_ind(self): list[int] Indices into ``features`` for each obs feature source. """ - if len(self.obs_features) == 0: - return [] - - if self.use_proxy_obs: - return [ - self.hr_features.index(f.replace('_obs', '')) - for f in self.obs_features - ] - else: - return [self.hr_features.index(f) for f in self.obs_features] + check_feats = ( + [f.replace('_obs', '') for f in self.obs_features] + if self.use_proxy_obs + else self.obs_features + ) + + return [self.features.index(f) for f in check_feats] def _get_proxy_obs(self, hi_res): """Generate proxy observation data by masking the gridded high-res @@ -604,37 +601,6 @@ def features(self): feats += [f for f in self.hr_exo_features if f not in feats] return feats - def _get_single_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): - """Get observation mask for a given spatial and temporal obs - fraction for a single batch entry. - - Parameters - ---------- - hi_res : np.ndarray - True high resolution data for a single batch entry. - spatial_frac : float - Fraction of the spatial domain that should be treated as - observations. This is a value between 0 and 1. - time_frac : float, optional - Fraction of the temporal domain that should be treated as - observations. This is a value between 0 and 1. Default is 1.0 - - Returns - ------- - np.ndarray - Mask which is True for locations that are not observed and False - for locations that are observed. - (spatial_1, spatial_2, n_features) - (spatial_1, spatial_2, n_temporal, n_features) - """ - mask_shape = [*hi_res.shape[:-1], len(self.hr_out_features)] - s_mask = RANDOM_GENERATOR.uniform(size=mask_shape[1:3]) <= spatial_frac - s_mask = s_mask[..., None, None] - t_mask = RANDOM_GENERATOR.uniform(size=mask_shape[-2]) <= time_frac - t_mask = t_mask[None, None, ..., None] - mask = ~(s_mask & t_mask) - return np.repeat(mask, mask_shape[-1], axis=-1) - def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): """Get observation mask for a given spatial and temporal obs fraction for an entire batch. This is divided between spatial and @@ -674,18 +640,26 @@ def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): if isinstance(time_frac, (list, tuple)) else [time_frac, time_frac] ) - s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=hi_res.shape[0]) - t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=hi_res.shape[0]) + n_obs, n_spatial_1, n_spatial_2, n_temporal = hi_res.shape[:-1] + n_features = len(self.hr_out_features) + + s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=n_obs) + t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=n_obs) s_fracs = np.clip(s_fracs, 0, 1) t_fracs = np.clip(t_fracs, 0, 1) - mask = np.stack( - [ - self._get_single_obs_mask(hi_res, s, t) - for s, t in zip(s_fracs, t_fracs) - ], - axis=0, + + s_mask = RANDOM_GENERATOR.uniform( + size=(n_obs, n_spatial_1, n_spatial_2) ) - return mask + s_mask = s_mask <= s_fracs[:, None, None] + s_mask = s_mask[..., None, None] + + t_mask = RANDOM_GENERATOR.uniform(size=(n_obs, n_temporal)) + t_mask = t_mask <= t_fracs[:, None] + t_mask = t_mask[:, None, None, :, None] + + mask = ~(s_mask & t_mask) + return np.repeat(mask, n_features, axis=-1) def _get_full_obs_mask(self, hi_res): """Define observation mask for the current batch. This differs from diff --git a/tests/conftest.py b/tests/conftest.py index 71925443bb..52193d4716 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -236,6 +236,94 @@ def func(): return func +@pytest.fixture(scope='package') +def gen_config_with_obs_3d_topo(): + """Get generator config with observation layers and topo layer.""" + + def func(): + return [ + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'temporal_mult': 2 + }, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'spatial_mult': 2 + }, + {'class': 'Activation', 'activation': 'relu'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + {'class': 'Sup3rConcatObs', 'name': 'u_10m_obs'}, + {'class': 'Sup3rConcatObs', 'name': 'v_10m_obs'}, + {'class': 'Sup3rConcat', 'name': 'topography'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + ] + + return func + + @pytest.fixture(scope='package') def gen_config_with_topo(): """Get generator config with custom topo layer.""" diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py index d8c421968a..f6cd8d359a 100644 --- a/tests/training/test_train_with_obs.py +++ b/tests/training/test_train_with_obs.py @@ -234,3 +234,65 @@ def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): tloss = model.history['train_geothermal_physics_loss_with_obs'].values assert np.sum(np.diff(tloss)) < 0 + + +def test_train_obs_with_topo(request): + """Test training with topo and obs. Make sure exo features are + properly concatenated.""" + + gen_config = 'gen_config_with_obs_3d_topo' + gen_config = request.getfixturevalue(gen_config)() + kwargs = { + 'file_paths': pytest.FP_WTK, + 'features': [*FEATURES_W, 'topography'], + 'target': TARGET_W, + 'shape': SHAPE, + } + + train_handler = DataHandler(**kwargs, time_slice=slice(None, 3000, 10)) + + val_handler = DataHandler(**kwargs, time_slice=slice(3000, None, 10)) + batcher = BatchHandler( + [train_handler], + [val_handler], + batch_size=2, + n_batches=1, + s_enhance=2, + t_enhance=2, + sample_shape=(20, 20, 10), + proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1}}, + feature_sets={ + 'lr_features': FEATURES_W, + 'hr_exo_features': [ + 'topography', + *[f'{feat}_obs' for feat in FEATURES_W], + ], + 'hr_out_features': FEATURES_W, + }, + ) + + Sup3rGan.seed() + + model = Sup3rGan( + gen_config, + pytest.ST_FP_DISC, + learning_rate=1e-4, + loss={ + 'GeothermalPhysicsLossWithObs': { + 'gen_features': FEATURES_W, + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], + } + }, + ) + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '16km', 'temporal': '3600min'}, + 'n_epoch': 3, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': None, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batcher, **model_kwargs) From 6a5d458326a27cbe7468c746dc6c802acbf9af5f Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 14 Mar 2026 11:30:38 -0600 Subject: [PATCH 10/26] test: add tests for sampler behavior with proxy observations --- sup3r/models/abstract.py | 2 +- tests/samplers/test_with_obs.py | 152 ++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 tests/samplers/test_with_obs.py diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index f7092fb53b..1141512827 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1172,7 +1172,7 @@ def _run_exo_layer(cls, layer, input_array, hi_res_exo): return layer(input_array, hr_exo, extras) return layer(input_array, hr_exo) - #@tf.function + @tf.function def _tf_generate(self, low_res, hi_res_exo=None): """Use the generator model to generate high res data from low res input diff --git a/tests/samplers/test_with_obs.py b/tests/samplers/test_with_obs.py new file mode 100644 index 0000000000..d2cf9a4d3c --- /dev/null +++ b/tests/samplers/test_with_obs.py @@ -0,0 +1,152 @@ +"""Test sampler behavior with proxy observations.""" + +import numpy as np +import pytest + +from sup3r.preprocessing import DualSampler, Sampler +from sup3r.preprocessing.base import Sup3rDataset +from sup3r.utilities.pytest.helpers import DummyData + +BASE_FEATURES = ['u_100m', 'v_100m'] +OBS_FEATURES = ['u_100m_obs', 'v_100m_obs'] + + +def _make_sampler( + sampler_cls, + hr_shape, + sample_shape, + batch_size, + proxy_obs_kwargs, + hr_features=None, +): + """Create either Sampler or DualSampler with proxy obs feature sets.""" + hr_features = hr_features or BASE_FEATURES + feature_sets = { + 'lr_features': BASE_FEATURES, + 'hr_out_features': BASE_FEATURES, + 'hr_exo_features': [ + *[f for f in hr_features if f == 'topography'], + *OBS_FEATURES, + ], + } + + if sampler_cls is Sampler: + data = DummyData(data_shape=hr_shape, features=hr_features) + return Sampler( + data, + sample_shape=sample_shape, + batch_size=batch_size, + proxy_obs_kwargs=proxy_obs_kwargs, + feature_sets=feature_sets, + ) + + lr_shape = (hr_shape[0] // 2, hr_shape[1] // 2, hr_shape[2]) + lr = DummyData(data_shape=lr_shape, features=BASE_FEATURES).data.high_res + hr = DummyData(data_shape=hr_shape, features=hr_features).data.high_res + data = Sup3rDataset(low_res=lr, high_res=hr) + return DualSampler( + data, + sample_shape=sample_shape, + batch_size=batch_size, + s_enhance=2, + t_enhance=1, + proxy_obs_kwargs=proxy_obs_kwargs, + feature_sets=feature_sets, + ) + + +def _get_hr_batch(sampler): + """Extract the high-res batch for Sampler and DualSampler outputs.""" + batch = next(sampler) + return batch[-1] if isinstance(batch, tuple) else batch + + +@pytest.mark.parametrize('sampler_cls', [Sampler, DualSampler]) +@pytest.mark.parametrize( + 'sample_shape, obs_fracs, expected', + [ + ((30, 30, 1), {'spatial': 0.4, 'time': 1.0}, 0.4), + ((30, 30, 12), {'spatial': 0.4, 'time': 0.5}, 0.2), + ], +) +def test_proxy_obs_appended_and_fraction( + sampler_cls, sample_shape, obs_fracs, expected +): + """Proxy obs channels are appended and sampled at configured fraction.""" + sampler = _make_sampler( + sampler_cls=sampler_cls, + hr_shape=(60, 60, 500), + sample_shape=sample_shape, + batch_size=20, + proxy_obs_kwargs={'onshore_obs_frac': obs_fracs}, + ) + + batch = _get_hr_batch(sampler) + obs = batch[..., -2:] + + assert batch.shape[-1] == 4 + assert obs.shape[-1] == 2 + + observed_frac = np.isfinite(obs[..., 0]).mean() + assert np.isclose(observed_frac, expected, atol=0.05) + + +@pytest.mark.parametrize('sampler_cls', [Sampler, DualSampler]) +def test_proxy_obs_fraction_bounds_with_ranges(sampler_cls): + """Observed fraction stays within expected range for sampled fractions.""" + s_range = [0.1, 0.3] + t_range = [0.2, 0.6] + sampler = _make_sampler( + sampler_cls=sampler_cls, + hr_shape=(80, 80, 500), + sample_shape=(40, 40, 20), + batch_size=8, + proxy_obs_kwargs={ + 'onshore_obs_frac': {'spatial': s_range, 'time': t_range} + }, + ) + + batch = _get_hr_batch(sampler) + obs = batch[..., -2:] + observed_by_sample = np.isfinite(obs[..., 0]).mean(axis=(1, 2, 3)) + + lower = s_range[0] * t_range[0] + upper = s_range[1] * t_range[1] + assert np.all(observed_by_sample >= (lower - 0.02)) + assert np.all(observed_by_sample <= (upper + 0.02)) + + +@pytest.mark.parametrize('sampler_cls', [Sampler, DualSampler]) +def test_proxy_obs_onshore_offshore_topography_fractions(sampler_cls): + """Onshore and offshore obs fractions are applied by topography mask.""" + sampler = _make_sampler( + sampler_cls=sampler_cls, + hr_shape=(80, 80, 500), + sample_shape=(40, 40, 12), + batch_size=8, + proxy_obs_kwargs={ + 'onshore_obs_frac': {'spatial': 0.8, 'time': 1.0}, + 'offshore_obs_frac': {'spatial': 0.1, 'time': 1.0}, + }, + hr_features=[*BASE_FEATURES, 'topography'], + ) + + topo_var = sampler.data.high_res['topography'] + topo = np.ones(topo_var.shape, dtype=np.float32) + topo[:, : topo.shape[1] // 2, :] = -1.0 + sampler.data.high_res['topography'] = (topo_var.dims, topo) + + batch = _get_hr_batch(sampler) + topo_idx = sampler.hr_features.index('topography') + topo_sample = batch[..., topo_idx] + obs = batch[..., -2:] + + onshore = topo_sample > 0 + offshore = ~onshore + + onshore_frac = np.isfinite(obs[..., 0][onshore]).mean() + offshore_frac = np.isfinite(obs[..., 0][offshore]).mean() + + assert np.isclose(onshore_frac, 0.8, atol=0.12) + assert np.isclose(offshore_frac, 0.1, atol=0.08) + assert onshore_frac > offshore_frac From 03e74935d12e454d404bfd0dbbc60cc4bfe32def Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 14 Mar 2026 16:05:22 -0600 Subject: [PATCH 11/26] make sure single sampler and dual sampler index obs and hr features correctly. --- sup3r/preprocessing/batch_queues/base.py | 2 +- sup3r/preprocessing/collections/base.py | 18 +----- sup3r/preprocessing/collections/stats.py | 10 +-- sup3r/preprocessing/samplers/base.py | 80 +++++++++++------------- sup3r/preprocessing/samplers/dual.py | 37 ++++++++--- tests/samplers/test_with_obs.py | 16 ++--- 6 files changed, 81 insertions(+), 82 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/base.py b/sup3r/preprocessing/batch_queues/base.py index ee22439666..cf972f886c 100644 --- a/sup3r/preprocessing/batch_queues/base.py +++ b/sup3r/preprocessing/batch_queues/base.py @@ -27,7 +27,7 @@ class SingleBatchQueue(AbstractBatchQueue): @property def queue_shape(self): """Shape of objects stored in the queue.""" - return [(self.batch_size, *self.hr_sample_shape, len(self.features))] + return [(self.batch_size, *self.hr_shape)] def transform( self, diff --git a/sup3r/preprocessing/collections/base.py b/sup3r/preprocessing/collections/base.py index 83524817c8..f9949cb282 100644 --- a/sup3r/preprocessing/collections/base.py +++ b/sup3r/preprocessing/collections/base.py @@ -37,13 +37,10 @@ def __init__( self.data = tuple(c.data for c in containers) self.containers = containers self._features: list = [] - self._data_features: list = [] @property def features(self): - """Get all features "available" in containers. This can include proxy - observations that are not actually in the data but created dynamically - during sampling.""" + """Get all features contained in data.""" if not self._features: _ = [ self._features.append(f) @@ -52,19 +49,6 @@ def features(self): ] return self._features - @property - def data_features(self): - """Get all features contained in data.""" - if not self._data_features: - _ = [ - self._data_features.append(f) - for f in np.concatenate([ - c.data.features for c in self.containers - ]) - if f not in self._data_features - ] - return self._data_features - @property def container_weights(self): """Get weights used to sample from different containers based on diff --git a/sup3r/preprocessing/collections/stats.py b/sup3r/preprocessing/collections/stats.py index 958864d690..1d51c3b659 100644 --- a/sup3r/preprocessing/collections/stats.py +++ b/sup3r/preprocessing/collections/stats.py @@ -51,7 +51,7 @@ def __init__(self, containers, means=None, stds=None): def _get_stat(self, stat_type, needed_features='all'): """Get either mean or std for all features and all containers.""" all_feats = ( - self.data_features if needed_features == 'all' else needed_features + self.features if needed_features == 'all' else needed_features ) hr_feats = set(self.containers[0].high_res.features).intersection( all_feats @@ -98,10 +98,10 @@ def _init_stats_dict(self, stats): if ( isinstance(stats, dict) and stats != {} - and any(f not in stats for f in self.data_features) + and any(f not in stats for f in self.features) ): msg = ( - f'Not all features ({self.data_features}) are found in the ' + f'Not all features ({self.features}) are found in the ' f'given stats dictionary {stats}. This is obviously from a ' 'prior run so you better be sure these stats carry over.' ) @@ -113,7 +113,7 @@ def get_means(self, means): """Dictionary of means for each feature, computed across all data handlers.""" means = self._init_stats_dict(means) - needed_features = set(self.data_features) - set(means) + needed_features = set(self.features) - set(means) if any(needed_features): logger.info(f'Getting means for {needed_features}.') cmeans = [ @@ -132,7 +132,7 @@ def get_stds(self, stds): """Dictionary of standard deviations for each feature, computed across all data handlers.""" stds = self._init_stats_dict(stds) - needed_features = set(self.data_features) - set(stds) + needed_features = set(self.features) - set(stds) if any(needed_features): logger.info(f'Getting stds for {needed_features}.') cstds = [ diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 8ed2bef224..82b12d19d5 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -126,7 +126,7 @@ def use_proxy_obs(self): check = bool(self.proxy_obs_kwargs) check = check or ( len(self.obs_features) > 0 - and all(f not in self.features for f in self.obs_features) + and all(f not in self.hr_features for f in self.obs_features) ) return check @@ -174,9 +174,9 @@ def get_sample_index(self, n_obs=None): self.shape, self.sample_shape[2] * n_obs ) feats = ( - self.features + self.hr_features if not self.use_proxy_obs - else self.features[: -len(self.obs_features)] + else self.hr_features[: -len(self.obs_features)] ) return (*spatial_slice, time_slice, feats) @@ -414,25 +414,6 @@ def _slow_batch(self): def _fast_batch_possible(self): return self.batch_size * self.sample_shape[2] <= self.data.shape[2] - @property - def obs_features_ind(self): - """Get the source feature indices in ``features`` for each obs - feature. Each obs feature named ``_obs`` maps to the - corresponding ```` in the features. - - Returns - ------- - list[int] - Indices into ``features`` for each obs feature source. - """ - check_feats = ( - [f.replace('_obs', '') for f in self.obs_features] - if self.use_proxy_obs - else self.obs_features - ) - - return [self.features.index(f) for f in check_feats] - def _get_proxy_obs(self, hi_res): """Generate proxy observation data by masking the gridded high-res data. Unobserved locations are set to NaN. @@ -520,7 +501,7 @@ def _parse_features(self, unparsed_feats): if any('*' in fn for fn in parsed_feats): out = [] - for feature in self.features: + for feature in self.hr_features: match = any( fnmatch(feature.lower(), pattern.lower()) for pattern in parsed_feats @@ -539,15 +520,17 @@ def lr_features(self): @property def hr_features(self): - """List of feature names or patt*erns that should be available as - either high-resolution model inputs (like topography or observations) - or as ground truth targets. If no entry is provided then all available - features from data will be used.""" - out = [ - f for f in self.hr_out_features if f not in self.hr_exo_features - ] - out += self.hr_exo_features - return out + """List of feature names or patt*erns that should be available in the + high-resolution data. For a non-dual sampler this is all features, + since even features only provided to the model as low-resolution still + need to be coarsened from the high-resolution data. This is in contrast + to dual samplers + (:class:`~sup3r.preprocessing.samplers.dual.DualSampler`), where there + are separate high-resolution and low-resolution data members.""" + feats = self.lr_features + feats += [f for f in self.hr_out_features if f not in feats] + feats += [f for f in self.hr_exo_features if f not in feats] + return feats @property def hr_out_features(self): @@ -578,28 +561,41 @@ def obs_features(self): @property def hr_features_ind(self): """Get the high-resolution feature channel indices that should be - included for training. This includes hr_out_features and + included for loss calculations. This includes hr_out_features and hr_exo_features, Any high-resolution features that are only included in the data handler to be coarsened for the low-res input are removed. """ - return [self.features.index(f) for f in self.hr_features] + hr_feats = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + hr_feats += self.hr_exo_features + return [self.hr_features.index(f) for f in hr_feats] @property def lr_features_ind(self): """Get the low-resolution feature channel indices that should be included for training. This includes lr_features. """ - return [self.features.index(f) for f in self.lr_features] + return [self.hr_features.index(f) for f in self.lr_features] @property - def features(self): - """Get the full set of features that should be included for training. - This is the union of lr_features, hr_out_features and hr_exo_features. + def obs_features_ind(self): + """Get the source feature indices in ``features`` for each obs + feature. Each obs feature named ``_obs`` maps to the + corresponding ```` in the features. + + Returns + ------- + list[int] + Indices into ``features`` for each obs feature source. """ - feats = self.lr_features - feats += [f for f in self.hr_out_features if f not in feats] - feats += [f for f in self.hr_exo_features if f not in feats] - return feats + check_feats = ( + [f.replace('_obs', '') for f in self.obs_features] + if self.use_proxy_obs + else self.obs_features + ) + + return [self.hr_features.index(f) for f in check_feats] def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): """Get observation mask for a given spatial and temporal obs diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index b81e2fff11..e540eaeacc 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -77,16 +77,16 @@ def __init__( observations. Keys can include ``onshore_obs_frac`` and ``offshore_obs_frac`` which specify the fraction of the batch that should be treated as onshore and offshore observations, - respectively. For example, ``proxy_obs_kwargs={'onshore_obs_frac': - {'spatial': 0.1, 'temporal': 0.2}, 'offshore_obs_frac': {'spatial': - 0.05, 'temporal': 0.1}}`` would specify that for the onshore - region observations cover 10% of the spatial domain and 20% of the - temporal domain, while for the offshore region observations cover - 5% of the spatial domain and 10% of the temporal domain. Instead of - a single float, these can also be lists to specify a lower and - upper bound for the spatial and temporal fractions, in which case - the actual fraction for each batch will be sampled uniformly - between these bounds. + respectively. For example, ``proxy_obs_kwargs={ 'onshore_obs_frac': + { 'spatial': 0.1, 'temporal': 0.2}, 'offshore_obs_frac': { + 'spatial': 0.05, 'temporal': 0.1} }`` would specify that for the + onshore region observations cover 10% of the spatial domain and 20% + of the temporal domain, while for the offshore region observations + cover 5% of the spatial domain and 10% of the temporal domain. + Instead of a single float, these can also be lists to specify a + lower and upper bound for the spatial and temporal fractions, in + which case the actual fraction for each batch will be sampled + uniformly between these bounds. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -152,6 +152,23 @@ def check_shape_consistency(self): ) assert self.data.high_res.shape[:-1] == enhanced_shape, msg + @property + def hr_features(self): + """List of feature names or patt*erns that should be available in the + high-resolution data. For dual samplers, this includes only the + features that are specifically designated as high-resolution outputs or + exogenous inputs. For a non-dual sampler + (:class:`~sup3r.preprocessing.samplers.base.Sampler`), this is all + features, since even features only provided to the model as + low-resolution still need to be coarsened from the high-resolution + data. + """ + out = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + out += self.hr_exo_features + return out + def get_sample_index(self, n_obs=None): """Get paired sample index, consisting of index for the low res sample and the index for the high res sample with the same spatiotemporal diff --git a/tests/samplers/test_with_obs.py b/tests/samplers/test_with_obs.py index d2cf9a4d3c..6c722fad7b 100644 --- a/tests/samplers/test_with_obs.py +++ b/tests/samplers/test_with_obs.py @@ -7,7 +7,8 @@ from sup3r.preprocessing.base import Sup3rDataset from sup3r.utilities.pytest.helpers import DummyData -BASE_FEATURES = ['u_100m', 'v_100m'] +LR_FEATURES = ['u_100m', 'v_100m', 'temperature_2m'] +HR_OUT_FEATURES = ['u_100m', 'v_100m'] OBS_FEATURES = ['u_100m_obs', 'v_100m_obs'] @@ -20,10 +21,10 @@ def _make_sampler( hr_features=None, ): """Create either Sampler or DualSampler with proxy obs feature sets.""" - hr_features = hr_features or BASE_FEATURES + hr_features = hr_features or LR_FEATURES feature_sets = { - 'lr_features': BASE_FEATURES, - 'hr_out_features': BASE_FEATURES, + 'lr_features': LR_FEATURES, + 'hr_out_features': HR_OUT_FEATURES, 'hr_exo_features': [ *[f for f in hr_features if f == 'topography'], *OBS_FEATURES, @@ -41,7 +42,7 @@ def _make_sampler( ) lr_shape = (hr_shape[0] // 2, hr_shape[1] // 2, hr_shape[2]) - lr = DummyData(data_shape=lr_shape, features=BASE_FEATURES).data.high_res + lr = DummyData(data_shape=lr_shape, features=LR_FEATURES).data.high_res hr = DummyData(data_shape=hr_shape, features=hr_features).data.high_res data = Sup3rDataset(low_res=lr, high_res=hr) return DualSampler( @@ -84,7 +85,8 @@ def test_proxy_obs_appended_and_fraction( batch = _get_hr_batch(sampler) obs = batch[..., -2:] - assert batch.shape[-1] == 4 + expected_channels = 5 if sampler_cls is Sampler else 4 + assert batch.shape[-1] == expected_channels assert obs.shape[-1] == 2 observed_frac = np.isfinite(obs[..., 0]).mean() @@ -128,7 +130,7 @@ def test_proxy_obs_onshore_offshore_topography_fractions(sampler_cls): 'onshore_obs_frac': {'spatial': 0.8, 'time': 1.0}, 'offshore_obs_frac': {'spatial': 0.1, 'time': 1.0}, }, - hr_features=[*BASE_FEATURES, 'topography'], + hr_features=[*LR_FEATURES, 'topography'], ) topo_var = sampler.data.high_res['topography'] From 473a95d10cb26b09a99dbb09ef60c9f04ce89c42 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 14 Mar 2026 16:22:37 -0600 Subject: [PATCH 12/26] fix: hr_feature ordering with exo features --- sup3r/preprocessing/samplers/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 82b12d19d5..5c91e5e01c 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -527,8 +527,12 @@ def hr_features(self): to dual samplers (:class:`~sup3r.preprocessing.samplers.dual.DualSampler`), where there are separate high-resolution and low-resolution data members.""" - feats = self.lr_features - feats += [f for f in self.hr_out_features if f not in feats] + feats = [f for f in self.lr_features if f not in self.hr_exo_features] + feats += [ + f + for f in self.hr_out_features + if f not in feats and f not in self.hr_exo_features + ] feats += [f for f in self.hr_exo_features if f not in feats] return feats From 19ca9414bf0a57ed453082089b6d3cdd3cd72438 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 15 Mar 2026 09:35:17 -0600 Subject: [PATCH 13/26] Refactor model parameter setting and update feature handling - Replaced direct assignment of model meta attributes with a new method `set_model_params` for better encapsulation and clarity. - Updated references from `hr_features` to `hr_source_features` in various classes to ensure consistency in feature handling. - Modified tests to utilize the new parameter setting method, improving readability and maintainability. - Ensured that all relevant tests are updated to reflect changes in feature management and model configuration. --- sup3r/models/interface.py | 77 ++-------------- sup3r/preprocessing/batch_queues/base.py | 19 ++-- sup3r/preprocessing/samplers/base.py | 56 +++++++----- sup3r/preprocessing/samplers/dual.py | 30 +++---- tests/bias/test_bias_correction.py | 11 ++- tests/bias/test_presrat_bias_correction.py | 11 ++- tests/bias/test_qdm_bias_correction.py | 62 +++++++------ tests/forward_pass/test_forward_pass.py | 42 +++++---- tests/forward_pass/test_forward_pass_exo.py | 64 ++++++++------ tests/forward_pass/test_forward_pass_obs.py | 14 +-- tests/forward_pass/test_multi_step.py | 97 +++++++++++++++------ tests/forward_pass/test_surface_model.py | 4 +- tests/output/test_qa.py | 11 ++- tests/pipeline/test_cli.py | 54 ++++++------ tests/pipeline/test_pipeline.py | 48 +++++----- tests/samplers/test_with_obs.py | 2 +- tests/training/test_train_gan.py | 12 +-- tests/training/test_train_with_obs.py | 18 +++- tests/utilities/test_loss_metrics.py | 7 +- 19 files changed, 340 insertions(+), 299 deletions(-) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 3d3d6c06d8..12c92b4c6f 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -95,45 +95,16 @@ def is_4d(self): """Check if model expects spatial only input""" return self.input_dims == 4 - def get_s_enhance_from_layers(self): - """Compute factor by which model will enhance spatial resolution from - layer attributes. Used in model training during high res coarsening""" - s_enhance = None - if hasattr(self, '_gen'): - s_enhancements = [ - getattr(layer, '_spatial_mult', 1) - for layer in self._gen.layers - ] - s_enhance = int(np.prod(s_enhancements)) - return s_enhance - - def get_t_enhance_from_layers(self): - """Compute factor by which model will enhance temporal resolution from - layer attributes. Used in model training during high res coarsening""" - t_enhance = None - if hasattr(self, '_gen'): - t_enhancements = [ - getattr(layer, '_temporal_mult', 1) - for layer in self._gen.layers - ] - t_enhance = int(np.prod(t_enhancements)) - return t_enhance - @property def s_enhance(self): """Factor by which model will enhance spatial resolution. Used in model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" models = getattr(self, 'models', [self]) - s_enhances = [m.meta.get('s_enhance', None) for m in models] - s_enhance = ( - self.get_s_enhance_from_layers() - if any(s is None for s in s_enhances) - else int(np.prod(s_enhances)) - ) + s_enhances = [m.meta.get('s_enhance', 1) for m in models] if len(models) == 1 and isinstance(self.meta, dict): - self.meta['s_enhance'] = s_enhance - return s_enhance + self.meta['s_enhance'] = np.prod(s_enhances) + return np.prod(s_enhances) @property def t_enhance(self): @@ -141,15 +112,10 @@ def t_enhance(self): model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" models = getattr(self, 'models', [self]) - t_enhances = [m.meta.get('t_enhance', None) for m in models] - t_enhance = ( - self.get_t_enhance_from_layers() - if any(t is None for t in t_enhances) - else int(np.prod(t_enhances)) - ) + t_enhances = [m.meta.get('t_enhance', 1) for m in models] if len(models) == 1 and isinstance(self.meta, dict): - self.meta['t_enhance'] = t_enhance - return t_enhance + self.meta['t_enhance'] = np.prod(t_enhances) + return np.prod(t_enhances) @property def s_enhancements(self): @@ -215,29 +181,6 @@ def _ensure_valid_input_resolution(self): logger.error(msg) raise RuntimeError(msg) - def _ensure_valid_enhancement_factors(self): - """Ensure user provided enhancement factors are the same as those - computed from layer attributes""" - t_enhance = self.meta.get('t_enhance', None) - s_enhance = self.meta.get('s_enhance', None) - if s_enhance is None or t_enhance is None: - return - - layer_se = self.get_s_enhance_from_layers() - layer_te = self.get_t_enhance_from_layers() - layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] - layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] - msg = ( - f'Enhancement factors computed from layer attributes ' - f'(s_enhance={layer_se}, t_enhance={layer_te}) ' - f'conflict with user provided values (s_enhance={s_enhance}, ' - f't_enhance={t_enhance})' - ) - check = layer_se == s_enhance and layer_te == t_enhance - if not check: - logger.error(msg) - raise RuntimeError(msg) - def _ensure_feature_consistency(self): """Ensure that the exogenous features specified in meta are consistent with the features found in the model layers""" @@ -407,11 +350,11 @@ def hr_features(self): the high-resolution data during training. This includes both output and exogenous features. """ - default = [ + out = [ f for f in self.hr_out_features if f not in self.hr_exo_features ] - default += self.hr_exo_features - return self.meta.get('hr_features', default) + out += self.hr_exo_features + return out @property def smoothing(self): @@ -468,7 +411,6 @@ def set_model_params(self, batch_handler=None, **kwargs): 'lr_features', 'hr_exo_features', 'hr_out_features', - 'hr_features', 'obs_features', 'smoothed_features', 's_enhance', @@ -496,7 +438,6 @@ def set_model_params(self, batch_handler=None, **kwargs): logger.warning(msg) warn(msg) - self._ensure_valid_enhancement_factors() self._ensure_valid_input_resolution() self._ensure_feature_consistency() diff --git a/sup3r/preprocessing/batch_queues/base.py b/sup3r/preprocessing/batch_queues/base.py index cf972f886c..c45f8c81cf 100644 --- a/sup3r/preprocessing/batch_queues/base.py +++ b/sup3r/preprocessing/batch_queues/base.py @@ -13,21 +13,18 @@ class SingleBatchQueue(AbstractBatchQueue): - """Base BatchQueue class for single dataset containers - - Note - ---- - Here we use `len(self.features)` for the last dimension of samples, since - samples in :class:`SingleBatchQueue` queues are coarsened to produce - low-res samples, and then the `lr_only_features` are removed with - `hr_features_ind`. In contrast, for samples in :class:`DualBatchQueue` - queues there are low / high res pairs and the high-res only stores the - `hr_features`""" + """Base BatchQueue class for single dataset containers""" @property def queue_shape(self): """Shape of objects stored in the queue.""" - return [(self.batch_size, *self.hr_shape)] + return [ + ( + self.batch_size, + *self.hr_sample_shape, + len(self.hr_source_features), + ) + ] def transform( self, diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 5c91e5e01c..d9f2e1fd1e 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -126,7 +126,9 @@ def use_proxy_obs(self): check = bool(self.proxy_obs_kwargs) check = check or ( len(self.obs_features) > 0 - and all(f not in self.hr_features for f in self.obs_features) + and all( + f not in self.hr_source_features for f in self.obs_features + ) ) return check @@ -174,9 +176,9 @@ def get_sample_index(self, n_obs=None): self.shape, self.sample_shape[2] * n_obs ) feats = ( - self.hr_features + self.hr_source_features if not self.use_proxy_obs - else self.hr_features[: -len(self.obs_features)] + else self.hr_source_features[: -len(self.obs_features)] ) return (*spatial_slice, time_slice, feats) @@ -220,7 +222,7 @@ def check_feature_consistency(self): """Check that the feature sets are consistent with each other and the obs features are configured correctly.""" if self.use_proxy_obs and not all( - f in self.hr_features for f in self.obs_features + f in self.hr_source_features for f in self.obs_features ): msg = ( 'When using proxy observations, all obs features must be ' @@ -251,10 +253,10 @@ def check_feature_consistency(self): if len(self.hr_exo_features) > 0: msg = ( f'hr_exo_features {self.hr_exo_features} must come at the end ' - f'of the full high-res feature set: {self.hr_features}' + f'of the full high-res feature set: {self.hr_source_features}' ) assert list(self.hr_exo_features) == list( - self.hr_features[-len(self.hr_exo_features) :] + self.hr_source_features[-len(self.hr_exo_features) :] ), msg assert all( @@ -501,7 +503,7 @@ def _parse_features(self, unparsed_feats): if any('*' in fn for fn in parsed_feats): out = [] - for feature in self.hr_features: + for feature in self.hr_source_features: match = any( fnmatch(feature.lower(), pattern.lower()) for pattern in parsed_feats @@ -519,12 +521,12 @@ def lr_features(self): return self._parse_features(self._lr_features) @property - def hr_features(self): - """List of feature names or patt*erns that should be available in the - high-resolution data. For a non-dual sampler this is all features, - since even features only provided to the model as low-resolution still - need to be coarsened from the high-resolution data. This is in contrast - to dual samplers + def hr_source_features(self): + """List of feature names or patt*erns that should be available natively + as high-resolution. For a non-dual sampler this is all features, since + even features only provided to the model as low-resolution still need + to be coarsened from the high-resolution data. This is in contrast to + dual samplers (:class:`~sup3r.preprocessing.samplers.dual.DualSampler`), where there are separate high-resolution and low-resolution data members.""" feats = [f for f in self.lr_features if f not in self.hr_exo_features] @@ -536,6 +538,18 @@ def hr_features(self): feats += [f for f in self.hr_exo_features if f not in feats] return feats + @property + def hr_features(self): + """List of feature names or patt*erns that the model is shown at + high-resolution. This does not include features that are only shown to + the model after coarsening. Thus, this includes hr_out_features and + and hr_exo_features.""" + out = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + out += self.hr_exo_features + return out + @property def hr_out_features(self): """List of feature names or patt*erns that should be output by the @@ -560,7 +574,7 @@ def obs_features(self): the generative model. These are different from the `hr_exo_features` in that they are intended to be used as observation features with NaN values where observations are not available.""" - return [f for f in self.hr_features if '_obs' in f] + return [f for f in self.hr_source_features if '_obs' in f] @property def hr_features_ind(self): @@ -569,18 +583,14 @@ def hr_features_ind(self): hr_exo_features, Any high-resolution features that are only included in the data handler to be coarsened for the low-res input are removed. """ - hr_feats = [ - f for f in self.hr_out_features if f not in self.hr_exo_features - ] - hr_feats += self.hr_exo_features - return [self.hr_features.index(f) for f in hr_feats] + return [self.hr_source_features.index(f) for f in self.hr_features] @property def lr_features_ind(self): """Get the low-resolution feature channel indices that should be included for training. This includes lr_features. """ - return [self.hr_features.index(f) for f in self.lr_features] + return [self.hr_source_features.index(f) for f in self.lr_features] @property def obs_features_ind(self): @@ -599,7 +609,7 @@ def obs_features_ind(self): else self.obs_features ) - return [self.hr_features.index(f) for f in check_feats] + return [self.hr_source_features.index(f) for f in check_feats] def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): """Get observation mask for a given spatial and temporal obs @@ -669,8 +679,8 @@ def _get_full_obs_mask(self, hi_res): on_sf = self.onshore_obs_frac.get('spatial', 0.0) on_tf = self.onshore_obs_frac.get('time', 1.0) obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf) - if 'topography' in self.hr_features and self.offshore_obs_frac: - topo_idx = self.hr_features.index('topography') + if 'topography' in self.hr_source_features and self.offshore_obs_frac: + topo_idx = self.hr_source_features.index('topography') topo = hi_res[..., topo_idx] off_sf = self.offshore_obs_frac.get('spatial', 0.0) off_tf = self.offshore_obs_frac.get('time', 1.0) diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index e540eaeacc..4b0a865a35 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -137,6 +137,15 @@ def __init__( } self.post_init_log(post_init_args) + @property + def hr_source_features(self): + """Features available natively at high-resolution.""" + out = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + out += self.hr_exo_features + return out + def check_shape_consistency(self): """Make sure container shapes are compatible with enhancement factors.""" @@ -152,23 +161,6 @@ def check_shape_consistency(self): ) assert self.data.high_res.shape[:-1] == enhanced_shape, msg - @property - def hr_features(self): - """List of feature names or patt*erns that should be available in the - high-resolution data. For dual samplers, this includes only the - features that are specifically designated as high-resolution outputs or - exogenous inputs. For a non-dual sampler - (:class:`~sup3r.preprocessing.samplers.base.Sampler`), this is all - features, since even features only provided to the model as - low-resolution still need to be coarsened from the high-resolution - data. - """ - out = [ - f for f in self.hr_out_features if f not in self.hr_exo_features - ] - out += self.hr_exo_features - return out - def get_sample_index(self, n_obs=None): """Get paired sample index, consisting of index for the low res sample and the index for the high res sample with the same spatiotemporal @@ -191,9 +183,9 @@ def get_sample_index(self, n_obs=None): for s in lr_index[2:-1] ] hr_feats = ( - self.hr_features[: -len(self.obs_features)] + self.hr_source_features[: -len(self.obs_features)] if self.use_proxy_obs - else self.hr_features + else self.hr_source_features ) hr_index = (*hr_index, hr_feats) diff --git a/tests/bias/test_bias_correction.py b/tests/bias/test_bias_correction.py index 34c204f146..0062f3f2e1 100644 --- a/tests/bias/test_bias_correction.py +++ b/tests/bias/test_bias_correction.py @@ -471,10 +471,13 @@ def test_fwp_integration(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['lr_features'] = features - model.meta['hr_out_features'] = features - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=features, + hr_out_features=features, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: bias_fp = os.path.join(td, 'bc.h5') diff --git a/tests/bias/test_presrat_bias_correction.py b/tests/bias/test_presrat_bias_correction.py index c63ee4da96..62987ca422 100644 --- a/tests/bias/test_presrat_bias_correction.py +++ b/tests/bias/test_presrat_bias_correction.py @@ -757,10 +757,13 @@ def test_fwp_integration(tmp_path, presrat_params, fp_fut_cc): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['lr_features'] = features - model.meta['hr_out_features'] = features - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=features, + hr_out_features=features, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) out_dir = os.path.join(tmp_path, 'st_gan') model.save(out_dir) diff --git a/tests/bias/test_qdm_bias_correction.py b/tests/bias/test_qdm_bias_correction.py index d5722d43bf..e3aa709634 100644 --- a/tests/bias/test_qdm_bias_correction.py +++ b/tests/bias/test_qdm_bias_correction.py @@ -139,9 +139,9 @@ def test_qdm_bc(fp_fut_cc): # Each location can be all finite or all NaN, but not both for v in out: tmp = np.isfinite(out[v].reshape(-1, out[v].shape[-1])) - assert np.all( - np.all(tmp, axis=1) == ~np.all(~tmp, axis=1) - ), f'For each location of {v} it should be all finite or nonte' + assert np.all(np.all(tmp, axis=1) == ~np.all(~tmp, axis=1)), ( + f'For each location of {v} it should be all finite or nonte' + ) def test_parallel(fp_fut_cc): @@ -178,9 +178,9 @@ def test_parallel(fp_fut_cc): for k in out_s: assert k in out_p, f'Missing {k} in parallel run' - assert np.allclose( - out_s[k], out_p[k], equal_nan=True - ), f'Different results for {k}' + assert np.allclose(out_s[k], out_p[k], equal_nan=True), ( + f'Different results for {k}' + ) def test_fill_nan(fp_fut_cc): @@ -202,14 +202,14 @@ def test_fill_nan(fp_fut_cc): out = c.run(fill_extend=False) # Ignore non `params` parameters, such as window_center params = (v for v in out if v.endswith('params')) - assert np.all( - [np.isnan(out[v]).any() for v in params] - ), 'Assume at least one NaN value for each param' + assert np.all([np.isnan(out[v]).any() for v in params]), ( + 'Assume at least one NaN value for each param' + ) out = c.run() - assert np.all( - [np.isfinite(v).all() for v in out.values()] - ), 'All NaN values where supposed to be filled' + assert np.all([np.isfinite(v).all() for v in out.values()]), ( + 'All NaN values where supposed to be filled' + ) def test_save_file(tmp_path, fp_fut_cc): @@ -350,7 +350,8 @@ def test_bc_identity(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(original)[idx], compute_if_dask(corrected)[idx]) + compute_if_dask(original)[idx], compute_if_dask(corrected)[idx] + ) def test_bc_identity_absolute(tmp_path, fp_fut_cc, dist_params): @@ -375,7 +376,8 @@ def test_bc_identity_absolute(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(original)[idx], compute_if_dask(corrected)[idx]) + compute_if_dask(original)[idx], compute_if_dask(corrected)[idx] + ) def test_bc_model_constant(tmp_path, fp_fut_cc, dist_params): @@ -400,7 +402,8 @@ def test_bc_model_constant(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], -10) + compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], -10 + ) def test_bc_trend(tmp_path, fp_fut_cc, dist_params): @@ -425,7 +428,8 @@ def test_bc_trend(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], 10) + compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], 10 + ) def test_bc_trend_same_hist(tmp_path, fp_fut_cc, dist_params): @@ -449,7 +453,8 @@ def test_bc_trend_same_hist(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(original)[idx], compute_if_dask(corrected)[idx]) + compute_if_dask(original)[idx], compute_if_dask(corrected)[idx] + ) def test_fwp_integration(tmp_path): @@ -493,10 +498,13 @@ def test_fwp_integration(tmp_path): Sup3rGan.seed() model = Sup3rGan(pytest.ST_FP_GEN, pytest.ST_FP_DISC, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['lr_features'] = features - model.meta['hr_out_features'] = features - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=features, + hr_out_features=features, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) bias_fp = os.path.join(tmp_path, 'bc.h5') out_dir = os.path.join(tmp_path, 'st_gan') @@ -568,12 +576,12 @@ def test_fwp_integration(tmp_path): bc_chunk = bc_fwp.get_input_chunk(ichunk) chunk = fwp.get_input_chunk(ichunk) delta = bc_chunk.input_data - chunk.input_data - assert np.allclose( - delta[..., 0], -2.72, atol=1e-03 - ), 'U reference offset is -1' - assert np.allclose( - delta[..., 1], 2.72, atol=1e-03 - ), 'V reference offset is 1' + assert np.allclose(delta[..., 0], -2.72, atol=1e-03), ( + 'U reference offset is -1' + ) + assert np.allclose(delta[..., 1], 2.72, atol=1e-03), ( + 'V reference offset is 1' + ) kwargs = { 'model_kwargs': strat.model_kwargs, diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 486b574bba..8a320aaa60 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -73,7 +73,7 @@ def test_fwp_nc_cc(): out_pattern=out_files, input_handler_name='DataHandlerNCforCC', pass_workers=None, - invert_uv=False + invert_uv=False, ) forward_pass = ForwardPass(strat) forward_pass.run(strat, node_index=0) @@ -173,7 +173,7 @@ def test_fwp_spatial_only(input_files): out_pattern=out_files, pass_workers=1, output_workers=1, - invert_uv=False + invert_uv=False, ) forward_pass = ForwardPass(strat) assert strat.output_workers == 1 @@ -227,7 +227,7 @@ def test_fwp_nc(input_files): }, out_pattern=out_files, pass_workers=1, - invert_uv=False + invert_uv=False, ) forward_pass = ForwardPass(strat) assert forward_pass.strategy.pass_workers == 1 @@ -443,14 +443,12 @@ def test_fwp_chunking(input_files): 'time_slice': time_slice, }, ) - data_chunked = np.zeros( - ( - shape[0] * s_enhance, - shape[1] * s_enhance, - raw_tsteps * t_enhance, - len(model.hr_out_features), - ) - ) + data_chunked = np.zeros(( + shape[0] * s_enhance, + shape[1] * s_enhance, + raw_tsteps * t_enhance, + len(model.hr_out_features), + )) handlerNC = DataHandler( input_files, FEATURES, target=target, shape=shape ) @@ -564,19 +562,25 @@ def test_fwp_multi_step_model(input_files): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s_model.meta['lr_features'] = ['u_100m', 'v_100m'] - s_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - assert s_model.s_enhance == 2 - assert s_model.t_enhance == 1 + s_model.set_model_params( + lr_features=['u_100m', 'v_100m'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '6km', 'temporal': '40min'}, + ) _ = s_model.generate(np.ones((4, 10, 10, 2))) fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['lr_features'] = ['u_100m', 'v_100m'] - st_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - assert st_model.s_enhance == 3 - assert st_model.t_enhance == 4 + st_model.set_model_params( + lr_features=['u_100m', 'v_100m'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '3km', 'temporal': '40min'}, + ) _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) with tempfile.TemporaryDirectory() as td: diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index 212b9882c5..436b30807d 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -54,25 +54,23 @@ def test_fwp_multi_step_model_topo_exoskip(input_files): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] - s1_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - s1_model.meta['input_resolution'] = { - 'spatial': '48km', - 'temporal': '60min', - } + s1_model.set_model_params( + lr_features=['u_100m', 'v_100m', 'topography'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '48km', 'temporal': '60min'}, + ) _ = s1_model.generate(np.ones((4, 10, 10, 3))) s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] - s2_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - s2_model.meta['input_resolution'] = { - 'spatial': '24km', - 'temporal': '60min', - } + s2_model.set_model_params( + lr_features=['u_100m', 'v_100m', 'topography'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '24km', 'temporal': '60min'}, + ) _ = s2_model.generate(np.ones((4, 10, 10, 3))) fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') @@ -281,14 +279,13 @@ def test_fwp_multi_step_model_topo_noskip(input_files): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] - st_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - st_model.meta['input_resolution'] = { - 'spatial': '12km', - 'temporal': '60min', - } + st_model.set_model_params( + lr_features=['u_100m', 'v_100m', 'topography'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) with tempfile.TemporaryDirectory() as td: @@ -1044,8 +1041,14 @@ def test_solar_multistep_exo(gen_config_with_topo): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) model1.set_norm_stats({'clearsky_ratio': 0.7}, {'clearsky_ratio': 0.04}) - model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} - model1.set_model_params(lr_features=features1, hr_out_features=features1) + model1.set_model_params( + lr_features=features1, + hr_out_features=features1, + hr_exo_features=[], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '8km', 'temporal': '40min'}, + ) features2 = ['U_200m', 'V_200m', 'topography'] @@ -1076,6 +1079,8 @@ def test_solar_multistep_exo(gen_config_with_topo): lr_features=features2, hr_out_features=features2[:-1], hr_exo_features=features2[-1:], + s_enhance=2, + t_enhance=1, ) features_in_3 = ['clearsky_ratio', 'U_200m', 'V_200m'] @@ -1088,9 +1093,12 @@ def test_solar_multistep_exo(gen_config_with_topo): {'U_200m': 4.2, 'V_200m': 5.6, 'clearsky_ratio': 0.7}, {'U_200m': 1.1, 'V_200m': 1.3, 'clearsky_ratio': 0.04}, ) - model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} model3.set_model_params( - lr_features=features_in_3, hr_out_features=features_out_3 + lr_features=features_in_3, + hr_out_features=features_out_3, + s_enhance=1, + t_enhance=8, + input_resolution={'spatial': '2km', 'temporal': '40min'}, ) with tempfile.TemporaryDirectory() as td: diff --git a/tests/forward_pass/test_forward_pass_obs.py b/tests/forward_pass/test_forward_pass_obs.py index f35793ad5d..333c38a3da 100644 --- a/tests/forward_pass/test_forward_pass_obs.py +++ b/tests/forward_pass/test_forward_pass_obs.py @@ -118,12 +118,14 @@ def test_fwp_with_obs(input_file, obs_file, gen_config_with_obs_2d, request): model = Sup3rGan( gen_config_with_obs_2d(), pytest.S_FP_DISC, learning_rate=1e-4 ) - model.meta['input_resolution'] = {'spatial': '16km', 'temporal': '3600min'} - model.meta['lr_features'] = ['u_10m', 'v_10m'] - model.meta['hr_exo_features'] = ['u_10m_obs', 'v_10m_obs'] - model.meta['hr_out_features'] = ['u_10m', 'v_10m'] - model.meta['s_enhance'] = 2 - model.meta['t_enhance'] = 1 + model.set_model_params( + input_resolution={'spatial': '16km', 'temporal': '3600min'}, + lr_features=['u_10m', 'v_10m'], + hr_exo_features=['u_10m_obs', 'v_10m_obs'], + hr_out_features=['u_10m', 'v_10m'], + s_enhance=2, + t_enhance=1, + ) with tempfile.TemporaryDirectory() as td: exo_tmp = { diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index f77b0258f0..4977eca9eb 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -33,10 +33,16 @@ def test_multi_step_model(features): model1 = Sup3rGan(fp_gen1, fp_disc) model2 = Sup3rGan(fp_gen2, fp_disc) - model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} - model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=features, hr_out_features=features) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '27km', 'temporal': '64min'}, + ) + model2.set_model_params( + lr_features=features, + hr_out_features=features, + input_resolution={'spatial': '9km', 'temporal': '16min'}, + ) _ = model1.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) _ = model2.generate(np.ones((4, 10, 10, 6, len(features)))) @@ -93,12 +99,21 @@ def test_multi_step_norm(norm_option): {'u_100m': 0.1, 'v_100m': 0.8}, {'u_100m': 0.04, 'v_100m': 0.02} ) - model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} - model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} - model3.meta['input_resolution'] = {'spatial': '3km', 'temporal': '4min'} - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model3.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '27km', 'temporal': '64min'}, + ) + model2.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '9km', 'temporal': '16min'}, + ) + model3.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '3km', 'temporal': '4min'}, + ) _ = model1.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) _ = model2.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) @@ -146,11 +161,16 @@ def test_spatial_then_temporal_gan(): {'u_100m': 0.3, 'v_100m': 0.9}, {'u_100m': 0.02, 'v_100m': 0.07} ) - model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} - model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} - - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '12km', 'temporal': '40min'}, + ) + model2.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '6km', 'temporal': '40min'}, + ) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -184,11 +204,16 @@ def test_temporal_then_spatial_gan(): {'u_100m': 0.3, 'v_100m': 0.9}, {'u_100m': 0.02, 'v_100m': 0.07} ) - model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} - model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} - - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '12km', 'temporal': '40min'}, + ) + model2.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '6km', 'temporal': '40min'}, + ) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -215,8 +240,11 @@ def test_spatial_gan_then_linear_interp(): model1.set_norm_stats( {'u_100m': 0.1, 'v_100m': 0.2}, {'u_100m': 0.04, 'v_100m': 0.02} ) - model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -241,8 +269,13 @@ def test_solar_multistep(): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) model1.set_norm_stats({'clearsky_ratio': 0.7}, {'clearsky_ratio': 0.04}) - model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} - model1.set_model_params(lr_features=features1, hr_out_features=features1) + model1.set_model_params( + lr_features=features1, + hr_out_features=features1, + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '8km', 'temporal': '40min'}, + ) features2 = ['U_200m', 'V_200m'] fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') @@ -252,8 +285,13 @@ def test_solar_multistep(): model2.set_norm_stats( {'U_200m': 4.2, 'V_200m': 5.6}, {'U_200m': 1.1, 'V_200m': 1.3} ) - model2.meta['input_resolution'] = {'spatial': '4km', 'temporal': '40min'} - model2.set_model_params(lr_features=features2, hr_out_features=features2) + model2.set_model_params( + lr_features=features2, + hr_out_features=features2, + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '4km', 'temporal': '40min'}, + ) features_in_3 = ['clearsky_ratio', 'U_200m', 'V_200m'] features_out_3 = ['clearsky_ratio'] @@ -265,9 +303,12 @@ def test_solar_multistep(): {'U_200m': 4.2, 'V_200m': 5.6, 'clearsky_ratio': 0.7}, {'U_200m': 1.1, 'V_200m': 1.3, 'clearsky_ratio': 0.04}, ) - model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} model3.set_model_params( - lr_features=features_in_3, hr_out_features=features_out_3 + lr_features=features_in_3, + hr_out_features=features_out_3, + s_enhance=1, + t_enhance=8, + input_resolution={'spatial': '2km', 'temporal': '40min'}, ) with tempfile.TemporaryDirectory() as td: diff --git a/tests/forward_pass/test_surface_model.py b/tests/forward_pass/test_surface_model.py index 5ea91720e7..9f6ae554b7 100644 --- a/tests/forward_pass/test_surface_model.py +++ b/tests/forward_pass/test_surface_model.py @@ -159,8 +159,8 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): temporal_model_kwargs={'model_dirs': temporal_dir}) for model in ms_model.models: - assert isinstance(model.s_enhance, int) - assert isinstance(model.t_enhance, int) + assert isinstance(model.s_enhance, (int, np.integer)) + assert isinstance(model.t_enhance, (int, np.integer)) x = np.ones((2, 10, 10, len(FEATURES))) with pytest.raises(AssertionError): diff --git a/tests/output/test_qa.py b/tests/output/test_qa.py index 9854591300..1b7f11ff81 100644 --- a/tests/output/test_qa.py +++ b/tests/output/test_qa.py @@ -55,10 +55,13 @@ def test_qa(input_files, ext): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(TRAIN_FEATURES)))) - model.meta['lr_features'] = TRAIN_FEATURES - model.meta['hr_out_features'] = MODEL_OUT_FEATURES - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=TRAIN_FEATURES, + hr_out_features=MODEL_OUT_FEATURES, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') model.save(out_dir) diff --git a/tests/pipeline/test_cli.py b/tests/pipeline/test_cli.py index a91d1d8c74..adfd8d68c3 100644 --- a/tests/pipeline/test_cli.py +++ b/tests/pipeline/test_cli.py @@ -201,20 +201,21 @@ def test_fwd_pass_with_bc_cli(runner, input_files): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - assert model.s_enhance == 3 - assert model.t_enhance == 4 + model.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') model.save(out_dir) - n_chunks = np.prod( - [ - int(np.ceil(ds / fs)) - for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) - ] - ) + n_chunks = np.prod([ + int(np.ceil(ds / fs)) + for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) + ]) out_files = os.path.join(td, 'out_{file_id}.nc') cache_pattern = os.path.join(td, 'cache_{feature}.nc') log_pattern = os.path.join(td, 'logs', 'log_{node_index}.log') @@ -299,20 +300,21 @@ def test_fwd_pass_cli(runner, input_files): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - assert model.s_enhance == 3 - assert model.t_enhance == 4 + model.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') model.save(out_dir) - n_chunks = np.prod( - [ - int(np.ceil(ds / fs)) - for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) - ] - ) + n_chunks = np.prod([ + int(np.ceil(ds / fs)) + for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) + ]) out_files = os.path.join(td, 'out_{file_id}.nc') cache_pattern = os.path.join(td, 'cache_{feature}.nc') log_pattern = os.path.join(td, 'logs', 'log_{node_index}.log') @@ -359,15 +361,17 @@ def test_pipeline_fwp_qa(runner, input_files): Sup3rGan.seed() model = Sup3rGan(pytest.ST_FP_GEN, pytest.ST_FP_DISC, learning_rate=1e-4) + model.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - assert model.s_enhance == 3 - assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index d05225bb2c..465fec9f92 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -43,13 +43,15 @@ def test_fwp_pipeline_with_bc(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 @@ -179,13 +181,15 @@ def test_fwp_pipeline(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 @@ -290,11 +294,13 @@ def test_fwp_pipeline_with_mask(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) test_context = click.Context(click.Command('pipeline'), obj={}) with tempfile.TemporaryDirectory() as td, test_context as ctx: @@ -391,13 +397,15 @@ def test_multiple_fwp_pipeline(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 diff --git a/tests/samplers/test_with_obs.py b/tests/samplers/test_with_obs.py index 6c722fad7b..38cda97758 100644 --- a/tests/samplers/test_with_obs.py +++ b/tests/samplers/test_with_obs.py @@ -139,7 +139,7 @@ def test_proxy_obs_onshore_offshore_topography_fractions(sampler_cls): sampler.data.high_res['topography'] = (topo_var.dims, topo) batch = _get_hr_batch(sampler) - topo_idx = sampler.hr_features.index('topography') + topo_idx = sampler.hr_source_features.index('topography') topo_sample = batch[..., topo_idx] obs = batch[..., -2:] diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 3bd154132b..1a8f5f0100 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -210,10 +210,10 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): learning_rate=lr, loss={'MeanAbsoluteError': {}, 'MeanSquaredError': {}}, ) - dummy.meta['lr_features'] = model.meta['lr_features'] - dummy.meta['hr_features'] = model.meta['hr_features'] - dummy.meta['hr_exo_features'] = model.meta['hr_exo_features'] - dummy.meta['hr_out_features'] = model.meta['hr_out_features'] + dummy.set_model_params( + input_resolution={'spatial': '30km', 'temporal': '60min'}, + batch_handler=batch_handler, + ) for batch in batch_handler: out_og = model._tf_generate(batch.low_res) @@ -410,7 +410,9 @@ def test_input_res_check(): with pytest.raises(RuntimeError): model.set_model_params( - input_resolution={'spatial': '22km', 'temporal': '9min'} + input_resolution={'spatial': '22km', 'temporal': '9min'}, + s_enhance=3, + t_enhance=4, ) diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py index f6cd8d359a..66b4bb78ea 100644 --- a/tests/training/test_train_with_obs.py +++ b/tests/training/test_train_with_obs.py @@ -236,7 +236,8 @@ def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): assert np.sum(np.diff(tloss)) < 0 -def test_train_obs_with_topo(request): +@pytest.mark.parametrize('lr_only_features', [[], ['temperature_2m']]) +def test_train_obs_with_topo(lr_only_features, request): """Test training with topo and obs. Make sure exo features are properly concatenated.""" @@ -250,8 +251,14 @@ def test_train_obs_with_topo(request): } train_handler = DataHandler(**kwargs, time_slice=slice(None, 3000, 10)) - val_handler = DataHandler(**kwargs, time_slice=slice(3000, None, 10)) + + # Add dummy lr only features + if lr_only_features: + for feat in lr_only_features: + train_handler[feat] = train_handler[FEATURES_W[0]].copy() + val_handler[feat] = val_handler[FEATURES_W[0]].copy() + batcher = BatchHandler( [train_handler], [val_handler], @@ -262,7 +269,7 @@ def test_train_obs_with_topo(request): sample_shape=(20, 20, 10), proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1}}, feature_sets={ - 'lr_features': FEATURES_W, + 'lr_features': [*lr_only_features, *FEATURES_W], 'hr_exo_features': [ 'topography', *[f'{feat}_obs' for feat in FEATURES_W], @@ -296,3 +303,8 @@ def test_train_obs_with_topo(request): } model.train(batcher, **model_kwargs) + + loss = model.history['train_geothermal_physics_loss_with_obs'].values + assert not np.isnan(loss).any() + gloss = model.history['train_loss_gen'].values + assert not np.isnan(gloss).any() diff --git a/tests/utilities/test_loss_metrics.py b/tests/utilities/test_loss_metrics.py index c4ba3f96ab..fa1b6e8b98 100644 --- a/tests/utilities/test_loss_metrics.py +++ b/tests/utilities/test_loss_metrics.py @@ -307,8 +307,11 @@ def test_multiterm_loss(): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - model.meta['hr_out_features'] = ['u_100m', 'v_100m', 'temp_100m'] - model.meta['hr_features'] = ['u_100m', 'v_100m', 'temp_100m'] + model.set_model_params( + lr_features=['u_100m', 'v_100m', 'temp_100m'], + hr_out_features=['u_100m', 'v_100m', 'temp_100m'], + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) multi_loss = model.get_loss_fun({ 'MaterialDerivativeLoss': { 'gen_features': ['u_100m', 'v_100m', 'temp_100m'] From bbdcb0a6505c6d1b6039f4e28e79a3442fef93d4 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 15 Mar 2026 10:27:48 -0600 Subject: [PATCH 14/26] fix: update enhancement factor calculations and model initialization in tests --- sup3r/models/interface.py | 28 ++++++++++++++++++++++------ tests/training/test_train_solar.py | 3 ++- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 12c92b4c6f..e4b0362d8b 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -100,22 +100,38 @@ def s_enhance(self): """Factor by which model will enhance spatial resolution. Used in model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" + + # If there are multiple steps, we want to return the product of the + # enhancement factors models = getattr(self, 'models', [self]) - s_enhances = [m.meta.get('s_enhance', 1) for m in models] + s_enhances = [m.meta.get('s_enhance', None) for m in models] + s_enhance = ( + None + if any(enh is None for enh in s_enhances) + else np.prod(s_enhances) + ) if len(models) == 1 and isinstance(self.meta, dict): - self.meta['s_enhance'] = np.prod(s_enhances) - return np.prod(s_enhances) + self.meta['s_enhance'] = s_enhance + return s_enhance @property def t_enhance(self): """Factor by which model will enhance temporal resolution. Used in model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" + + # If there are multiple steps, we want to return the product of the + # enhancement factors models = getattr(self, 'models', [self]) - t_enhances = [m.meta.get('t_enhance', 1) for m in models] + t_enhances = [m.meta.get('t_enhance', None) for m in models] + t_enhance = ( + None + if any(enh is None for enh in t_enhances) + else np.prod(t_enhances) + ) if len(models) == 1 and isinstance(self.meta, dict): - self.meta['t_enhance'] = np.prod(t_enhances) - return np.prod(t_enhances) + self.meta['t_enhance'] = t_enhance + return t_enhance @property def s_enhancements(self): diff --git a/tests/training/test_train_solar.py b/tests/training/test_train_solar.py index 27ddaff4dd..0b82b30a54 100644 --- a/tests/training/test_train_solar.py +++ b/tests/training/test_train_solar.py @@ -62,7 +62,8 @@ def test_solar_cc_model(hr_steps): Sup3rGan.seed() model = SolarCC( - fp_gen, fp_disc, learning_rate=1e-4, loss='MeanAbsoluteError' + fp_gen, fp_disc, learning_rate=1e-4, loss='MeanAbsoluteError', + t_enhance=8 ) with tempfile.TemporaryDirectory() as td: From e15ae1bfae5e14b77677a834f53ad21d449b2367 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 15 Mar 2026 11:02:43 -0600 Subject: [PATCH 15/26] fix: update model parameter settings for multi-step GAN tests --- tests/forward_pass/test_multi_step.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index 4977eca9eb..716b97321a 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -37,11 +37,15 @@ def test_multi_step_model(features): lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '27km', 'temporal': '64min'}, + s_enhance=3, + t_enhance=4, ) model2.set_model_params( lr_features=features, hr_out_features=features, input_resolution={'spatial': '9km', 'temporal': '16min'}, + s_enhance=3, + t_enhance=4, ) _ = model1.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) @@ -103,16 +107,22 @@ def test_multi_step_norm(norm_option): lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '27km', 'temporal': '64min'}, + s_enhance=3, + t_enhance=4, ) model2.set_model_params( lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '9km', 'temporal': '16min'}, + s_enhance=3, + t_enhance=4, ) model3.set_model_params( lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '3km', 'temporal': '4min'}, + s_enhance=3, + t_enhance=4, ) _ = model1.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) @@ -165,11 +175,15 @@ def test_spatial_then_temporal_gan(): lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '12km', 'temporal': '40min'}, + s_enhance=2, + t_enhance=1, ) model2.set_model_params( lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '6km', 'temporal': '40min'}, + s_enhance=3, + t_enhance=4, ) with tempfile.TemporaryDirectory() as td: @@ -208,11 +222,15 @@ def test_temporal_then_spatial_gan(): lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '12km', 'temporal': '40min'}, + s_enhance=2, + t_enhance=1, ) model2.set_model_params( lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '6km', 'temporal': '40min'}, + s_enhance=3, + t_enhance=4, ) with tempfile.TemporaryDirectory() as td: @@ -244,6 +262,8 @@ def test_spatial_gan_then_linear_interp(): lr_features=FEATURES, hr_out_features=FEATURES, input_resolution={'spatial': '12km', 'temporal': '60min'}, + s_enhance=2, + t_enhance=1, ) with tempfile.TemporaryDirectory() as td: From 17d568cecd97fdb78a9b0fcecc753f8df1440739 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 15 Mar 2026 11:30:37 -0600 Subject: [PATCH 16/26] fix: add s_enhance and t_enhance parameters to multiterm loss test --- tests/utilities/test_loss_metrics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utilities/test_loss_metrics.py b/tests/utilities/test_loss_metrics.py index fa1b6e8b98..9669d58970 100644 --- a/tests/utilities/test_loss_metrics.py +++ b/tests/utilities/test_loss_metrics.py @@ -311,6 +311,8 @@ def test_multiterm_loss(): lr_features=['u_100m', 'v_100m', 'temp_100m'], hr_out_features=['u_100m', 'v_100m', 'temp_100m'], input_resolution={'spatial': '12km', 'temporal': '60min'}, + s_enhance=1, + t_enhance=1, ) multi_loss = model.get_loss_fun({ 'MaterialDerivativeLoss': { From 76484e8e4e69906c1d8a39104f17fb523510b3bf Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 18 Mar 2026 08:19:16 -0600 Subject: [PATCH 17/26] small doc string clarification --- sup3r/preprocessing/rasterizers/exo.py | 28 +++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 04ebfcf15f..b97b61915f 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -35,16 +35,14 @@ class BaseExoRasterizer(ABC): """Class to extract high-res (4km+) data rasters for new spatially-enhanced datasets (e.g. GCM files after spatial enhancement) using nearest neighbor - mapping and aggregation from high-res datasets (e.g. WTK or NSRDB) + mapping and aggregation from high-res datasets (e.g. WTK or NSRDB). Parameters ---------- feature : str Name of exogenous feature to rasterize. file_paths : str | list - Filepaths(s) for typically low-res WRF output or GCM netcdf data files - that is source low-resolution data intended to be sup3r resolved. - These are used to define the grid that the high-resolution exogenous + Filepaths(s) used to define the grid that the high-resolution exogenous data will be mapped onto. This can be either a single h5 file or a list of netcdf files with identical grid. The string can be a unix-style file path which will be passed through glob.glob. @@ -557,7 +555,27 @@ class ExoRasterizer(BaseExoRasterizer, metaclass=Sup3rMeta): def __new__(cls, feature, file_paths, source_files=None, **kwargs): """Override parent class to return type specific class based on - `source_files`""" + `source_files` + + Parameters + ---------- + feature : str + Name of exogenous feature to rasterize. If the feature name ends + with '_obs' then the `ObsRasterizer` will be used which is designed + for sparse spatiotemporal observation data. If the feature name is + 'sza' then the `SzaRasterizer` will be used which computes solar + zenith angle from the lat/lon and time. Otherwise, the + `BaseExoRasterizer` will be used which is designed for more dense + spatiotemporal data like topography or srl. + file_paths : str | list + Filepaths(s) used to define the grid that the high-resolution + exogenous data will be mapped onto. This can be either a single h5 + file or a list of netcdf files with identical grid. The string can + be a unix-style file path which will be passed through glob.glob. + source_files : str | list | None + Filepath(s) to hi-res exogenous data, which will be mapped to the + enhanced grid of the file_paths input. + """ if feature.lower() == 'sza': ExoClass = SzaRasterizer elif feature.lower().endswith('_obs'): From b248a824469c2b5f7c7c5a63f9e7591025b367fa Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 18 Mar 2026 08:25:36 -0600 Subject: [PATCH 18/26] doc string edit and notebook typo removed --- examples/sup3rwind/running_sup3r_models.ipynb | 1 - sup3r/preprocessing/data_handlers/exo.py | 15 ++++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/sup3rwind/running_sup3r_models.ipynb b/examples/sup3rwind/running_sup3r_models.ipynb index 65affe6677..07577d42fa 100644 --- a/examples/sup3rwind/running_sup3r_models.ipynb +++ b/examples/sup3rwind/running_sup3r_models.ipynb @@ -241,7 +241,6 @@ " 'out_range': [-100, 100],\n", " 'temporal_avg': False,\n", " }\n", - "}\n", "\n", "config = {\n", " 'pass_workers': 1,\n", diff --git a/sup3r/preprocessing/data_handlers/exo.py b/sup3r/preprocessing/data_handlers/exo.py index 4d337663fb..e1aed21fc6 100644 --- a/sup3r/preprocessing/data_handlers/exo.py +++ b/sup3r/preprocessing/data_handlers/exo.py @@ -302,11 +302,10 @@ def __init__( feature : str Exogenous feature to extract from file_paths file_paths : str | list - A single source h5 file or netcdf file to extract raster - data from. The string can be a unix-style file path which - will be passed through glob.glob. This is typically low-res - WRF output or GCM netcdf data that is source low-resolution - data intended to be sup3r resolved. + Filepaths(s) used to define the grid that the high-resolution + exogenous data will be mapped onto. This can be either a single h5 + file or a list of netcdf files with identical grid. The string can + be a unix-style file path which will be passed through glob.glob. model : Sup3rGan | MultiStepGan Model used to get exogenous data. If a ``MultiStepGan`` ``lr_features``, ``hr_exo_features``, and @@ -331,7 +330,9 @@ def __init__( exo_rasterizer_kwargs : dict Keyword arguments passed to the :class:`~sup3r.preprocessing.rasterizers.exo.BaseExoRasterizer` - class. + class. This is used to specify parameters for exogenous data + rasterization such as the ``source_files`` for the exogenous data + and the method of rasterization. """ self.feature = feature self.file_paths = file_paths @@ -465,7 +466,7 @@ def _get_single_step_enhance(self, step): 'Received exo_kwargs entry without valid combine_type ' '(input/layer/output)' ) - assert combine_type.lower() in ('input', 'output', 'layer'), msg + assert combine_type.lower() in {'input', 'output', 'layer'}, msg if combine_type.lower() == 'input': if mstep == 0: s_enhance = 1 From 04208ad8a1e0bbef3725f79e8a475cff594bb965 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 18 Mar 2026 10:39:04 -0600 Subject: [PATCH 19/26] log message about exo rasterizer selection --- sup3r/preprocessing/rasterizers/exo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index b97b61915f..77219d3d1a 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -589,7 +589,6 @@ def __new__(cls, feature, file_paths, source_files=None, **kwargs): 'feature': feature, **kwargs, } - logger.info( f'Using {ExoClass.__name__} to rasterize feature "{feature}"' ) From bd9a8b9a153d689c90d96b2cc83f28f554de01b6 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 18 Mar 2026 12:17:51 -0600 Subject: [PATCH 20/26] fix: typo --- sup3r/preprocessing/rasterizers/exo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 77219d3d1a..cc1a3b5a69 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -496,7 +496,7 @@ def source_handler(self): features=[feat], **self.source_handler_kwargs, ) - return self._source_handlers + return self._source_handler @property def source_data(self): From 93791d0884e06f9317eaf60e673b2ef2bb314734 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 19 Mar 2026 10:26:42 -0600 Subject: [PATCH 21/26] *removed sparse_disc argument - use train_disc arg to skip disc calcs instead. *added back raise error if missing gapless exo feature inputs to layers --- sup3r/models/abstract.py | 13 ++++-- sup3r/models/base.py | 60 ++++++++++++--------------- sup3r/models/solar_cc.py | 11 +++-- sup3r/preprocessing/samplers/base.py | 13 ++---- sup3r/preprocessing/samplers/dual.py | 5 +-- tests/training/test_train_with_obs.py | 4 +- 6 files changed, 52 insertions(+), 54 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 1141512827..65cecdcd26 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -49,7 +49,6 @@ def __init__(self): self._gen = None self._means = None self._stdevs = None - self._sparse_disc = None self._train_record = pd.DataFrame() self._val_record = pd.DataFrame() @@ -1047,8 +1046,8 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): features = getattr(layer, 'features', [layer.name]) exo_features = getattr(layer, 'exo_features', []) for feat in features + exo_features: - missing_feat = feat in features and feat not in exogenous_data - if missing_feat: + missing_feat = feat not in exogenous_data + if missing_feat and '_obs' in feat: msg = ( f'{feat} does not match any features in exogenous_data ' f'({list(exogenous_data)}). Will try to run without this ' @@ -1056,6 +1055,14 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): ) logger.warning(msg) continue + elif missing_feat: + msg = ( + f'{feat} does not match any features in exogenous_data ' + f'({list(exogenous_data)}). This feature is required for ' + f'layer {layer.name}.' + ) + logger.error(msg) + raise KeyError(msg) exo = exogenous_data.get_combine_type_data(feat, 'layer') exo = self._reshape_norm_exo( input_array, diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 9de1e046b2..b00877e31d 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -38,7 +38,6 @@ def __init__( stdevs=None, default_device=None, name=None, - sparse_disc=False, ): """ Parameters @@ -100,11 +99,6 @@ def __init__( "/gpu:0" or "/cpu:0" name : str | None Optional name for the GAN. - sparse_disc : bool - Whether the discriminator can accept sparse features as input. - If False, the discriminator will only receive the dense features - as input. If True, the discriminator will receive both dense and - sparse features as input. """ super().__init__() @@ -136,7 +130,6 @@ def __init__( self._means = means self._stdevs = stdevs - self._sparse_disc = sparse_disc def save(self, out_dir): """Save the GAN with its sub-networks to a directory. @@ -306,15 +299,7 @@ def _tf_discriminate(self, hi_res): Discriminator output logits """ - hr = ( - hi_res - if len(self.obs_features) == 0 or self._sparse_disc - else hi_res[..., : -len(self.obs_features)] - ) - if hr.shape[-1] == 0: - return tf.constant([], dtype=tf.float32) - - out = self.discriminator.layers[0](hr) + out = self.discriminator.layers[0](hi_res) layer_num = 1 try: for i, layer in enumerate(self.discriminator.layers[1:]): @@ -408,7 +393,7 @@ def weights(self): """ return self.generator_weights + self.discriminator_weights - def init_weights(self, lr_shape, hr_shape, device=None): + def init_weights(self, lr_shape, hr_shape, train_disc=False, device=None): """Initialize the generator and discriminator weights with device placement. @@ -422,12 +407,15 @@ def init_weights(self, lr_shape, hr_shape, device=None): Shape of one batch of high res input data for sup3r resolution. Note that the batch size (axis=0) must be included, but the actual batch size doesnt really matter. + train_disc : bool + Whether to initialize the discriminator weights. If False, only the + generator weights will be initialized. device : str | None Option to place model weights on a device. If None, self.default_device will be used. """ - if not self.generator_weights: + if not self.generator_weights or not self.discriminator_weights: if device is None: device = self.default_device @@ -451,7 +439,9 @@ def init_weights(self, lr_shape, hr_shape, device=None): f'{len(self.hr_out_features)}' ) assert out.shape[-1] == len(self.hr_out_features), msg - _ = self._tf_discriminate(hi_res) + + if train_disc: + _ = self._tf_discriminate(hi_res) @staticmethod def get_weight_update_fraction( @@ -870,30 +860,36 @@ def calc_loss( logger.error(msg) raise RuntimeError(msg) - disc_out_true = self._tf_discriminate(hi_res_true) - disc_out_gen = self._tf_discriminate(hi_res_gen) - loss_details = {} loss = None + disc_out_true = None + disc_out_gen = None + loss_gen_advers = None - if compute_disc or train_disc: + if train_disc or compute_disc: + disc_out_true = self._tf_discriminate(hi_res_true) + disc_out_gen = self._tf_discriminate(hi_res_gen) loss_details['loss_disc'] = self.calc_loss_disc( disc_out_true=disc_out_true, disc_out_gen=disc_out_gen ) + if train_gen and compute_disc: + loss_gen_advers = self.calc_loss_disc( + disc_out_true=disc_out_gen, disc_out_gen=disc_out_true + ) + loss_details['loss_gen_advers'] = loss_gen_advers + if train_gen: loss_gen_content, loss_gen_content_details = ( self.calc_loss_gen_content(hi_res_true, hi_res_gen) ) - loss_gen_advers = self.calc_loss_disc( - disc_out_true=disc_out_gen, disc_out_gen=disc_out_true + loss = ( + loss_gen_content + if loss_gen_advers is None + else weight_gen_advers * loss_gen_advers ) - loss = loss_gen_content - if weight_gen_advers > 0: - loss += weight_gen_advers * loss_gen_advers loss_details['loss_gen'] = loss loss_details['loss_gen_content'] = loss_gen_content - loss_details['loss_gen_advers'] = loss_gen_advers loss_details.update(loss_gen_content_details) elif train_disc: @@ -1131,11 +1127,7 @@ def _train_epoch( Namespace of the breakdown of loss components """ lr_shape, hr_shape = batch_handler.shapes - self.init_weights(lr_shape, hr_shape) - - self.init_weights( - (1, *batch_handler.lr_shape), (1, *batch_handler.hr_shape) - ) + self.init_weights(lr_shape, hr_shape, train_disc=train_disc) disc_th_low = np.min(disc_loss_bounds) disc_th_high = np.max(disc_loss_bounds) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index db36fd86bb..ebcb9a808f 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -63,7 +63,7 @@ def __init__(self, *args, t_enhance=None, **kwargs): self._t_enhance = t_enhance or self.t_enhance self.meta['t_enhance'] = self._t_enhance - def init_weights(self, lr_shape, hr_shape, device=None): + def init_weights(self, lr_shape, hr_shape, train_disc=True, device=None): """Initialize the generator and discriminator weights with device placement. @@ -77,6 +77,9 @@ def init_weights(self, lr_shape, hr_shape, device=None): Shape of one batch of high res input data for sup3r resolution. Note that the batch size (axis=0) must be included, but the actual batch size doesn't really matter. + train_disc : bool + Whether to initialize the discriminator weights. If False, only the + generator weights will be initialized. device : str | None Option to place model weights on a device. If None, self.default_device will be used. @@ -87,7 +90,9 @@ def init_weights(self, lr_shape, hr_shape, device=None): if hr_shape[3] != self.DAYLIGHT_HOURS: hr_shape = hr_shape[0:3] + (self.DAYLIGHT_HOURS,) + hr_shape[-1:] - super().init_weights(lr_shape, hr_shape, device=device) + super().init_weights( + lr_shape, hr_shape, train_disc=train_disc, device=device + ) @tf.function def calc_loss( @@ -97,7 +102,7 @@ def calc_loss( weight_gen_advers=0.001, train_gen=True, train_disc=False, - compute_disc=False + compute_disc=False, ): """Calculate the GAN loss function using generated and true high resolution data. diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index d9f2e1fd1e..d819f8f983 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -123,14 +123,7 @@ def use_proxy_obs(self): be generated from the gridded ground truth feature named ``temperature``. """ - check = bool(self.proxy_obs_kwargs) - check = check or ( - len(self.obs_features) > 0 - and all( - f not in self.hr_source_features for f in self.obs_features - ) - ) - return check + return bool(self.proxy_obs_kwargs) @property def onshore_obs_frac(self): @@ -571,8 +564,8 @@ def obs_features(self): """List of feature names or patt*erns that should be treated as observations. These features will be included in the high-res data but not the low-res data and won't necessarily be expected to be output by - the generative model. These are different from the `hr_exo_features` in - that they are intended to be used as observation features with NaN + the generative model. These are different from other `hr_exo_features` + in that they are intended to be used as observation features with NaN values where observations are not available.""" return [f for f in self.hr_source_features if '_obs' in f] diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index 4b0a865a35..ccd1dd9d5c 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -97,10 +97,9 @@ def __init__( f'{self.__class__.__name__} requires a Sup3rDataset object ' 'with `.low_res` and `.high_res` data members, in that order' ) - dnames = ['low_res', 'high_res'] - check = ( + check = all( hasattr(data, dname) and getattr(data, dname) == data[i] - for i, dname in enumerate(dnames) + for i, dname in enumerate(['low_res', 'high_res']) ) assert check, msg diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py index 66b4bb78ea..a425f8edf6 100644 --- a/tests/training/test_train_with_obs.py +++ b/tests/training/test_train_with_obs.py @@ -144,7 +144,9 @@ def test_train_cond_obs(gen_config, sample_shape, t_enhance, fp_disc, request): ], ) def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): - """Test model training with sparse high resolution ground truth data.""" + """Test model training with only sparse high resolution ground truth data. + This should skip any calculations involving the discriminator - since + train_disc=False""" gen_config = request.getfixturevalue(gen_config)() kwargs = { From b84ab4fadbe409cfd35fa3e0f3f2786bf60a48f8 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 20 Mar 2026 09:43:34 -0600 Subject: [PATCH 22/26] refactor: update daily_reduction checks to use set for improved performance; enhance documentation in local_qdm_bc and Sampler class --- sup3r/bias/base.py | 10 +++++----- sup3r/bias/bias_transforms.py | 7 ++----- sup3r/models/abstract.py | 5 ++--- sup3r/models/base.py | 14 +++++++------- sup3r/preprocessing/samplers/base.py | 25 +++++++++++++++---------- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/sup3r/bias/base.py b/sup3r/bias/base.py index d7761a8767..25664c714f 100644 --- a/sup3r/bias/base.py +++ b/sup3r/bias/base.py @@ -737,7 +737,7 @@ def _reduce_base_data( ) cs_ratio = ( - daily_reduction.lower() in ('avg', 'average', 'mean') + daily_reduction.lower() in {'avg', 'average', 'mean'} and base_dset == 'clearsky_ratio' ) @@ -756,16 +756,16 @@ def _reduce_base_data( ) assert not np.isnan(base_data).any(), msg - elif daily_reduction.lower() in ('avg', 'average', 'mean'): + elif daily_reduction.lower() in {'avg', 'average', 'mean'}: base_data = df.groupby('date').mean()['base_data'].values - elif daily_reduction.lower() in ('max', 'maximum'): + elif daily_reduction.lower() in {'max', 'maximum'}: base_data = df.groupby('date').max()['base_data'].values - elif daily_reduction.lower() in ('min', 'minimum'): + elif daily_reduction.lower() in {'min', 'minimum'}: base_data = df.groupby('date').min()['base_data'].values - elif daily_reduction.lower() in ('sum', 'total'): + elif daily_reduction.lower() in {'sum', 'total'}: base_data = df.groupby('date').sum()['base_data'].values msg = ( diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index d64418b62c..445440ca74 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -720,6 +720,8 @@ def local_qdm_bc( -------- sup3r.bias.qdm.QuantileDeltaMappingCorrection : Estimate probability distributions required by QDM method + rex.utilities.bc_utils.QuantileDeltaMapping : + Core QDM transformation. Notes ----- @@ -738,11 +740,6 @@ def local_qdm_bc( Also, :class:`rex.utilities.bc_utils.QuantileDeltaMapping` expects params to be 2D (space, N-params). - See Also - -------- - rex.utilities.bc_utils.QuantileDeltaMapping : - Core QDM transformation. - Examples -------- >>> unbiased = local_qdm_bc(biased_array, lat_lon_array, "ghi", "rsds", diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 65cecdcd26..df2c10886c 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1231,9 +1231,8 @@ def _get_hr_exo_and_loss( **calc_loss_kwargs, ): """Get high-resolution exogenous data, generate synthetic output, and - compute loss. Obs features (if present at the end of hi_res_true) are - extracted and added to exo_data, and trimmed from hi_res_true before - loss calculation.""" + compute loss. All hr_exo_features are extracted from hi_res_true and + added to exo_data.""" hi_res_exo = self.get_hr_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss, loss_details = self.calc_loss( diff --git a/sup3r/models/base.py b/sup3r/models/base.py index b00877e31d..03b922ab22 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -516,12 +516,6 @@ def calc_loss_disc(disc_out_true, disc_out_gen): encourages the generator to produce output which is "more realistic" than the true high-res data. - References - ---------- - .. [Wang2018] Wang, Xintao, et al. "Esrgan: Enhanced super-resolution - generative adversarial networks." Proceedings of the European - conference on computer vision (ECCV) workshops. 2018. - Parameters ---------- disc_out_true : tf.Tensor @@ -536,6 +530,12 @@ def calc_loss_disc(disc_out_true, disc_out_gen): loss_disc : tf.Tensor 0D tensor discriminator model loss for either the spatial or temporal component of the super resolution generated output. + + References + ---------- + .. [Wang2018] Wang, Xintao, et al. "Esrgan: Enhanced super-resolution + generative adversarial networks." Proceedings of the European + conference on computer vision (ECCV) workshops. 2018. """ true_logits = disc_out_true - tf.reduce_mean(disc_out_gen) fake_logits = disc_out_gen - tf.reduce_mean(disc_out_true) @@ -886,7 +886,7 @@ def calc_loss( loss = ( loss_gen_content if loss_gen_advers is None - else weight_gen_advers * loss_gen_advers + else loss_gen_content + weight_gen_advers * loss_gen_advers ) loss_details['loss_gen'] = loss loss_details['loss_gen_content'] = loss_gen_content diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index d819f8f983..14e9598b81 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -86,8 +86,8 @@ def __init__( should be treated as onshore and offshore observations, respectively. For example, ``proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1, 'temporal': 0.2}, 'offshore_obs_frac': {'spatial': - 0.05, 'temporal': 0.1}}`` would specify that for the onshore - region observations cover 10% of the spatial domain and 20% of the + 0.05, 'time': 0.1}}`` would specify that for the onshore region + observations cover 10% of the spatial domain and 20% of the temporal domain, while for the offshore region observations cover 5% of the spatial domain and 10% of the temporal domain. Instead of a single float, these can also be lists to specify a lower and @@ -149,6 +149,12 @@ def offshore_obs_frac(self): def get_sample_index(self, n_obs=None): """Randomly gets spatiotemporal sample index. + Returns + ------- + sample_index : tuple + Tuple of latitude slice, longitude slice, time slice, and features. + Used to get single observation like ``self.data[sample_index]`` + Notes ----- If ``n_obs > 1`` this will get a time slice with ``n_obs * @@ -156,12 +162,6 @@ def get_sample_index(self, n_obs=None): ``n_obs`` samples each with ``self.sample_shape[2]`` time steps. This is a much more efficient way of getting batches of samples but only works if there are enough continuous time steps to sample. - - Returns - ------- - sample_index : tuple - Tuple of latitude slice, longitude slice, time slice, and features. - Used to get single observation like ``self.data[sample_index]`` """ n_obs = n_obs or self.batch_size spatial_slice = uniform_box_sampler(self.shape, self.sample_shape[:2]) @@ -496,7 +496,7 @@ def _parse_features(self, unparsed_feats): if any('*' in fn for fn in parsed_feats): out = [] - for feature in self.hr_source_features: + for feature in self.features: match = any( fnmatch(feature.lower(), pattern.lower()) for pattern in parsed_feats @@ -632,6 +632,11 @@ def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): for locations that are observed. (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) + + Notes + ----- + The output mask is repeated along the feature dimension, so each + feature will have the same observation mask. """ s_range = ( spatial_frac @@ -644,7 +649,7 @@ def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): else [time_frac, time_frac] ) n_obs, n_spatial_1, n_spatial_2, n_temporal = hi_res.shape[:-1] - n_features = len(self.hr_out_features) + n_features = len(self.obs_features) s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=n_obs) t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=n_obs) From 65efd325c6297447c19e9410a4a2e295b39d57e3 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 21 Mar 2026 09:52:06 -0600 Subject: [PATCH 23/26] fix: still need `sparse_disc` flag for removing obs features from discrimination. --- sup3r/models/abstract.py | 71 +++++++++------ sup3r/models/base.py | 21 ++++- sup3r/models/interface.py | 8 +- sup3r/preprocessing/samplers/base.py | 10 ++- sup3r/preprocessing/samplers/dual.py | 2 +- sup3r/utilities/loss_metrics.py | 15 +++- tests/training/test_train_with_obs.py | 121 ++++++++++++++++++++++++-- 7 files changed, 203 insertions(+), 45 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index df2c10886c..dec28284d1 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -14,10 +14,10 @@ import pandas as pd import tensorflow as tf from gaps.config import load_config -from phygnn import CustomNetwork from tensorflow.keras import optimizers import sup3r.utilities.loss_metrics +from phygnn import CustomNetwork from sup3r.preprocessing.data_handlers import ExoData from sup3r.preprocessing.utilities import numpy_if_tensor from sup3r.utilities import VERSION_RECORD @@ -462,38 +462,36 @@ def _combine_loss_input(self, hi_res_true, hi_res_gen): def _get_loss_inputs(self, hi_res_gen, hi_res_true, loss_func): """Get inputs for the given loss function according to the required - input features""" + generator output features and ground truth features. If the loss + function doesn't specify required features, this will default to using + all output features that are not exogenous features.""" gen_feats = getattr(loss_func, 'gen_features', 'all') true_feats = getattr(loss_func, 'true_features', 'all') - if gen_feats != 'all' and not all( - f in self.hr_features for f in gen_feats - ): + if gen_feats == 'all': + gen_feats = self.hr_out_features + if true_feats == 'all': + true_feats = self.hr_out_features + + if not all(f in self.hr_out_features for f in gen_feats): msg = ( f'{loss_func} requires gen_features: ' - f'{loss_func.gen_features}, but these are not found ' - f'in the high-resolution features: {self.hr_features}' + f'{loss_func.gen_features}, but these are not found in the ' + f'high-resolution output features: {self.hr_out_features}' ) logger.error(msg) raise ValueError(msg) - if true_feats != 'all' and not all( - f in self.hr_features for f in true_feats - ): + if not all(f in self.hr_features for f in true_feats): msg = ( f'{loss_func} requires true_features: ' f'{loss_func.true_features}, but these are not found ' - f'in the high-resolution output features: {self.hr_features}' + f'in the high-resolution features: {self.hr_features}' ) logger.error(msg) raise ValueError(msg) - if gen_feats == 'all': - gen_feats = self.hr_features - if true_feats == 'all': - true_feats = self.hr_features - gen_inds = [self.hr_features.index(f) for f in gen_feats] true_inds = [self.hr_features.index(f) for f in true_feats] @@ -541,6 +539,15 @@ def loss_fun(hi_res_gen, hi_res_true): ) val = loss_func(hr_gen, hr_true) loss_details[camel_to_underscore(ln)] = val + if tf.math.reduce_any(tf.math.is_nan(val)): + msg = ( + f'NaN values found for loss term "{ln}" with value ' + f'{val} when running loss function {loss_func} on ' + f'generated tensor of shape {hi_res_gen.shape} and ' + f'true tensor of shape {hi_res_true.shape}' + ) + logger.error(msg) + raise ValueError(msg) loss += weights[i] * val return loss, loss_details @@ -1040,6 +1047,16 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): of shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | ExoData + Special dictionary (class:`ExoData`) of exogenous feature data with + entries describing whether features should be combined at input, a + mid network layer, or with output. This doesn't have to include the + 'model' key since this data is for a single step model. + norm_in : bool + Flag to normalize low_res input data if the self._means, + self._stdevs attributes are available. The generator should always + received normalized data with mean=0 stdev=1. This also normalizes + exogenous data. """ feat_stack = [] extras = [] @@ -1055,7 +1072,7 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): ) logger.warning(msg) continue - elif missing_feat: + if missing_feat: msg = ( f'{feat} does not match any features in exogenous_data ' f'({list(exogenous_data)}). This feature is required for ' @@ -1100,7 +1117,7 @@ def generate( Flag to normalize low_res input data if the self._means, self._stdevs attributes are available. The generator should always received normalized data with mean=0 stdev=1. This also normalizes - hi_res_topo. + exogenous data. un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units @@ -1207,20 +1224,20 @@ def _tf_generate(self, low_res, hi_res_exo=None): """ hi_res = self.generator.layers[0](low_res) layer_num = 1 - for i, layer in enumerate(self.generator.layers[1:]): - try: + try: + for i, layer in enumerate(self.generator.layers[1:]): layer_num = i + 1 if isinstance(layer, SUP3R_LAYERS): hi_res = self._run_exo_layer(layer, hi_res, hi_res_exo) else: hi_res = layer(hi_res) - except Exception as e: - msg = ( - f'Could not run layer #{layer_num} "{layer}" on tensor ' - f'of shape {hi_res.shape}' - ) - logger.error(msg) - raise RuntimeError(msg) from e + except Exception as e: + msg = ( + f'Could not run layer #{layer_num} "{layer}" on tensor ' + f'of shape {hi_res.shape}' + ) + logger.error(msg) + raise RuntimeError(msg) from e return hi_res diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 03b922ab22..7f9a38a074 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -38,6 +38,7 @@ def __init__( stdevs=None, default_device=None, name=None, + sparse_disc=False, ): """ Parameters @@ -97,6 +98,15 @@ def __init__( (this was tested as most efficient given the custom multi-gpu strategy developed in self.run_gradient_descent()). Examples: "/gpu:0" or "/cpu:0" + sparse_disc : bool + Flag to indicate if the discriminator can handle sparse input data. + If False, the discriminator will expect input data with no missing + values. If True, the discriminator will be able to handle input + data with missing values, which may be the case when using + observations for training. Note that if True, the discriminator + model architecture should be designed to handle sparse data (e.g. + by using masking layers or other techniques). + name : str | None Optional name for the GAN. """ @@ -130,6 +140,7 @@ def __init__( self._means = means self._stdevs = stdevs + self._sparse_disc = sparse_disc def save(self, out_dir): """Save the GAN with its sub-networks to a directory. @@ -298,8 +309,12 @@ def _tf_discriminate(self, hi_res): out : np.ndarray Discriminator output logits """ - - out = self.discriminator.layers[0](hi_res) + hr = ( + hi_res + if self._sparse_disc + else tf.gather(hi_res, indices=self.hr_out_features_ind, axis=-1) + ) + out = self.discriminator.layers[0](hr) layer_num = 1 try: for i, layer in enumerate(self.discriminator.layers[1:]): @@ -425,7 +440,7 @@ def init_weights(self, lr_shape, hr_shape, train_disc=False, device=None): low_res = tf.cast(np.ones(lr_shape), dtype=tf.float32) hi_res = tf.cast(np.ones(hr_shape), dtype=tf.float32) - hr_exo_shape = hr_shape[:-1] + (1,) + hr_exo_shape = (*hr_shape[:-1], 1) hr_exo = tf.cast(np.ones(hr_exo_shape), dtype=tf.float32) with tf.device(device): diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index e4b0362d8b..e7f0d7d603 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -9,8 +9,8 @@ from warnings import warn import numpy as np -from phygnn import CustomNetwork +from phygnn import CustomNetwork from sup3r.preprocessing.data_handlers import ExoData from sup3r.utilities import VERSION_RECORD from sup3r.utilities.utilities import safe_cast @@ -347,6 +347,12 @@ def hr_out_features(self): generative model outputs.""" return self.meta.get('hr_out_features', []) + @property + def hr_out_features_ind(self): + """Get the indices of the high-resolution output features in the order + they are output by the model.""" + return [self.hr_features.index(feat) for feat in self.hr_out_features] + @property def obs_features(self): """Get list of exogenous observation feature names the model uses. diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 14e9598b81..f7e69f428a 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -169,9 +169,9 @@ def get_sample_index(self, n_obs=None): self.shape, self.sample_shape[2] * n_obs ) feats = ( - self.hr_source_features - if not self.use_proxy_obs - else self.hr_source_features[: -len(self.obs_features)] + [f for f in self.hr_source_features if f not in self.obs_features] + if self.use_proxy_obs + else self.hr_source_features ) return (*spatial_slice, time_slice, feats) @@ -636,7 +636,9 @@ def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): Notes ----- The output mask is repeated along the feature dimension, so each - feature will have the same observation mask. + feature will have the same observation mask. The output mask is not + repeated along the batch dimension, so each sample in the batch will + have a different observation mask. """ s_range = ( spatial_frac diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index ccd1dd9d5c..a3c26bcfa2 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -182,7 +182,7 @@ def get_sample_index(self, n_obs=None): for s in lr_index[2:-1] ] hr_feats = ( - self.hr_source_features[: -len(self.obs_features)] + [f for f in self.hr_source_features if f not in self.obs_features] if self.use_proxy_obs else self.hr_source_features ) diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index e7c950ddbd..cda2188898 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -681,7 +681,7 @@ def __init__(self, n_projections=1024): n_projections : int number of random 1D projections to use - Note + Note: ---- Experimentally, we get stability in the SW metric when n_projections is at least 30% of the number of projection dimensions, which for us @@ -888,7 +888,14 @@ def __call__(self, x1, x2): ) assert check, msg - mask = tf.math.is_nan(x2) - return self.LOSS_METRIC( - x1[tf.math.logical_not(mask)], x2[tf.math.logical_not(mask)] + mask = tf.math.logical_not(tf.math.is_nan(x2)) + x1m = tf.boolean_mask(x1, mask) + x2m = tf.boolean_mask(x2, mask) + + physics_loss = tf.constant(1e-3, dtype=x1.dtype) + obs_loss = ( + tf.constant(0, dtype=x1.dtype) + if tf.math.reduce_all(tf.math.is_nan(x2m)) + else self.LOSS_METRIC(x1m, x2m) ) + return physics_loss + obs_loss diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py index a425f8edf6..d92e6ac342 100644 --- a/tests/training/test_train_with_obs.py +++ b/tests/training/test_train_with_obs.py @@ -36,7 +36,9 @@ ('gen_config_with_obs_3d', (20, 20, 10), 2, pytest.ST_FP_DISC), ], ) -def test_train_cond_obs(gen_config, sample_shape, t_enhance, fp_disc, request): +def test_train_proxy_obs( + gen_config, sample_shape, t_enhance, fp_disc, request +): """Test a special model which conditions model output on observations with a ``Sup3rConcatObs`` layer.""" @@ -143,7 +145,7 @@ def test_train_cond_obs(gen_config, sample_shape, t_enhance, fp_disc, request): ('gen_config_with_obs_3d', (20, 20, 10), 2, pytest.ST_FP_DISC), ], ) -def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): +def test_train_real_obs(gen_config, sample_shape, t_enhance, fp_disc, request): """Test model training with only sparse high resolution ground truth data. This should skip any calculations involving the discriminator - since train_disc=False""" @@ -175,8 +177,12 @@ def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): obs_data = dual_rasterizer.high_res.copy() for feat in FEATURES_W: tmp = np.full(obs_data[feat].shape, np.nan) - lat_ids = list(range(0, 20, 4)) - lon_ids = list(range(0, 20, 4)) + lat_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[0], size=5, replace=False + ) + lon_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[1], size=5, replace=False + ) for ilat, ilon in itertools.product(lat_ids, lon_ids): tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] obs_data[f'{feat}_obs'] = (obs_data[feat].dims, tmp) @@ -239,7 +245,7 @@ def test_train_just_obs(gen_config, sample_shape, t_enhance, fp_disc, request): @pytest.mark.parametrize('lr_only_features', [[], ['temperature_2m']]) -def test_train_obs_with_topo(lr_only_features, request): +def test_train_proxy_obs_with_topo(lr_only_features, request): """Test training with topo and obs. Make sure exo features are properly concatenated.""" @@ -310,3 +316,108 @@ def test_train_obs_with_topo(lr_only_features, request): assert not np.isnan(loss).any() gloss = model.history['train_loss_gen'].values assert not np.isnan(gloss).any() + + +@pytest.mark.parametrize('lr_only_features', [[], ['temperature_2m']]) +def test_train_real_obs_with_topo(lr_only_features, request): + """Test model training with only sparse high resolution ground truth data. + This should skip any calculations involving the discriminator - since + train_disc=False""" + + gen_config = request.getfixturevalue('gen_config_with_obs_3d_topo')() + kwargs = { + 'features': [*FEATURES_W, 'topography'], + 'target': TARGET_W, + 'shape': SHAPE, + } + + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES_W, + time_slice=slice(None, None, 2), + ) + + # Add dummy lr only features + if lr_only_features: + for feat in lr_only_features: + lr_handler[feat] = lr_handler[FEATURES_W[0]].copy() + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=2, + t_enhance=2, + run_qa=False, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES_W: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[0], size=5, replace=False + ) + lon_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[1], size=5, replace=False + ) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[f'{feat}_obs'] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=(20, 20, 10), + batch_size=2, + s_enhance=2, + t_enhance=2, + n_batches=1, + feature_sets={ + 'lr_features': [*lr_only_features, *FEATURES_W], + 'hr_exo_features': [ + 'topography', + *[f'{feat}_obs' for feat in FEATURES_W], + ], + 'hr_out_features': FEATURES_W, + }, + mode='lazy', + ) + + Sup3rGan.seed() + model = Sup3rGan( + gen_config, + pytest.ST_FP_DISC, + learning_rate=1e-4, + loss={ + 'GeothermalPhysicsLossWithObs': { + 'gen_features': FEATURES_W, + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], + } + }, + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': 5, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': True, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tloss = model.history['train_geothermal_physics_loss_with_obs'].values + assert np.sum(np.diff(tloss)) < 0 From 16474dddbf8297143aa8b0b6899c5b2ed8769561 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 21 Mar 2026 15:50:05 -0600 Subject: [PATCH 24/26] fix: previous logic was overwriting layer.features and including layer.exo_features --- sup3r/models/interface.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index e7f0d7d603..9529d18b09 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -204,8 +204,17 @@ def _ensure_feature_consistency(self): if hasattr(self, '_gen'): for layer in self._gen.layers: if isinstance(layer, SUP3R_LAYERS): - feats = getattr(layer, 'features', [layer.name]) - features.extend(feats) + feats = ( + [layer.name] + if not hasattr(layer, 'features') + else layer.features + ) + exo_feats = ( + [] + if not hasattr(layer, 'exo_features') + else layer.exo_features + ) + features.extend(feats + exo_feats) if set(self.hr_exo_features) != set(features): msg = ( f'Model meta hr_exo_features {self.hr_exo_features} does not ' From ac47940e8c1f3914c2cc46af68559d5a51920833 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 22 Mar 2026 09:11:34 -0600 Subject: [PATCH 25/26] fix: lock on gradient - without it multi gpu was failing suddenly. --- sup3r/models/abstract.py | 42 ++++++++++++++++--------------- sup3r/models/base.py | 2 +- sup3r/models/interface.py | 2 +- tests/conftest.py | 6 ----- tests/training/test_end_to_end.py | 2 +- 5 files changed, 25 insertions(+), 29 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index dec28284d1..603afaade2 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from inspect import signature +from threading import Lock from warnings import warn import numpy as np @@ -51,6 +52,7 @@ def __init__(self): self._stdevs = None self._train_record = pd.DataFrame() self._val_record = pd.DataFrame() + self._lock = Lock() def load_network(self, model, name): """Load a CustomNetwork object from hidden layers config, .json file @@ -105,10 +107,9 @@ def _load_model_from_string(self, model, name): return model['meta'][f'config_{name}']['hidden_layers'] msg = ( - 'Could not load model from json config, need ' - '"hidden_layers" key or ' - f'"meta/config_{name}/hidden_layers" ' - ' at top level but only found: {}'.format(model.keys()) + 'Could not load model from json config, need "hidden_layers" key ' + f'or "meta/config_{name}/hidden_layers" at top level but only ' + f'found: {model.keys()}' ) logger.error(msg) raise KeyError(msg) @@ -539,15 +540,15 @@ def loss_fun(hi_res_gen, hi_res_true): ) val = loss_func(hr_gen, hr_true) loss_details[camel_to_underscore(ln)] = val - if tf.math.reduce_any(tf.math.is_nan(val)): - msg = ( - f'NaN values found for loss term "{ln}" with value ' - f'{val} when running loss function {loss_func} on ' + tf.debugging.assert_all_finite( + val, + message=( + f'NaN or Inf values found for loss term "{ln}" with ' + f'value {val} when running loss function {loss_func} ' f'generated tensor of shape {hi_res_gen.shape} and ' f'true tensor of shape {hi_res_true.shape}' - ) - logger.error(msg) - raise ValueError(msg) + ), + ) loss += weights[i] * val return loss, loss_details @@ -1297,15 +1298,16 @@ def get_single_grad( loss_details : dict Namespace of the breakdown of loss components """ - with ( - tf.device(device_name), - tf.GradientTape(watch_accessed_variables=False) as tape, - ): - tape.watch(training_weights) - loss, loss_details, _, _ = self._get_hr_exo_and_loss( - low_res, hi_res_true, **calc_loss_kwargs - ) - grad = tape.gradient(loss, training_weights) + with self._lock: + with ( + tf.device(device_name), + tf.GradientTape(watch_accessed_variables=False) as tape, + ): + tape.watch(training_weights) + loss, loss_details, _, _ = self._get_hr_exo_and_loss( + low_res, hi_res_true, **calc_loss_kwargs + ) + grad = tape.gradient(loss, training_weights) return grad, loss_details @abstractmethod diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 7f9a38a074..31d6d64266 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -973,7 +973,7 @@ def _train_batch( Flag whether to train the discriminator for this set of epochs only_disc : bool Flag whether to only train the discriminator for this set of epochs - gen_too_good : bool + disc_too_good : bool Flag whether to skip training the discriminator and only train the generator, due to superior performance, for this batch. weight_gen_advers : float diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 9529d18b09..f1acdd2356 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -217,7 +217,7 @@ def _ensure_feature_consistency(self): features.extend(feats + exo_feats) if set(self.hr_exo_features) != set(features): msg = ( - f'Model meta hr_exo_features {self.hr_exo_features} does not ' + f'Specified hr_exo_features {self.hr_exo_features} does not ' f'match features {features} found in model layers.' ) logger.error(msg) diff --git a/tests/conftest.py b/tests/conftest.py index 52193d4716..3f0d1c4a32 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,12 +63,6 @@ def set_random_state(): RANDOM_GENERATOR.bit_generator.state = GLOBAL_STATE -@pytest.fixture(autouse=True) -def train_on_cpu(): - """Train on cpu for tests.""" - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' - - @pytest.fixture(scope='package') def gen_config_with_obs_2d(): """Get generator config with observation layers.""" diff --git a/tests/training/test_end_to_end.py b/tests/training/test_end_to_end.py index 597a2d17b0..a81d647857 100644 --- a/tests/training/test_end_to_end.py +++ b/tests/training/test_end_to_end.py @@ -59,7 +59,7 @@ def test_end_to_end(): train_containers=[train_dh], val_containers=[val_dh], n_batches=2, - batch_size=10, + batch_size=4, sample_shape=(12, 12, 16), s_enhance=3, t_enhance=4, From 514b7c90126c67af6ed0c10406f33831f4adb75f Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 22 Mar 2026 11:28:07 -0600 Subject: [PATCH 26/26] fix: update loss assertions to check for NaN values and adjust training logic for discriminator --- tests/training/test_train_with_obs.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py index d92e6ac342..0a637783b3 100644 --- a/tests/training/test_train_with_obs.py +++ b/tests/training/test_train_with_obs.py @@ -241,7 +241,7 @@ def test_train_real_obs(gen_config, sample_shape, t_enhance, fp_disc, request): model.train(batch_handler, **model_kwargs) tloss = model.history['train_geothermal_physics_loss_with_obs'].values - assert np.sum(np.diff(tloss)) < 0 + assert not np.isnan(tloss).any() @pytest.mark.parametrize('lr_only_features', [[], ['temperature_2m']]) @@ -399,6 +399,7 @@ def test_train_real_obs_with_topo(lr_only_features, request): pytest.ST_FP_DISC, learning_rate=1e-4, loss={ + 'MeanAbsoluteError': {}, 'GeothermalPhysicsLossWithObs': { 'gen_features': FEATURES_W, 'true_features': [f'{feat}_obs' for feat in FEATURES_W], @@ -412,12 +413,18 @@ def test_train_real_obs_with_topo(lr_only_features, request): 'n_epoch': 5, 'weight_gen_advers': 0.0, 'train_gen': True, - 'train_disc': True, + 'train_disc': False, 'checkpoint_int': 1, 'out_dir': os.path.join(td, 'test_{epoch}'), } model.train(batch_handler, **model_kwargs) - tloss = model.history['train_geothermal_physics_loss_with_obs'].values - assert np.sum(np.diff(tloss)) < 0 + tloss = model.history['train_loss_gen'].values + assert not np.isnan(tloss).any() + + model_kwargs['train_disc'] = True + model.train(batch_handler, **model_kwargs) + + tloss_disc = model.history['train_loss_disc'].values + assert not np.isnan(tloss_disc).any()